import cPickle as pickle
import msgpack
import numpy as np
# Load vocabulary w/ word frequencies
with open('wmt11.head.vocab', 'rb') as f:
vocab = msgpack.load(f)
# Load requisite vector data
with open('wmt11.head.vectors', 'rb') as f:
W = pickle.load(f)
id2word = dict((id, word) for word, (id, _) in vocab.iteritems())
# Normalize word vectors
for i, row in enumerate(W):
W[i, :] /= np.linalg.norm(row)
# Remove context word vectors
W = W[:len(vocab), :]
def most_similar(positive, negative, topn=10, freq_threshold=5):
# Build a "mean" vector for the given positive and negative terms
mean_vecs = []
for word in positive: mean_vecs.append(W[vocab[word][0]])
for word in negative: mean_vecs.append(-1 * W[vocab[word][0]])
mean = np.array(mean_vecs).mean(axis=0)
mean /= np.linalg.norm(mean)
# Now calculate cosine distances between this mean vector and all others
dists = np.dot(W, mean)
best = np.argsort(dists)[::-1][:topn + len(positive) + len(negative) + 100]
result = [(id2word[i], dists[i]) for i in best if (vocab[id2word[i]] >= freq_threshold
and id2word[i] not in positive
and id2word[i] not in negative)]
return result[:topn]
most_similar(['king', 'woman'], ['man'], topn=50)
[('queen', 0.69478105985944205), ('ace', 0.63016505472136219), ('trick', 0.62198680411172658), ('library', 0.61596180822198343), ('diamond', 0.61546379436428578), ('club', 0.60882108049620698), ('horse', 0.60577931043391597), ('ski', 0.5980682567370863), ('tennis', 0.59252997663757134), ('chef', 0.58578732345724127), ('museum', 0.58238877554666368), ('grandmother', 0.58148552464037506), ('diamonds', 0.58011253208849856), ('crown', 0.57997983286899146), ('seller', 0.57651369635738636), ('tip', 0.57446540473288965), ('oldest', 0.56849683935598727), ('holder', 0.56698181344304011), ('row', 0.56597681090513596), ('Museum', 0.56365025428845172), ('royal', 0.56291276425071346), ('Royal', 0.56191759370337424), ('farmer', 0.55962264699238262), ('Queen', 0.55947426321308247), ('colony', 0.55792198467607856), ('Maine', 0.55782081129452066), ('hat', 0.55772124209691432), ('dog', 0.5566658071093793), ('Valley', 0.55537812887550264), ('soccer', 0.55403076872031942), ('cinema', 0.55362014730217401), ('Latvia', 0.55191882205612497), ('hero', 0.55175383201232631), ('dancer', 0.55130889459560439), ('spade', 0.55042516938080366), ('Country', 0.54924256358808599), ('Yale', 0.54889249198494516), ('Rock', 0.54845215690150428), ('girlfriend', 0.54695823638084806), ('pool', 0.54691472405799435), ('neighbor', 0.54683901446532124), ('bars', 0.54670398577022916), ('bottle', 0.54548588559461253), ('pope', 0.54363205675334925), ('boyfriend', 0.54230260301709221), ('classic', 0.54154692864168852), ('interior', 0.54121481559609896), ('Buffalo', 0.54109694727311264), ('buyer', 0.5408952311440024), ('sheriff', 0.54027326795728747)]
most_similar(['brought', 'seek'], ['bring'], topn=50)
[('sought', 0.80168320406931981), ('seeking', 0.73662888334926047), ('forced', 0.69273739435205384), ('attempted', 0.68510971171255386), ('tried', 0.67516210714164604), ('allowed', 0.65480577594618783), ('urged', 0.64988576500767947), ('managed', 0.64642237086872134), ('seeks', 0.64400589953863585), ('refused', 0.63527139994723147), ('intended', 0.63348152287487691), ('unable', 0.62647201702998501), ('demanded', 0.62625912225250269), ('prompted', 0.62515185408955964), ('threatened', 0.62393983356386451), ('determined', 0.62224734077632959), ('attempting', 0.6181202137691465), ('hoped', 0.61761471480150942), ('prepared', 0.61306593357078376), ('encouraged', 0.61228972301898099), ('requested', 0.60900154224998204), ('followed', 0.60838821784825281), ('helped', 0.60657759212041662), ('attempt', 0.60619642784126782), ('failed', 0.6045095128678335), ('led', 0.60300298435099864), ('opted', 0.59973601096877438), ('granted', 0.59786114263781998), ('initiated', 0.59435889304397749), ('chosen', 0.59148716010685476), ('faced', 0.58759291122220392), ('wanted', 0.58484842439765283), ('refusing', 0.58473430252602809), ('addressed', 0.58405417697747952), ('offered', 0.58351159606145453), ('asking', 0.58349246210416927), ('rejected', 0.58123114186229263), ('decided', 0.58080346748275247), ('pledged', 0.57872006649827579), ('pressed', 0.57805915462397062), ('ordered', 0.57777715792227946), ('received', 0.57658184971738702), ('designed', 0.57650776001629578), ('persuaded', 0.57468989479783739), ('urging', 0.57295698093987468), ('accepted', 0.57247322897335085), ('allowing', 0.57049684140481094), ('able', 0.56909602033660478), ('calling', 0.56550493868341545), ('required', 0.565493188089782)]