unplugged-system/external/rappor/tests/gen_counts.R

214 lines
7.0 KiB
R
Executable File

#!/usr/bin/env Rscript
#
# Copyright 2015 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
source('analysis/R/read_input.R')
RandomPartition <- function(total, weights) {
# Outputs a random partition according to a specified distribution
# Args:
# total - number of samples
# weights - weights that are proportional to the probability density
# function of the target distribution
# Returns:
# a histogram sampled according to the pdf
# Example:
# > RandomPartition(100, c(3, 2, 1, 0, 1))
# [1] 47 24 15 0 14
if (any(weights < 0))
stop("Probabilities cannot be negative")
if (sum(weights) == 0)
stop("Probabilities cannot sum up to 0")
bins <- length(weights)
result <- rep(0, bins)
# idiomatic way:
# rnd_list <- sample(strs, total, replace = TRUE, weights)
# apply(as.array(strs), 1, function(x) length(rnd_list[rnd_list == x]))
#
# The following is much faster for larger totals. We can replace a loop with
# (tail) recusion, but R chokes with the recursion depth > 850.
w <- sum(weights)
for (i in 1:bins)
if (total > 0) { # if total == 0, nothing else to do
# invariant: w = sum(weights[i:bins])
# rather than computing sum every time leading to quadratic time, keep
# updating it
# The probability p is clamped to [0, 1] to avoid under/overflow errors.
p <- min(max(weights[i] / w, 0), 1)
# draw the number of balls falling into the current bin
rnd_draw <- rbinom(n = 1, size = total, prob = p)
result[i] <- rnd_draw # push rnd_draw balls from total to result[i]
total <- total - rnd_draw
w <- w - weights[i]
}
names(result) <- names(weights)
return(result)
}
GenerateCounts <- function(params, true_map, partition, reports_per_client) {
# Fast simulation of the marginal table for RAPPOR reports
# Args:
# params - parameters of the RAPPOR reporting process
# true_map - hashed true inputs
# partition - allocation of clients between true values
# reports_per_client - number of reports (IRRs) per client
if (nrow(true_map) != (params$m * params$k)) {
stop(cat("Map does not match the params file!",
"mk =", params$m * params$k,
"nrow(map):", nrow(true_map),
sep = " "))
}
# For each reporting type computes its allocation to cohorts.
# Output is an m x strs matrix.
cohorts <- as.matrix(
apply(as.data.frame(partition), 1,
function(count) RandomPartition(count, rep(1, params$m))))
# Expands to (m x k) x strs matrix, where each element (corresponding to the
# bit in the aggregate Bloom filter) is repeated k times.
expanded <- apply(cohorts, 2, function(vec) rep(vec, each = params$k))
# For each bit, the number of clients reporting this bit:
clients_per_bit <- rep(apply(cohorts, 1, sum), each = params$k)
# Computes the true number of bits set to one BEFORE PRR.
true_ones <- apply(expanded * true_map, 1, sum)
ones_in_prr <-
unlist(lapply(true_ones,
function(x) rbinom(n = 1, size = x, prob = 1 - params$f / 2))) +
unlist(lapply(clients_per_bit - true_ones, # clients where the bit is 0
function(x) rbinom(n = 1, size = x, prob = params$f / 2)))
# Number of IRRs where each bit is reported (either as 0 or as 1)
reports_per_bit <- clients_per_bit * reports_per_client
ones_before_irr <- ones_in_prr * reports_per_client
ones_after_irr <-
unlist(lapply(ones_before_irr,
function(x) rbinom(n = 1, size = x, prob = params$q))) +
unlist(lapply(reports_per_bit - ones_before_irr,
function(x) rbinom(n = 1, size = x, prob = params$p)))
counts <- cbind(apply(cohorts, 1, sum) * reports_per_client,
matrix(ones_after_irr, nrow = params$m, ncol = params$k, byrow = TRUE))
if(any(is.na(counts)))
stop("Failed to generate bit counts. Likely due to integer overflow.")
counts
}
ComputePdf <- function(distr, range) {
# Outputs discrete probability density function for a given distribution
# These are the five distributions in gen_sim_input.py
if (distr == 'exp') {
pdf <- dexp(1:range, rate = 5 / range)
} else if (distr == 'gauss') {
half <- range / 2
left <- -half + 1
pdf <- dnorm(left : half, sd = range / 6)
} else if (distr == 'unif') {
# e.g. for N = 4, weights are [0.25, 0.25, 0.25, 0.25]
pdf <- dunif(1:range, max = range)
} else if (distr == 'zipf1') {
# Since the distrubition defined over a finite set, we allow the parameter
# of the Zipf distribution to be 1.
pdf <- sapply(1:range, function(x) 1 / x)
} else if (distr == 'zipf1.5') {
pdf <- sapply(1:range, function(x) 1 / x^1.5)
}
else {
stop(sprintf("Invalid distribution '%s'", distr))
}
pdf <- pdf / sum(pdf) # normalize
pdf
}
# Usage:
#
# $ ./gen_counts.R exp 10000 1 foo_params.csv foo_true_map.csv foo
#
# Inputs:
# distribution name
# number of clients
# reports per client
# parameters file
# map file
# prefix for output files
# Outputs:
# foo_counts.csv
# foo_hist.csv
#
# Warning: the number of reports in any cohort must be less than
# .Machine$integer.max
main <- function(argv) {
distr <- argv[[1]]
num_clients <- as.integer(argv[[2]])
reports_per_client <- as.integer(argv[[3]])
params_file <- argv[[4]]
true_map_file <- argv[[5]]
out_prefix <- argv[[6]]
params <- ReadParameterFile(params_file)
true_map <- ReadMapFile(true_map_file, params)
num_unique_values <- length(true_map$strs)
pdf <- ComputePdf(distr, num_unique_values)
# Computes the number of clients reporting each string
# according to the pre-specified distribution.
partition <- RandomPartition(num_clients, pdf)
# Histogram
true_hist <- data.frame(string = true_map$strs, count = partition)
counts <- GenerateCounts(params, true_map$map, partition, reports_per_client)
# Now create a CSV file
# Opposite of ReadCountsFile in read_input.R
# http://stackoverflow.com/questions/6750546/export-csv-without-col-names
counts_path <- paste0(out_prefix, '_counts.csv')
write.table(counts, file = counts_path,
row.names = FALSE, col.names = FALSE, sep = ',')
cat(sprintf('Wrote %s\n', counts_path))
# TODO: Don't write strings that appear 0 times?
hist_path <- paste0(out_prefix, '_hist.csv')
write.csv(true_hist, file = hist_path, row.names = FALSE)
cat(sprintf('Wrote %s\n', hist_path))
}
if (length(sys.frames()) == 0) {
main(commandArgs(TRUE))
}