# Utils
from scipy.stats import bernoulli
from scipy.stats import beta
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import multinomial
from numpy.random import dirichlet
%matplotlib inline
# 25 word vocabulary
W = 25
# image size
L = np.sqrt(W)
# 10 topics
T = 10
# 100 documents
D = 100
# 10 words per document
N = 15
# phi is given as the horizontal and vertical topics on the 5X5 images
phi = [np.zeros((L, L)) for i in range(T)]
line = 0
for phi_t in phi:
if line >= L:
trueLine = int(line - L)
phi_t[:,trueLine] = 1/L*np.ones(L)
else:
phi_t[line] = 1/L*np.ones(L)
line += 1
# plot the topics
f, axs = plt.subplots(1,T+1,figsize=(15,1))
ax = axs[0]
ax.text(0,0.4, "Topics: ", fontsize = 16)
ax.axis("off")
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
for (ax, (i,phi_t)) in zip(axs[1:], enumerate(phi)):
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
ax.imshow(phi_t, interpolation='none', cmap='Greys_r')
f.savefig('phi_train.png', format='png')
# sample theta from alpha = 1
theta = [dirichlet(np.ones(T)) for i in range(D)]
# sample documents from theta and phi
B = []
for d in range(D):
doc = np.zeros(W)
theta_sample = multinomial(N, theta[d])
for t,count in enumerate(theta_sample):
doc += multinomial(count, phi[int(t)].flatten())
doc = doc.reshape(L,L)
B.append(doc)
words = []
for b in B:
doc = []
for idx,el in enumerate(b.flatten()):
doc += [idx]*el
words += doc
# write Aprob file
with open('../in/lda_artificial.pl', 'w') as fout:
fout.write('% docs\n')
for d in range(D):
fout.write('doc(d{}).\n'.format(d))
fout.write('% words\n')
for w in set(words):
fout.write('word(w{}).\n'.format(w))
fout.write('% topics\n')
for t in range(T):
fout.write('topic(t{}).\n'.format(t))
fout.write('% data\n')
for d in range(D):
word_list = map(lambda x: 'w{}'.format(x),
words[d*N:d*N+N])
str_word_list = '['+','.join(word_list)+']'
fout.write('observe(d{}, {}).\n'.format(d, str_word_list))
fout.write('% prob distribs\n')
fout.write(('aprob_dirichlet_share({}, D, theta(T,D)) :-'+
'topic(T), doc(D).\n').format(list(np.ones(T))))
fout.write(('aprob_dirichlet_share({}, T, phi(W,T)) :-'+
'word(W), topic(T).\n').format(list(np.ones(W))))
fout.write('''
% aprob flags
%:- set_value(dbg_read,2).
%:- set_value(dbg_query,2).
%:- set_value(dbg_write,2).
% template for lda
%lda :-
% theta(T,d0),
% phi(w0,T).
% plate iterator
% aprob_plate(+IteratorQuery, +PlateQuery)
% IteratorQuery - query that iterates through plates
% PlateQuery - plate specific query
aprob_plate([observe(D, WordList), member(Word, WordList)],
[theta(T,D),phi(Word,T)]).''')
%%bash
cat ../in/lda_query.pl
echo 'This is the query file we use to run Aprob (the logical inference part).'
% load the Aprob prolog file :- ['../src/aprob.pl']. % load the input file and query :- load_and_query_aprob('../in/lda_artificial.pl'). % stop prolog :- halt. This is the query file we use to run Aprob (the logical inference part).
%%time
%%bash
sicstus -l ../in/lda_query.pl
% compiling /home/rares/p/aprob/in/lda_query.pl... % compiling /home/rares/p/aprob/src/aprob.pl... % module abduction imported into user % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/clpfd.po... % module clpfd imported into abduction % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/atts.po... % module attributes imported into clpfd % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/types.po... % module types imported into attributes % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/types.po in module types, 0 msec 1424 bytes % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/atts.po in module attributes, 0 msec 27824 bytes % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/avl.po... % module avl imported into clpfd % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/avl.po in module avl, 0 msec 47376 bytes % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/lists.po... % module lists imported into clpfd % module types imported into lists % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/lists.po in module lists, 0 msec 109104 bytes % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/ordsets.po... % module ordsets imported into clpfd % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/ordsets.po in module ordsets, 0 msec 37248 bytes % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/trees.po... % module trees imported into clpfd % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/trees.po in module trees, 0 msec 9856 bytes % module types imported into clpfd % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/samsort.po... % module samsort imported into clpfd % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/samsort.po in module samsort, 0 msec 21424 bytes % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/timeout.po... % module timeout imported into clpfd % module types imported into timeout % loading foreign resource /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/x86_64-linux-glibc2.7/timeout.so in module timeout % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/timeout.po in module timeout, 0 msec 12320 bytes % loading foreign resource /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/x86_64-linux-glibc2.7/clpfd.so in module clpfd % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/clpfd.po in module clpfd, 20 msec 1812320 bytes % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/clpr.po... % module clpr imported into abduction % module arith_r imported into clpr % module types imported into arith_r % module types imported into clpr % module attributes imported into clpr % module geler_r imported into clpr % module attributes imported into geler_r % module nfr imported into clpr % module arith_r imported into nfr % module clpr imported into nfr % module types imported into nfr % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/terms.po... % module terms imported into nfr % module types imported into terms % module avl imported into terms % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/terms.po in module terms, 0 msec 40752 bytes % module geler_r imported into nfr % module classr imported into clpr % module clpr imported into classr % module types imported into classr % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/ugraphs.po... % module ugraphs imported into classr % module ordsets imported into ugraphs % module lists imported into ugraphs % module avl imported into ugraphs % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/random.po... % module random imported into ugraphs % module types imported into random % loading foreign resource /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/x86_64-linux-glibc2.7/random.so in module random % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/random.po in module random, 10 msec 23520 bytes % module types imported into ugraphs % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/ugraphs.po in module ugraphs, 10 msec 76624 bytes % module attributes imported into classr % module ordsets imported into clpr % module terms imported into clpr % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/assoc3.po... % module assoc3 imported into clpr % module avl imported into assoc3 % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/assoc3.po in module assoc3, 0 msec 11664 bytes % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/clpr.po in module clpr, 10 msec 668336 bytes % module attributes imported into abduction % module ordsets imported into abduction % module terms imported into abduction % module lists imported into abduction % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/sets.po... % module sets imported into abduction % module lists imported into sets % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/sets.po in module sets, 0 msec 20640 bytes % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/process.po... % module process imported into abduction % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/process.po in module process, 0 msec 4976 bytes % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/file_systems.po... % module file_systems imported into abduction % module types imported into file_systems % loading /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/system.po... % module system imported into file_systems % module types imported into system % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/system.po in module system, 0 msec 4288 bytes % module lists imported into file_systems % loaded /usr/local/sicstus4.2.3/bin/sp-4.2.3/sicstus-4.2.3/library/file_systems.po in module file_systems, 0 msec 47936 bytes % module timeout imported into abduction % module system imported into abduction * [F,L2,LNew,L2New,Paths,Time3,L2NoNegNew,AbdOrder,AbdOrderNew,PathsNew,L2NoNeg] - singleton variables * Approximate lines: 3197-3308, file: '/home/rares/p/aprob/src/aprob.pl' % compiled /home/rares/p/aprob/src/aprob.pl in module abduction, 300 msec 3833072 bytes % compiled /home/rares/p/aprob/in/lda_query.pl in module user, 28580 msec 9044800 bytes
CPU times: user 8 ms, sys: 0 ns, total: 8 ms Wall time: 29.4 s
# Ok, now we use the rest of the Aprob pipeline to:
# 1. Parse the Formula (and the other files)
from formula_gen import Formula
option_args = {
'probs_file' : '../out/out.probs',
'sat_file' : '../out/out.sat',
'plate_file' : '../out/out.plate'
}
option = 'read_sat'
fr = Formula(option, option_args)
fr.formula.to_formula()
'((10&1)|((11&2)&~1)|((12&3)&(~1&~2))|((13&4)&(~1&~2&~3))|((14&5)&(~1&~2&~3&~4))|((15&6)&(~1&~2&~3&~4&~5))|((16&7)&(~1&~2&~3&~4&~5&~6))|((17&8)&(~1&~2&~3&~4&~5&~6&~7))|((18&9)&(~1&~2&~3&~4&~5&~6&~7&~8))|(19&(~1&~2&~3&~4&~5&~6&~7&~8&~9))|(10&1)|((11&2)&~1)|((12&3)&(~1&~2))|((13&4)&(~1&~2&~3))|((14&5)&(~1&~2&~3&~4))|((15&6)&(~1&~2&~3&~4&~5))|((16&7)&(~1&~2&~3&~4&~5&~6))|((17&8)&(~1&~2&~3&~4&~5&~6&~7))|((18&9)&(~1&~2&~3&~4&~5&~6&~7&~8))|(19&(~1&~2&~3&~4&~5&~6&~7&~8&~9))|(10&1)|((11&2)&~1)|((12&3)&(~1&~2))|((13&4)&(~1&~2&~3))|((14&5)&(~1&~2&~3&~4))|((15&6)&(~1&~2&~3&~4&~5))|((16&7)&(~1&~2&~3&~4&~5&~6))|((17&8)&(~1&~2&~3&~4&~5&~6&~7))|((18&9)&(~1&~2&~3&~4&~5&~6&~7&~8))|(19&(~1&~2&~3&~4&~5&~6&~7&~8&~9)))'
# 2. Compile it to a BDD
from knowledge_compilation import BDD
bdd = BDD(option='pycudd', formula = fr.formula)
bdd.compile()
%%time
# 3. Perform sampling on the BDD plate
from prob_inference import PyCUDDInference
inf = PyCUDDInference(bdd)
gibbs_options = {
'N': 100
}
inf.gibbs_sampler_plate(gibbs_options)