import os
import argh
from argh import arg
import json
import numpy as np
import tensorflow as tf
from tensorboard.plugins import projector
from source.process_embeddings import Embeddings, filter_by_vocab
from source.unsupervised_metrics import wn_category
[docs]@arg('--tn-label')
def tensorboard_emb(data_dir, model_name, output_path, tn_label='wn_clusters', label_name='clusters'):
"""
Visualise embeddings using TensorBoard.
Code from: https://gist.github.com/BrikerMan/7bd4e4bd0a00ac9076986148afc06507
:param model_name: name of numpy array files: embedding (.npy) and vocab (.vocab)
:param output_path: str, directory
:param tn_label: label dictionary file path or options: {"wn_clusters", "None"}
:param label_name: str, title for the labeling (e.g.: Cluster)
Usage on remote server with port forwarding:
* when you ssh into the machine, you use the option -L to transfer the port 6006 of the remote server
into the port 16006 of my machine (for instance):
* ssh -L 16006:127.0.0.1:6006 alv34@yellowhammer
What it does is that everything on the port 6006 of the server (in 127.0.0.1:6006) will be forwarded
to my machine on the port 16006.
* You can then launch tensorboard on the remote machine using a standard tensorboard --logdir log with
the default 6006 port
* On your local machine, go to http://127.0.0.1:16006 and enjoy your remote TensorBoard.
"""
print('Load embedding')
embs = Embeddings(data_dir, [model_name])
if tn_label == 'wn_clusters':
labeler = lambda w: wn_category(w)
print('Filter embedding and vocab by existing cluster names')
filter_vocab = [w for w in embs.vocabs[0] if labeler(w) is not None]
model, vocab = filter_by_vocab(embs.embeddings[0], embs.vocabs[0], filter_vocab)
print('#Vocab after filtering:', len(vocab))
elif tn_label == 'None':
labeler = lambda w: w
model = embs.embeddings[0]
vocab = embs.vocabs[0]
elif os.path.exists(tn_label):
with open(tn_label, 'r') as f:
label_dict = json.load(f)
labeler = lambda w: int(label_dict[w])
model = embs.embeddings[0]
vocab = embs.vocabs[0]
else:
print('Add a valid label dictionary file path or choose between {"wn_clusters", "None"}.')
file_name = "{}_metadata".format(model_name)
meta_file = "{}.tsv".format(file_name)
placeholder = np.zeros((len(vocab), model.shape[1]))
with open(os.path.join(output_path, meta_file), 'wb') as file_metadata:
file_metadata.write("Word\t{}".format(label_name).encode('utf-8') + b'\n')
for i, word in enumerate(vocab):
placeholder[i] = model[i, :]
# temporary solution for https://github.com/tensorflow/tensorflow/issues/9094
if word == '':
print("Emply Line, should replecaed by any thing else, or will cause a bug of tensorboard")
file_metadata.write("{0}".format('<Empty Line>').encode('utf-8') + b'\n')
else:
file_metadata.write(
"{0}\t{1}".format(word, labeler(word)).encode('utf-8') + b'\n')
weights = tf.Variable(placeholder, trainable=False, name=file_name)
checkpoint = tf.train.Checkpoint(embedding=weights)
checkpoint.save(os.path.join(output_path, f"embedding.ckpt"))
# Set up config
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
# The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`
embedding.tensor_name = f"embedding/.ATTRIBUTES/VARIABLE_VALUE"
embedding.metadata_path = meta_file
projector.visualize_embeddings(output_path, config)
print('Run `tensorboard --logdir={0}` to run visualize result on tensorboard'.format(output_path))
if __name__ == '__main__':
argh.dispatch_command(tensorboard_emb)