Source code for source.embedding

import os
from itertools import tee
import logging
import json
from gensim.models import Word2Vec
from gensim.models.callbacks import CallbackAny2Vec
from matplotlib import pyplot as plt

logger = logging.getLogger(__name__)


[docs]class LossLogger(CallbackAny2Vec): """Callback to print loss after each epoch.""" def __init__(self, show=False): """ :param show: If True, show loss curve in the end. """ self.epoch = 0 self.prev_cum_loss = 0 self.batch_losses = [] # loss after each batch, for every epoch (?) self.epoch_losses = [] self.show = show
[docs] def on_epoch_begin(self, model): self.batch = 0
[docs] def on_epoch_end(self, model): cum_loss = model.get_latest_training_loss() eloss = cum_loss - self.prev_cum_loss print("Epoch #{} end, loss: {}".format(self.epoch, eloss)) self.epoch += 1 self.epoch_losses.append(eloss)
[docs] def on_batch_end(self, model): cum_loss = model.get_latest_training_loss() loss = abs(cum_loss - self.prev_cum_loss) print(f"Epoch {self.epoch} - Batch {self.batch} end loss: {loss}") self.prev_cum_loss = cum_loss self.batch_losses.append(loss) self.batch += 1
[docs] def on_train_end(self, model): plt.plot(self.batch_losses) if self.show: plt.show()
[docs]def train(corpus, save_path, load_path=None, size=300, window=5, min_count=10, workers=4, epochs=5, max_vocab_size=None, show_loss=False, save_loss=False): """ Train w2v. :param corpus: list of list strings :param save_path: Model file path :return: trained model """ texts, texts_build, texts_l = tee(corpus, 3) loss_logger = LossLogger(show_loss) # TODO: loss curve looks weird with multiple workers if not os.path.exists(save_path) and load_path is None: model = Word2Vec(texts_build, size=size, window=window, min_count=min_count, workers=workers, max_vocab_size=max_vocab_size, compute_loss=False, hs=0, sg=1, iter=epochs) else: if load_path is None: load_path = save_path print(f'Loading model {load_path}') model = Word2Vec.load(load_path) model.build_vocab(texts_build, update=True) logger.debug('Updates vocab, new size: {}'.format(len(model.wv.vocab))) model.train(texts, total_examples=model.corpus_count, epochs=model.iter, callbacks=[loss_logger], compute_loss=True) print('Saving model') model.save(save_path) if save_loss: plt.savefig(save_path + '_losscurve.png') with open(save_path + '_losscurve.json', 'w') as f: json.dump(loss_logger.batch_losses, f) return model