Fixed issues with creating a model, and added progress indicators
This commit is contained in:
parent
33e0f5482f
commit
a269979b29
1 changed files with 70 additions and 29 deletions
|
@ -6,7 +6,10 @@ import numpy as np
|
||||||
import json
|
import json
|
||||||
from time import strftime, localtime
|
from time import strftime, localtime
|
||||||
import pickle
|
import pickle
|
||||||
|
import functools
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
|
||||||
ready: bool = True
|
ready: bool = True
|
||||||
MODEL_MATCH_STRING = "[0-9]{2}_[0-9]{2}_[0-9]{4}-[0-9]{2}_[0-9]{2}"
|
MODEL_MATCH_STRING = "[0-9]{2}_[0-9]{2}_[0-9]{4}-[0-9]{2}_[0-9]{2}"
|
||||||
|
@ -30,6 +33,34 @@ except ImportError:
|
||||||
))
|
))
|
||||||
ready = False
|
ready = False
|
||||||
|
|
||||||
|
class TFCallback(keras.callbacks.Callback):
|
||||||
|
def __init__(self,bot, progress_embed:discord.Embed, message):
|
||||||
|
self.embed:discord.Embed = progress_embed
|
||||||
|
self.bot:commands.Bot = bot
|
||||||
|
self.message = message
|
||||||
|
self.times:List[int] = [time.time()]
|
||||||
|
|
||||||
|
def on_train_begin(self, logs=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send_message(self,message:str, description:str, **kwargs):
|
||||||
|
if "epoch" in kwargs:
|
||||||
|
self.times.append(time.time())
|
||||||
|
average_epoch_time:int = np.average(np.diff(np.array(self.times)))
|
||||||
|
description = f"ETA: {round(average_epoch_time)}s"
|
||||||
|
self.embed.add_field(name=f"<t:{round(time.time())}:t> - {message}",value=description,inline=False)
|
||||||
|
await self.message.edit(embed=self.embed)
|
||||||
|
|
||||||
|
def on_train_end(self,logs=None):
|
||||||
|
self.bot.loop.create_task(self.send_message("Training stopped", "training has been stopped."))
|
||||||
|
|
||||||
|
def on_epoch_begin(self, epoch, logs=None):
|
||||||
|
self.bot.loop.create_task(self.send_message(f"Starting epoch {epoch}","This might take a while", epoch=True))
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
self.bot.loop.create_task(self.send_message(f"Epoch {epoch} ended",f"Accuracy: {round(logs.get('accuracy',0.0),4)}"))
|
||||||
|
|
||||||
|
|
||||||
class Ai:
|
class Ai:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
model_path = settings.get("model_path")
|
model_path = settings.get("model_path")
|
||||||
|
@ -39,10 +70,7 @@ class Ai:
|
||||||
self.batch_size = 64
|
self.batch_size = 64
|
||||||
|
|
||||||
def get_model_name_from_path(self,path:str):
|
def get_model_name_from_path(self,path:str):
|
||||||
print(path)
|
|
||||||
match:re.Match = re.search(MODEL_MATCH_STRING, path)
|
match:re.Match = re.search(MODEL_MATCH_STRING, path)
|
||||||
|
|
||||||
print(match.start)
|
|
||||||
return path[match.start():][:match.end()]
|
return path[match.start():][:match.end()]
|
||||||
|
|
||||||
def generate_model_name(self) -> str:
|
def generate_model_name(self) -> str:
|
||||||
|
@ -96,7 +124,11 @@ class Ai:
|
||||||
model_path:str = settings.get("model_path")
|
model_path:str = settings.get("model_path")
|
||||||
if model_path:
|
if model_path:
|
||||||
self.model = self.__load_model(model_path)
|
self.model = self.__load_model(model_path)
|
||||||
|
self.is_loaded = True
|
||||||
|
|
||||||
|
async def run_async(self,func,bot,*args,**kwargs):
|
||||||
|
func = functools.partial(func,*args,**kwargs)
|
||||||
|
return await bot.loop.run_in_executor(None,func)
|
||||||
|
|
||||||
class Learning(Ai):
|
class Learning(Ai):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -117,7 +149,8 @@ class Learning(Ai):
|
||||||
|
|
||||||
return x,y, tokenizer
|
return x,y, tokenizer
|
||||||
|
|
||||||
def create_model(self,memory: List[str], iters:int=2):
|
def create_model(self,memory: list, iters:int=2):
|
||||||
|
memory = memory[:2000]
|
||||||
X,y,tokenizer = self.__generate_labels_and_inputs(memory)
|
X,y,tokenizer = self.__generate_labels_and_inputs(memory)
|
||||||
maxlen:int = max([len(x) for x in X])
|
maxlen:int = max([len(x) for x in X])
|
||||||
x_pad = pad_sequences(X, maxlen=maxlen, padding="pre")
|
x_pad = pad_sequences(X, maxlen=maxlen, padding="pre")
|
||||||
|
@ -130,8 +163,10 @@ class Learning(Ai):
|
||||||
model.add(Dense(VOCAB_SIZE, activation="softmax"))
|
model.add(Dense(VOCAB_SIZE, activation="softmax"))
|
||||||
|
|
||||||
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
||||||
history = model.fit(x_pad, y, epochs=iters, batch_size=32)
|
history = model.fit(x_pad, y, epochs=iters, batch_size=64, callbacks=[tf_callback])
|
||||||
self.save_model(model, tokenizer, history)
|
self.save_model(model, tokenizer, history)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def add_training(self,memory: List[str], iters:int=2):
|
def add_training(self,memory: List[str], iters:int=2):
|
||||||
tokenizer_path = os.path.join(settings.get("model_path"),"tokenizer.pkl")
|
tokenizer_path = os.path.join(settings.get("model_path"),"tokenizer.pkl")
|
||||||
|
@ -144,8 +179,9 @@ class Learning(Ai):
|
||||||
x_pad = pad_sequences(X, maxlen=maxlen, padding="pre")
|
x_pad = pad_sequences(X, maxlen=maxlen, padding="pre")
|
||||||
y = np.array(y)
|
y = np.array(y)
|
||||||
|
|
||||||
history = self.model.fit(x_pad,y, epochs=iters, validation_data=(x_pad,y), batch_size=64) # Idelaly, validation data would be seperate from the actual data
|
history = self.model.fit(x_pad,y, epochs=iters, validation_data=(x_pad,y), batch_size=64, callbacks=[tf_callback]) # Ideally, validation data would be seperate from the actual data
|
||||||
self.save_model(self.model,tokenizer,history,self.get_model_name_from_path(settings.get("model_path")))
|
self.save_model(self.model,tokenizer,history,self.get_model_name_from_path(settings.get("model_path")))
|
||||||
|
return
|
||||||
|
|
||||||
class Generation(Ai):
|
class Generation(Ai):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -173,16 +209,16 @@ class Generation(Ai):
|
||||||
|
|
||||||
|
|
||||||
VOCAB_SIZE = 100_000
|
VOCAB_SIZE = 100_000
|
||||||
|
|
||||||
SETTINGS_TYPE = TypedDict("SETTINGS_TYPE", {
|
SETTINGS_TYPE = TypedDict("SETTINGS_TYPE", {
|
||||||
"model_path":str, # path to the base folder of the model, aka .../models/05-01-2025-22_31/
|
"model_path":str, # path to the base folder of the model, aka .../models/05-01-2025-22_31/
|
||||||
"tokenizer_path":str,
|
"tokenizer_path":str,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
tf_callback:TFCallback
|
||||||
model_dropdown_items = []
|
model_dropdown_items = []
|
||||||
settings: SETTINGS_TYPE = {}
|
settings: SETTINGS_TYPE = {}
|
||||||
|
|
||||||
|
target_message:int
|
||||||
learning:Learning
|
learning:Learning
|
||||||
generation: Generation
|
generation: Generation
|
||||||
|
|
||||||
|
@ -237,26 +273,14 @@ class DropdownView(discord.ui.View):
|
||||||
|
|
||||||
|
|
||||||
class Tf(commands.Cog):
|
class Tf(commands.Cog):
|
||||||
@staticmethod
|
|
||||||
def needs_ready(func):
|
|
||||||
def inner(args:tuple, kwargs:dict):
|
|
||||||
if not ready:
|
|
||||||
raise AttributeError("Not ready!")
|
|
||||||
a = func(*args, **kwargs)
|
|
||||||
return a
|
|
||||||
return inner
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self,bot):
|
def __init__(self,bot):
|
||||||
global learning, generation
|
global learning, generation, ready
|
||||||
global ready
|
|
||||||
os.makedirs(os.path.join(".","models"),exist_ok=True)
|
os.makedirs(os.path.join(".","models"),exist_ok=True)
|
||||||
Settings().load()
|
Settings().load()
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
learning = Learning()
|
learning = Learning()
|
||||||
generation = Generation()
|
generation = Generation()
|
||||||
|
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
async def start(self,ctx):
|
async def start(self,ctx):
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
@ -268,19 +292,36 @@ class Tf(commands.Cog):
|
||||||
await ctx.send(generation.generate_sentence(word_amount,seed))
|
await ctx.send(generation.generate_sentence(word_amount,seed))
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
async def create(self,ctx):
|
async def create(self,ctx:commands.Context, epochs:int=3):
|
||||||
|
global tf_callback
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
with open("memory.json","r") as f:
|
with open("memory.json","r") as f:
|
||||||
memory:List[str] = json.load(f)
|
memory:List[str] = json.load(f)
|
||||||
learning.create_model(memory) # TODO: CHANGE
|
await ctx.send("Initializing tensorflow")
|
||||||
await ctx.send("Trained succesfully!")
|
embed = discord.Embed(title="Creating a model...", description="Progress of creating a model")
|
||||||
|
embed.set_footer(text="Note: Progress tracking might report delayed / wrong data, since the function is run asynchronously")
|
||||||
|
target_message:discord.Message = await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
tf_callback = TFCallback(self.bot,embed,target_message)
|
||||||
|
await learning.run_async(learning.create_model,self.bot,memory,epochs)
|
||||||
|
embed = target_message.embeds[0]
|
||||||
|
embed.add_field(name=f"<t:{round(time.time())}:t> Finished",value="Model saved.")
|
||||||
|
await target_message.edit(embed=embed)
|
||||||
|
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
async def train(self,ctx):
|
async def train(self,ctx, epochs:int=2):
|
||||||
|
global tf_callback
|
||||||
|
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
with open("memory.json","r") as f:
|
with open("memory.json","r") as f:
|
||||||
memory:List[str] = json.load(f)
|
memory:List[str] = json.load(f)
|
||||||
learning.add_training(memory,2)
|
|
||||||
|
embed = discord.Embed(title="Training model...", description="Progress of training model")
|
||||||
|
target_message = await ctx.send(embed=embed)
|
||||||
|
tf_callback = TFCallback(self.bot,embed,target_message)
|
||||||
|
|
||||||
|
await learning.run_async(learning.add_training,self.bot,memory,epochs)
|
||||||
await ctx.send("Finished!")
|
await ctx.send("Finished!")
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue