crap
This commit is contained in:
parent
679b1822cd
commit
f3e25e314c
4 changed files with 94 additions and 277 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,3 +2,4 @@
|
||||||
__pycache__
|
__pycache__
|
||||||
current_version.txt
|
current_version.txt
|
||||||
MEMORY_LOADED
|
MEMORY_LOADED
|
||||||
|
memory.json
|
18
bot.py
18
bot.py
|
@ -25,9 +25,10 @@ import shutil
|
||||||
print(splashtext) # you can use https://patorjk.com/software/taag/ for 3d text or just remove this entirely
|
print(splashtext) # you can use https://patorjk.com/software/taag/ for 3d text or just remove this entirely
|
||||||
|
|
||||||
def download_json():
|
def download_json():
|
||||||
|
locales_dir = "locales"
|
||||||
response = requests.get(f"{VERSION_URL}/goob/locales/{LOCALE}.json")
|
response = requests.get(f"{VERSION_URL}/goob/locales/{LOCALE}.json")
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
locales_dir = "locales"
|
|
||||||
if not os.path.exists(locales_dir):
|
if not os.path.exists(locales_dir):
|
||||||
os.makedirs(locales_dir)
|
os.makedirs(locales_dir)
|
||||||
file_path = os.path.join(locales_dir, f"{LOCALE}.json")
|
file_path = os.path.join(locales_dir, f"{LOCALE}.json")
|
||||||
|
@ -38,6 +39,7 @@ def download_json():
|
||||||
file.write(response.text)
|
file.write(response.text)
|
||||||
|
|
||||||
if not os.path.exists(os.path.join(locales_dir, "en.json")):
|
if not os.path.exists(os.path.join(locales_dir, "en.json")):
|
||||||
|
|
||||||
response = requests.get(f"{VERSION_URL}/goob/locales/en.json")
|
response = requests.get(f"{VERSION_URL}/goob/locales/en.json")
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
with open(os.path.join(locales_dir, "en.json"), "w", encoding="utf-8") as file:
|
with open(os.path.join(locales_dir, "en.json"), "w", encoding="utf-8") as file:
|
||||||
|
@ -101,7 +103,7 @@ def register_name(NAME):
|
||||||
token = data.get("token")
|
token = data.get("token")
|
||||||
|
|
||||||
if not os.getenv("gooberTOKEN"):
|
if not os.getenv("gooberTOKEN"):
|
||||||
print(f"{GREEN}{get_translation(LOCALE, 'add_token').format(token=token)} gooberTOKEN=<your_token>.{RESET}")
|
print(f"{GREEN}{get_translation(LOCALE, 'add_token').format(token=token)} gooberTOKEN=<token>.{RESET}")
|
||||||
quit()
|
quit()
|
||||||
else:
|
else:
|
||||||
print(f"{GREEN}{RESET}")
|
print(f"{GREEN}{RESET}")
|
||||||
|
@ -210,8 +212,8 @@ def check_for_update():
|
||||||
gooberhash = latest_version_info.get("hash")
|
gooberhash = latest_version_info.get("hash")
|
||||||
if gooberhash == currenthash:
|
if gooberhash == currenthash:
|
||||||
if local_version < latest_version:
|
if local_version < latest_version:
|
||||||
print(f"{YELLOW}{get_translation(LOCALE, 'new_version')}{RESET}")
|
print(f"{YELLOW}{get_translation(LOCALE, 'new_version').format(latest_version=latest_version, local_version=local_version)}{RESET}")
|
||||||
print(f"{YELLOW}{get_translation(LOCALE, 'changelog').format(VERSION_URL=VERSION_URL)}")
|
print(f"{YELLOW}{get_translation(LOCALE, 'changelog').format(VERSION_URL=VERSION_URL)}{RESET}")
|
||||||
elif local_version > latest_version:
|
elif local_version > latest_version:
|
||||||
if IGNOREWARNING == False:
|
if IGNOREWARNING == False:
|
||||||
print(f"\n{RED}{get_translation(LOCALE, 'invalid_version').format(local_version=local_version)}")
|
print(f"\n{RED}{get_translation(LOCALE, 'invalid_version').format(local_version=local_version)}")
|
||||||
|
@ -411,19 +413,19 @@ async def retrain(ctx):
|
||||||
return
|
return
|
||||||
data_size = len(memory)
|
data_size = len(memory)
|
||||||
processed_data = 0
|
processed_data = 0
|
||||||
processing_message_ref = await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retraining')}")
|
processing_message_ref = await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retraining').format(processed_data=processed_data, data_size=data_size)}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i, data in enumerate(memory):
|
for i, data in enumerate(memory):
|
||||||
processed_data += 1
|
processed_data += 1
|
||||||
if processed_data % 1000 == 0 or processed_data == data_size:
|
if processed_data % 1000 == 0 or processed_data == data_size:
|
||||||
await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retraining')}", edit=True, message_reference=processing_message_ref)
|
await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retraining').format(processed_data=processed_data, data_size=data_size)}", edit=True, message_reference=processing_message_ref)
|
||||||
|
|
||||||
global markov_model
|
global markov_model
|
||||||
|
|
||||||
markov_model = train_markov_model(memory)
|
markov_model = train_markov_model(memory)
|
||||||
save_markov_model(markov_model)
|
save_markov_model(markov_model)
|
||||||
|
|
||||||
await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retrain_successful')}", edit=True, message_reference=processing_message_ref)
|
await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retrain_successful').format(data_size=data_size)}", edit=True, message_reference=processing_message_ref)
|
||||||
|
|
||||||
@bot.hybrid_command(description=f"{get_translation(LOCALE, 'command_desc_talk')}")
|
@bot.hybrid_command(description=f"{get_translation(LOCALE, 'command_desc_talk')}")
|
||||||
async def talk(ctx):
|
async def talk(ctx):
|
||||||
|
@ -483,7 +485,7 @@ async def help(ctx):
|
||||||
custom_commands = []
|
custom_commands = []
|
||||||
for cog_name, cog in bot.cogs.items():
|
for cog_name, cog in bot.cogs.items():
|
||||||
for command in cog.get_commands():
|
for command in cog.get_commands():
|
||||||
if command.name not in command_categories[f"{get_translation(LOCALE, 'command_help_categories_general')}"] and command.name not in command_categories["Administration"]:
|
if command.name not in command_categories[f"{get_translation(LOCALE, 'command_help_categories_general')}"] and command.name not in command_categories[f"{get_translation(LOCALE, 'command_help_categories_admin')}"]:
|
||||||
custom_commands.append(command.name)
|
custom_commands.append(command.name)
|
||||||
|
|
||||||
if custom_commands:
|
if custom_commands:
|
||||||
|
|
|
@ -1,64 +1,60 @@
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
import os
|
import os
|
||||||
from typing import List, TypedDict
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
from time import strftime, localtime
|
|
||||||
import pickle
|
import pickle
|
||||||
import functools
|
import functools
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
ready: bool = True
|
ready = True
|
||||||
MODEL_MATCH_STRING = "[0-9]{2}_[0-9]{2}_[0-9]{4}-[0-9]{2}_[0-9]{2}"
|
MODEL_MATCH_STRING = r"[0-9]{2}_[0-9]{2}_[0-9]{4}-[0-9]{2}_[0-9]{2}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
from keras.preprocessing.text import Tokenizer
|
from tensorflow.keras.preprocessing.text import Tokenizer
|
||||||
from keras_preprocessing.sequence import pad_sequences
|
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||||||
from keras.models import Sequential
|
from tensorflow.keras.models import Sequential, load_model
|
||||||
from keras.layers import Embedding, LSTM, Dense
|
from tensorflow.keras.layers import Embedding, LSTM, Dense
|
||||||
from keras.models import load_model
|
from tensorflow.keras.backend import clear_session
|
||||||
from keras.backend import clear_session
|
|
||||||
tf.config.optimizer.set_jit(True)
|
if tf.config.list_physical_devices('GPU'):
|
||||||
|
print("Using GPU acceleration")
|
||||||
|
elif tf.config.list_physical_devices('Metal'):
|
||||||
|
print("Using Metal for macOS acceleration")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("ERROR: Failed to import Tensorflow. Here is a list of required dependencies:",(
|
print("ERROR: Failed to import TensorFlow. Ensure you have the correct dependencies:")
|
||||||
"tensorflow==2.10.0"
|
print("tensorflow>=2.15.0")
|
||||||
"(for Nvidia users: tensorflow-gpu==2.10.0)"
|
print("For macOS (Apple Silicon): tensorflow-metal")
|
||||||
"(for macOS: tensorflow-metal==0.6.0, tensorflow-macos==2.10.0)"
|
|
||||||
"numpy~=1.23"
|
|
||||||
))
|
|
||||||
ready = False
|
ready = False
|
||||||
|
|
||||||
|
|
||||||
class TFCallback(keras.callbacks.Callback):
|
class TFCallback(keras.callbacks.Callback):
|
||||||
def __init__(self,bot, progress_embed:discord.Embed, message):
|
def __init__(self, bot, progress_embed: discord.Embed, message):
|
||||||
self.embed:discord.Embed = progress_embed
|
self.embed = progress_embed
|
||||||
self.bot:commands.Bot = bot
|
self.bot = bot
|
||||||
self.message = message
|
self.message = message
|
||||||
self.times:List[int] = [time.time()]
|
self.times = [time.time()]
|
||||||
|
|
||||||
def on_train_begin(self, logs=None):
|
async def send_message(self, message: str, description: str, **kwargs):
|
||||||
pass
|
|
||||||
|
|
||||||
async def send_message(self,message:str, description:str, **kwargs):
|
|
||||||
if "epoch" in kwargs:
|
if "epoch" in kwargs:
|
||||||
self.times.append(time.time())
|
self.times.append(time.time())
|
||||||
average_epoch_time:int = np.average(np.diff(np.array(self.times)))
|
avg_epoch_time = np.mean(np.diff(self.times))
|
||||||
description = f"ETA: {round(average_epoch_time)}s"
|
description = f"ETA: {round(avg_epoch_time)}s"
|
||||||
self.embed.add_field(name=f"<t:{round(time.time())}:t> - {message}",value=description,inline=False)
|
self.embed.add_field(name=f"<t:{round(time.time())}:t> - {message}", value=description, inline=False)
|
||||||
await self.message.edit(embed=self.embed)
|
await self.message.edit(embed=self.embed)
|
||||||
|
|
||||||
def on_train_end(self,logs=None):
|
def on_train_end(self, logs=None):
|
||||||
self.bot.loop.create_task(self.send_message("Training stopped", "training has been stopped."))
|
self.bot.loop.create_task(self.send_message("Training stopped", "Training has been stopped."))
|
||||||
|
|
||||||
def on_epoch_begin(self, epoch, logs=None):
|
def on_epoch_begin(self, epoch, logs=None):
|
||||||
self.bot.loop.create_task(self.send_message(f"Starting epoch {epoch}","This might take a while", epoch=True))
|
self.bot.loop.create_task(self.send_message(f"Starting epoch {epoch}", "This might take a while", epoch=True))
|
||||||
|
|
||||||
def on_epoch_end(self, epoch, logs=None):
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
self.bot.loop.create_task(self.send_message(f"Epoch {epoch} ended",f"Accuracy: {round(logs.get('accuracy',0.0),4)}"))
|
self.bot.loop.create_task(self.send_message(f"Epoch {epoch} ended", f"Accuracy: {round(logs.get('accuracy', 0.0), 4)}"))
|
||||||
|
|
||||||
|
|
||||||
class Ai:
|
class Ai:
|
||||||
|
@ -69,272 +65,91 @@ class Ai:
|
||||||
self.is_loaded = model_path is not None
|
self.is_loaded = model_path is not None
|
||||||
self.batch_size = 64
|
self.batch_size = 64
|
||||||
|
|
||||||
def get_model_name_from_path(self,path:str):
|
def generate_model_name(self):
|
||||||
match:re.Match = re.search(MODEL_MATCH_STRING, path)
|
return time.strftime('%d_%m_%Y-%H_%M', time.localtime())
|
||||||
return path[match.start():][:match.end()]
|
|
||||||
|
|
||||||
def generate_model_name(self) -> str:
|
def __load_model(self, model_path):
|
||||||
return strftime('%d_%m_%Y-%H_%M', localtime())
|
|
||||||
|
|
||||||
def generate_model_abs_path(self, name:str):
|
|
||||||
name = name or self.generate_model_name()
|
|
||||||
return os.path.join(".","models",self.generate_model_name(),"model.h5")
|
|
||||||
|
|
||||||
def generate_tokenizer_abs_path(self, name:str):
|
|
||||||
name = name or self.generate_model_name()
|
|
||||||
return os.path.join(".","models",name,"tokenizer.pkl")
|
|
||||||
|
|
||||||
def generate_info_abs_path(self,name:str):
|
|
||||||
name = name or self.generate_model_name()
|
|
||||||
return os.path.join(".","models",name,"info.json")
|
|
||||||
|
|
||||||
|
|
||||||
def save_model(self,model, tokenizer, history, _name:str=None):
|
|
||||||
name:str = _name or self.generate_model_name()
|
|
||||||
os.makedirs(os.path.join(".","models",name), exist_ok=True)
|
|
||||||
|
|
||||||
with open(self.generate_info_abs_path(name),"w") as f:
|
|
||||||
json.dump(history.history,f)
|
|
||||||
|
|
||||||
with open(self.generate_tokenizer_abs_path(name), "wb") as f:
|
|
||||||
pickle.dump(tokenizer,f)
|
|
||||||
|
|
||||||
model.save(self.generate_model_abs_path(name))
|
|
||||||
|
|
||||||
|
|
||||||
def __load_model(self, model_path:str):
|
|
||||||
clear_session()
|
clear_session()
|
||||||
self.model = load_model(os.path.join(model_path,"model.h5"))
|
self.model = load_model(os.path.join(model_path, "model.h5"))
|
||||||
|
model_name = os.path.basename(model_path)
|
||||||
model_name:str = self.get_model_name_from_path(model_path)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(self.generate_tokenizer_abs_path(model_name),"rb") as f:
|
with open(os.path.join(model_path, "tokenizer.pkl"), "rb") as f:
|
||||||
self.tokenizer = pickle.load(f)
|
self.tokenizer = pickle.load(f)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print("Failed to load tokenizer for model... Using default")
|
print("Failed to load tokenizer, using default.")
|
||||||
self.tokenizer = Tokenizer()
|
self.tokenizer = Tokenizer()
|
||||||
|
with open("memory.json", "r") as f:
|
||||||
with open("memory.json","r") as f:
|
self.tokenizer.fit_on_texts(json.load(f))
|
||||||
self.tokenizer.fit_on_sequences(json.load(f))
|
|
||||||
self.is_loaded = True
|
self.is_loaded = True
|
||||||
|
|
||||||
def reload_model(self):
|
def reload_model(self):
|
||||||
clear_session()
|
clear_session()
|
||||||
model_path:str = settings.get("model_path")
|
model_path = settings.get("model_path")
|
||||||
if model_path:
|
if model_path:
|
||||||
self.model = self.__load_model(model_path)
|
self.__load_model(model_path)
|
||||||
self.is_loaded = True
|
self.is_loaded = True
|
||||||
|
|
||||||
async def run_async(self,func,bot,*args,**kwargs):
|
async def run_async(self, func, bot, *args, **kwargs):
|
||||||
func = functools.partial(func,*args,**kwargs)
|
return await bot.loop.run_in_executor(None, functools.partial(func, *args, **kwargs))
|
||||||
return await bot.loop.run_in_executor(None,func)
|
|
||||||
|
|
||||||
class Learning(Ai):
|
class Learning(Ai):
|
||||||
def __init__(self):
|
def create_model(self, memory, epochs=2):
|
||||||
super().__init__()
|
memory = memory[:2000]
|
||||||
|
|
||||||
def __generate_labels_and_inputs(self,memory: List[str], tokenizer=None) -> tuple:
|
|
||||||
if not tokenizer:
|
|
||||||
tokenizer = Tokenizer()
|
tokenizer = Tokenizer()
|
||||||
tokenizer.fit_on_texts(memory)
|
tokenizer.fit_on_texts(memory)
|
||||||
sequences = tokenizer.texts_to_sequences(memory)
|
sequences = tokenizer.texts_to_sequences(memory)
|
||||||
|
X, y = [], []
|
||||||
x = []
|
|
||||||
y = []
|
|
||||||
for seq in sequences:
|
for seq in sequences:
|
||||||
for i in range(1, len(seq)):
|
for i in range(1, len(seq)):
|
||||||
x.append(seq[:i])
|
X.append(seq[:i])
|
||||||
y.append(seq[i])
|
y.append(seq[i])
|
||||||
|
maxlen = max(map(len, X))
|
||||||
return x,y, tokenizer
|
X = pad_sequences(X, maxlen=maxlen, padding="pre")
|
||||||
|
|
||||||
def create_model(self,memory: list, iters:int=2):
|
|
||||||
memory = memory[:2000]
|
|
||||||
X,y,tokenizer = self.__generate_labels_and_inputs(memory)
|
|
||||||
maxlen:int = max([len(x) for x in X])
|
|
||||||
x_pad = pad_sequences(X, maxlen=maxlen, padding="pre")
|
|
||||||
|
|
||||||
y = np.array(y)
|
y = np.array(y)
|
||||||
|
|
||||||
model = Sequential()
|
model = Sequential([
|
||||||
model.add(Embedding(input_dim=VOCAB_SIZE,output_dim=128,input_length=maxlen))
|
Embedding(input_dim=VOCAB_SIZE, output_dim=128, input_length=maxlen),
|
||||||
model.add(LSTM(64))
|
LSTM(64),
|
||||||
model.add(Dense(VOCAB_SIZE, activation="softmax"))
|
Dense(VOCAB_SIZE, activation="softmax")
|
||||||
|
])
|
||||||
|
|
||||||
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
||||||
history = model.fit(x_pad, y, epochs=iters, batch_size=64, callbacks=[tf_callback])
|
history = model.fit(X, y, epochs=epochs, batch_size=64, callbacks=[tf_callback])
|
||||||
self.save_model(model, tokenizer, history)
|
self.save_model(model, tokenizer, history)
|
||||||
return
|
|
||||||
|
|
||||||
|
def save_model(self, model, tokenizer, history, name=None):
|
||||||
|
name = name or self.generate_model_name()
|
||||||
|
model_dir = os.path.join("models", name)
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
def add_training(self,memory: List[str], iters:int=2):
|
with open(os.path.join(model_dir, "info.json"), "w") as f:
|
||||||
tokenizer_path = os.path.join(settings.get("model_path"),"tokenizer.pkl")
|
json.dump(history.history, f)
|
||||||
with open(tokenizer_path, "rb") as f:
|
with open(os.path.join(model_dir, "tokenizer.pkl"), "wb") as f:
|
||||||
tokenizer = pickle.load(f)
|
pickle.dump(tokenizer, f)
|
||||||
|
model.save(os.path.join(model_dir, "model.h5"))
|
||||||
|
|
||||||
X,y,_ = self.__generate_labels_and_inputs(memory, tokenizer)
|
|
||||||
|
|
||||||
maxlen:int = max([len(x) for x in X])
|
|
||||||
x_pad = pad_sequences(X, maxlen=maxlen, padding="pre")
|
|
||||||
y = np.array(y)
|
|
||||||
|
|
||||||
history = self.model.fit(x_pad,y, epochs=iters, validation_data=(x_pad,y), batch_size=64, callbacks=[tf_callback]) # Ideally, validation data would be seperate from the actual data
|
|
||||||
self.save_model(self.model,tokenizer,history,self.get_model_name_from_path(settings.get("model_path")))
|
|
||||||
return
|
|
||||||
|
|
||||||
class Generation(Ai):
|
class Generation(Ai):
|
||||||
def __init__(self):
|
def generate_sentence(self, word_amount, seed):
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def generate_sentence(self, word_amount:int, seed:str):
|
|
||||||
if not self.is_loaded:
|
if not self.is_loaded:
|
||||||
return False
|
return False
|
||||||
for _ in range(word_amount):
|
for _ in range(word_amount):
|
||||||
token_list = self.tokenizer.texts_to_sequences([seed])[0]
|
token_list = self.tokenizer.texts_to_sequences([seed])[0]
|
||||||
token_list = pad_sequences([token_list], maxlen=self.model.layers[0].input_shape[1], padding="pre")
|
token_list = pad_sequences([token_list], maxlen=self.model.input_shape[1], padding="pre")
|
||||||
|
predicted_word_index = np.argmax(self.model.predict(token_list, verbose=0), axis=-1)[0]
|
||||||
output_word = "" # Sometimes model fails to predict the word, so using a fallback
|
output_word = next((w for w, i in self.tokenizer.word_index.items() if i == predicted_word_index), "")
|
||||||
|
|
||||||
predicted_probs = self.model.predict(token_list, verbose=0)
|
|
||||||
predicted_word_index = np.argmax(predicted_probs, axis=-1)[0]
|
|
||||||
|
|
||||||
for word, index in self.tokenizer.word_index.items():
|
|
||||||
if index == predicted_word_index:
|
|
||||||
output_word = word
|
|
||||||
break
|
|
||||||
|
|
||||||
seed += " " + output_word
|
seed += " " + output_word
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
|
|
||||||
VOCAB_SIZE = 100_000
|
VOCAB_SIZE = 100_000
|
||||||
SETTINGS_TYPE = TypedDict("SETTINGS_TYPE", {
|
settings = {}
|
||||||
"model_path":str, # path to the base folder of the model, aka .../models/05-01-2025-22_31/
|
learning = Learning()
|
||||||
"tokenizer_path":str,
|
generation = Generation()
|
||||||
})
|
|
||||||
|
|
||||||
tf_callback:TFCallback
|
tf_callback = None
|
||||||
model_dropdown_items = []
|
|
||||||
settings: SETTINGS_TYPE = {}
|
|
||||||
|
|
||||||
target_message:int
|
|
||||||
learning:Learning
|
|
||||||
generation: Generation
|
|
||||||
|
|
||||||
class Settings:
|
|
||||||
def __init__(self):
|
|
||||||
self.settings_path:str = os.path.join(".","models","settings.json")
|
|
||||||
|
|
||||||
def load(self):
|
|
||||||
global settings
|
|
||||||
try:
|
|
||||||
with open(self.settings_path,"r") as f:
|
|
||||||
settings = json.load(f)
|
|
||||||
except FileNotFoundError:
|
|
||||||
with open(self.settings_path,"w") as f:
|
|
||||||
json.dump({},f)
|
|
||||||
|
|
||||||
def change_model(self,new_model_base_path:str):
|
|
||||||
global settings
|
|
||||||
new_model_path = os.path.join(".","models",new_model_base_path)
|
|
||||||
|
|
||||||
with open(self.settings_path,"r") as f:
|
|
||||||
settings = json.load(f)
|
|
||||||
|
|
||||||
settings["model_path"] = new_model_path
|
|
||||||
|
|
||||||
with open(self.settings_path, "w") as f:
|
|
||||||
json.dump(settings,f)
|
|
||||||
|
|
||||||
|
|
||||||
class Dropdown(discord.ui.Select):
|
|
||||||
def __init__(self, items:List[str]):
|
|
||||||
global model_dropdown_items
|
|
||||||
model_dropdown_items = []
|
|
||||||
|
|
||||||
for item in items:
|
|
||||||
model_dropdown_items.append(
|
|
||||||
discord.SelectOption(label=item)
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(placeholder="Select model", options=model_dropdown_items)
|
|
||||||
|
|
||||||
async def callback(self, interaction: discord.Interaction):
|
|
||||||
if int(interaction.user.id) != int(os.getenv("ownerid")):
|
|
||||||
await interaction.message.channel.send("KILL YOURSELF")
|
|
||||||
Settings().change_model(self.values[0])
|
|
||||||
await interaction.message.channel.send(f"Changed model to {self.values[0]}")
|
|
||||||
|
|
||||||
class DropdownView(discord.ui.View):
|
|
||||||
def __init__(self, timeout, models):
|
|
||||||
super().__init__(timeout=timeout)
|
|
||||||
self.add_item(Dropdown(models))
|
|
||||||
|
|
||||||
|
|
||||||
class Tf(commands.Cog):
|
|
||||||
def __init__(self,bot):
|
|
||||||
global learning, generation, ready
|
|
||||||
os.makedirs(os.path.join(".","models"),exist_ok=True)
|
|
||||||
Settings().load()
|
|
||||||
self.bot = bot
|
|
||||||
learning = Learning()
|
|
||||||
generation = Generation()
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def start(self,ctx):
|
|
||||||
await ctx.defer()
|
|
||||||
await ctx.send("hi")
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def generate(self,ctx,seed:str,word_amount:int=5):
|
|
||||||
await ctx.defer()
|
|
||||||
await ctx.send(generation.generate_sentence(word_amount,seed))
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def create(self,ctx:commands.Context, epochs:int=3):
|
|
||||||
global tf_callback
|
|
||||||
await ctx.defer()
|
|
||||||
with open("memory.json","r") as f:
|
|
||||||
memory:List[str] = json.load(f)
|
|
||||||
await ctx.send("Initializing tensorflow")
|
|
||||||
embed = discord.Embed(title="Creating a model...", description="Progress of creating a model")
|
|
||||||
embed.set_footer(text="Note: Progress tracking might report delayed / wrong data, since the function is run asynchronously")
|
|
||||||
target_message:discord.Message = await ctx.send(embed=embed)
|
|
||||||
|
|
||||||
tf_callback = TFCallback(self.bot,embed,target_message)
|
|
||||||
await learning.run_async(learning.create_model,self.bot,memory,epochs)
|
|
||||||
embed = target_message.embeds[0]
|
|
||||||
embed.add_field(name=f"<t:{round(time.time())}:t> Finished",value="Model saved.")
|
|
||||||
await target_message.edit(embed=embed)
|
|
||||||
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def train(self,ctx, epochs:int=2):
|
|
||||||
global tf_callback
|
|
||||||
|
|
||||||
await ctx.defer()
|
|
||||||
with open("memory.json","r") as f:
|
|
||||||
memory:List[str] = json.load(f)
|
|
||||||
|
|
||||||
embed = discord.Embed(title="Training model...", description="Progress of training model")
|
|
||||||
target_message = await ctx.send(embed=embed)
|
|
||||||
tf_callback = TFCallback(self.bot,embed,target_message)
|
|
||||||
|
|
||||||
await learning.run_async(learning.add_training,self.bot,memory,epochs)
|
|
||||||
await ctx.send("Finished!")
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def change(self,ctx,model:str=None):
|
|
||||||
embed = discord.Embed(title="Change model",description="Which model would you like to use?")
|
|
||||||
if model is None:
|
|
||||||
models:List[str] = os.listdir(os.path.join(".","models"))
|
|
||||||
models = [folder for folder in models if re.match(MODEL_MATCH_STRING,folder)]
|
|
||||||
if len(models) == 0:
|
|
||||||
models = ["No models available."]
|
|
||||||
await ctx.send(embed=embed,view=DropdownView(90,models))
|
|
||||||
learning.reload_model()
|
|
||||||
generation.reload_model()
|
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot):
|
||||||
await bot.add_cog(Tf(bot))
|
await bot.add_cog(Tf(bot))
|
|
@ -35,4 +35,3 @@ GREEN = "\033[32m"
|
||||||
YELLOW = "\033[33m"
|
YELLOW = "\033[33m"
|
||||||
DEBUG = "\033[1;30m"
|
DEBUG = "\033[1;30m"
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue