unplugged-system/external/rappor/analysis/tensorflow/fast_em.py

181 lines
4.8 KiB
Python
Executable File

#!/usr/bin/python
"""
fast_em.py: Tensorflow implementation of expectation maximization for RAPPOR
association analysis.
TODO:
- Use TensorFlow ops for reading input (so that reading input can be
distributed)
- Reduce the number of ops (currently proportional to the number of reports).
May require new TensorFlow ops.
- Fix performance bug (v_split is probably being recomputed on every
iteration):
bin$ ./test.sh decode-assoc-cpp - 1.1 seconds (single-threaded C++)
bin$ ./test.sh decode-assoc-tensorflow - 226 seconds on GPU
"""
import sys
import numpy as np
import tensorflow as tf
def log(msg, *args):
if args:
msg = msg % args
print >>sys.stderr, msg
def ExpectTag(f, expected):
"""Read and consume a 4 byte tag from the given file."""
b = f.read(4)
if b != expected:
raise RuntimeError('Expected %r, got %r' % (expected, b))
def ReadListOfMatrices(f):
"""
Read a big list of conditional probability matrices from a binary file.
"""
ExpectTag(f, 'ne \0')
num_entries = np.fromfile(f, np.uint32, count=1)[0]
log('Number of entries: %d', num_entries)
ExpectTag(f, 'es \0')
entry_size = np.fromfile(f, np.uint32, count=1)[0]
log('Entry size: %d', entry_size)
ExpectTag(f, 'dat\0')
vec_length = num_entries * entry_size
v = np.fromfile(f, np.float64, count=vec_length)
log('Values read: %d', len(v))
log('v: %s', v[:10])
#print 'SUM', sum(v)
# NOTE: We're not reshaping because we're using one TensorFlow tensor object
# per matrix, since it makes the algorithm expressible with current
# TensorFlow ops.
#v = v.reshape((num_entries, entry_size))
return num_entries, entry_size, v
def WriteTag(f, tag):
if len(tag) != 3:
raise AssertionError("Tags should be 3 bytes. Got %r" % tag)
f.write(tag + '\0') # NUL terminated
def WriteResult(f, num_em_iters, pij):
WriteTag(f, 'emi')
emi = np.array([num_em_iters], np.uint32)
emi.tofile(f)
WriteTag(f, 'pij')
pij.tofile(f)
def DebugSum(num_entries, entry_size, v):
"""Sum the entries as a sanity check."""
cond_prob = tf.placeholder(tf.float64, shape=(num_entries * entry_size,))
debug_sum = tf.reduce_sum(cond_prob)
with tf.Session() as sess:
s = sess.run(debug_sum, feed_dict={cond_prob: v})
log('Debug sum: %f', s)
def BuildEmIter(num_entries, entry_size, v):
# Placeholder for the value from the previous iteration.
pij_in = tf.placeholder(tf.float64, shape=(entry_size,))
# split along dimension 0
# TODO:
# - make sure this doesn't get run for every EM iteration
# - investigate using tf.tile() instead? (this may cost more memory)
v_split = tf.split(0, num_entries, v)
z_numerator = [report * pij_in for report in v_split]
sum_z = [tf.reduce_sum(report) for report in z_numerator]
z = [z_numerator[i] / sum_z[i] for i in xrange(num_entries)]
# Concat per-report tensors and reshape. This is probably inefficient?
z_concat = tf.concat(0, z)
z_concat = tf.reshape(z_concat, [num_entries, entry_size])
# This whole expression represents an EM iteration. Bind the pij_in
# placeholder, and get a new estimation of Pij.
em_iter_expr = tf.reduce_sum(z_concat, 0) / num_entries
return pij_in, em_iter_expr
def RunEm(pij_in, entry_size, em_iter_expr, max_em_iters, epsilon=1e-6):
"""Run the iterative EM algorithm (using the TensorFlow API).
Args:
num_entries: number of matrices (one per report)
entry_size: total number of cells in each matrix
v: numpy.ndarray (e.g. 7000 x 8 matrix)
max_em_iters: maximum number of EM iterations
Returns:
pij: numpy.ndarray (e.g. vector of length 8)
"""
# Initial value is the uniform distribution
pij = np.ones(entry_size) / entry_size
i = 0 # visible outside loop
# Do EM iterations.
with tf.Session() as sess:
for i in xrange(max_em_iters):
print 'PIJ', pij
new_pij = sess.run(em_iter_expr, feed_dict={pij_in: pij})
dif = max(abs(new_pij - pij))
log('EM iteration %d, dif = %e', i, dif)
pij = new_pij
if dif < epsilon:
log('Early EM termination: %e < %e', max_dif, epsilon)
break
# If i = 9, then we did 10 iteratinos.
return i + 1, pij
def sep():
print '-' * 80
def main(argv):
input_path = argv[1]
output_path = argv[2]
max_em_iters = int(argv[3])
sep()
with open(input_path) as f:
num_entries, entry_size, cond_prob = ReadListOfMatrices(f)
sep()
DebugSum(num_entries, entry_size, cond_prob)
sep()
pij_in, em_iter_expr = BuildEmIter(num_entries, entry_size, cond_prob)
num_em_iters, pij = RunEm(pij_in, entry_size, em_iter_expr, max_em_iters)
sep()
log('Final Pij: %s', pij)
with open(output_path, 'wb') as f:
WriteResult(f, num_em_iters, pij)
log('Wrote %s', output_path)
if __name__ == '__main__':
try:
main(sys.argv)
except RuntimeError, e:
print >>sys.stderr, 'FATAL: %s' % e
sys.exit(1)