Source code for source.cnn_embeddings

import sys, os
sys.path.append("../../img2vec_privatefork/img2vec_pytorch")  # Adds higher directory to python modules path.
sys.path.append("../img2vec_privatefork/img2vec_pytorch")  # Adds higher directory to python modules path.
from img_to_vec import Img2Vec
from PIL import Image
import json
import pickle
from tqdm import tqdm
import argh
from argh import arg
from collections import defaultdict
from glob import glob
import warnings

from process_embeddings import serialize2npy


cnn_models = ['alexnet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
           'wide_resnet50_2', 'wide_resnet101_2']

[docs]@arg('-cnn', '--cnn_model', choices=cnn_models, default='resnet18') def get_cnn(image_dir, word_index_file, savedir=None, cnn_model='resnet18', agg_maxnum=10, gpu=False, filename_prefix=''): """Extract CNN representations for images in a directory and saves it into a dictionary file.""" img2vec = Img2Vec(model=cnn_model, cuda=gpu) # Dictionary of {words: {img_name: img_representation}} word_img_repr = defaultdict(dict) with open(word_index_file) as f: word_imgs = json.load(f) for word, img_names in tqdm(word_imgs.items()): for imgn in img_names: try: img = Image.open(os.path.join(image_dir, imgn)).convert('RGB') word_img_repr[word][imgn] = img2vec.get_vec(img) except FileNotFoundError: warnings.warn(f'Image {imgn} for word "{word}" is missing.') # Save representations if savedir is None: savedir = image_dir repr_path = os.path.join(savedir, '_'.join([filename_prefix, cnn_model + '.pkl'])) with open(repr_path, 'wb') as f: pickle.dump(word_img_repr, f) # Save aggregated embeddings serialize2npy(repr_path, savedir, agg_maxnum)
[docs]def create_index_from_fnames(image_dir, savepath): files = glob(image_dir + '/*.jpg') word_index_file = defaultdict(list) for f in files: fname = os.path.basename(f) word = fname.split('_')[0] word_index_file[word].append(fname) with open(savepath, 'w') as f: json.dump(word_index_file, f)
if __name__ == '__main__': argh.dispatch_command(get_cnn)