Skip to content

Commit

Permalink
I've arrived at a good 'smallsim' example.
Browse files Browse the repository at this point in the history
  • Loading branch information
pcarbo committed Jun 24, 2024
1 parent 8a8630e commit 4e9dfbe
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 51 deletions.
108 changes: 73 additions & 35 deletions analysis/smallsim_hard.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ source("../code/smallsim_functions.R")
Simulate a $100 \times 400$ counts matrix from a multinomial topic
model with $K = 6$ topics.

```{r sim-data, fig.height=1.5, fig.width=6, message=FALSE, warning=FALSE}
set.seed(2)
```{r sim-data}
set.seed(4)
n <- 100
m <- 400
k <- 6
Expand All @@ -42,50 +42,88 @@ L <- out$L
major_topic <- out$major_topic
s <- simulate_sizes(n)
X <- simulate_multinom_counts(L,F,s)
X <- X[,colSums(X > 0) > 0]
topic_colors <- c("dodgerblue","darkorange","forestgreen","darkblue",
"gold","skyblue")
simdata_structure_plot(L,major_topic,topic_colors)
cols <- which(colSums(X > 0) > 0)
F <- F[cols,]
X <- X[,cols]
```

ADD TEXT HERE.
We fit the multinomial topic model by performing (i) 180 EM updates
(without extrapolation), or (ii) 180 SCD updates (with extrapolation).
Both of the fits are initialized by running 20 EM updates.

```{r, eval=FALSE}
control <- list(extrapolate = FALSE,numiter = 4, eval=FALSE)
fit0 <- fit_poisson_nmf(X,k,numiter = 20,method = "em",control = control)
```{r fit-models, message=FALSE, results="hide"}
control <- list(extrapolate = FALSE,numiter = 4)
fit0 <- fit_poisson_nmf(X,k,numiter=20,method="em",control=control)
fit1 <- fit_poisson_nmf(X,fit0=fit0,numiter=180,method="em",control=control)
fit2 <- fit_poisson_nmf(X,fit0=fit0,numiter=80,method="scd",control=control)
control$extrapolate <- TRUE
fit2 <- fit_poisson_nmf(X,fit0=fit2,numiter=100,method="scd",control=control)
fit0 <- poisson2multinom(fit0)
fit1 <- poisson2multinom(fit1)
fit2 <- poisson2multinom(fit2)
print(loadings_scatterplot(fit1$L,fit2$L,topic_colors,"em","scd"))
```

```{r, eval=FALSE}
pdat <- rbind(data.frame(iter = 1:200,
ll = fit1$progress$loglik.multinom,
res = fit1$progress$res,
method = "em"),
data.frame(iter = 1:200,
EM and SCD produce quite different estimates, and among the two, the
SCD estimates are much closer to the truth.

```{r compare-fits, fig.height=4, fig.width=5, message=FALSE, warning=FALSE}
topic_colors <- c("dodgerblue","darkorange","forestgreen","darkblue",
"gold","skyblue")
loadings_order <- order(major_topic,L[,1])
k_set <- c(1,3,5,2,4,6)
p1 <- simdata_structure_plot(L,loadings_order,topic_colors,title = "true")
p2 <- simdata_structure_plot(poisson2multinom(fit1)$L,loadings_order,
topic_colors[k_set],title = "EM")
p3 <- simdata_structure_plot(poisson2multinom(fit2)$L,loadings_order,
topic_colors[k_set],title = "SCD")
plot_grid(p1,p2,p3,nrow = 3,ncol = 1)
```

Indeed, the SCD estimates also improve upon the EM estimates in terms
of log-likelihood, with a total improvement of 600 log-likelihood units,

```{r compare-logliks}
c("em" = sum(loglik_multinom_topic_model(X,fit1)),
"scd" = sum(loglik_multinom_topic_model(X,fit2)))
```

or an average of 6 log-likelihood units per document,

```{r compare-logliks-2}
c("em" = sum(loglik_multinom_topic_model(X,fit1)),
"scd" = sum(loglik_multinom_topic_model(X,fit2)))/n
```

What is reassuring is that if we continue to perform the EM updates,
we eventually arrive at the same solution as SCD. But SCD is able to
"rescue" the EM estimates much more quickly after performing just a
few SCD updates.

```{r fit-models-2, message=FALSE, results="hide", fig.height=2.5, fig.width=3.5}
control$extrapolate <- FALSE
fit3 <- fit_poisson_nmf(X,fit0=fit1,numiter=600,method="em",control=control)
control$extrapolate <- TRUE
fit2 <- fit_poisson_nmf(X,fit0=fit2,numiter=600,method="scd",control=control)
fit4 <- fit_poisson_nmf(X,fit0=fit1,numiter=600,method="scd",control=control)
fit1 <- poisson2multinom(fit1)
fit2 <- poisson2multinom(fit2)
fit3 <- poisson2multinom(fit3)
fit4 <- poisson2multinom(fit4)
# loadings_scatterplot(F[,k_set],fit1$F,topic_colors,"true","em")
# loadings_scatterplot(F[,k_set],fit2$F,topic_colors,"true","scd")
pdat <- rbind(data.frame(iter = 1:800,
ll = fit2$progress$loglik.multinom,
res = fit2$progress$res,
method = "scd"))
pdat <- transform(pdat,
ll = max(ll) - ll + 0.1)
p <- ggplot(pdat,aes(x = iter,y = ll,color = method)) +
geom_line(size = 0.75) +
scale_y_continuous(trans = "log10") +
scale_color_manual(values = c("dodgerblue","darkorange")) +
method = "scd"),
data.frame(iter = 1:800,
ll = fit3$progress$loglik.multinom,
method = "em"),
data.frame(iter = 1:800,
ll = fit4$progress$loglik.multinom,
method = "em+scd"))
pdat <- transform(pdat,ll = max(ll) - ll + 0.1)
ggplot(pdat,aes(x = iter,y = ll,color = method)) +
geom_line(linewidth = 0.75) +
scale_x_continuous(breaks = seq(0,800,100)) +
scale_y_continuous(trans = "log10",breaks = 10^seq(-1,4)) +
scale_color_manual(values = c("dodgerblue","darkorange","magenta")) +
labs(x = "iteration",y = "loglik difference") +
theme_cowplot(font_size = 10)
print(p)
```

```{r, eval=FALSE}
p1 <- simdata_structure_plot(L,topic_colors)
p2 <- simdata_structure_plot(fit1$L,topic_colors)
p3 <- simdata_structure_plot(fit2$L,topic_colors)
plot_grid(p1,p2,p3,nrow = 3,ncol = 1)
```
30 changes: 14 additions & 16 deletions code/smallsim_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ normalize.cols <- function (A)
simulate_sizes <- function (n)
ceiling(10^rnorm(n,3,0.20))

# Randomly generate an m x k factors matrix for a multinomial topic
# model with k topics. Each row of the factors matrix are generated
# Randomly generate an m x k factors (F) matrix for a multinomial topic
# model with k topics. Each row of the factors (F) matrix are generated
# according to the following procedure: generate u = |r| - 5, where r
# ~ N(0,3); for each k, generate the Poisson rates as exp(max(t,-5)),
# where t ~ 0.8 * N(u,s/10) + 0.2 * N(u,s)}, and s = exp(-u/3).
Expand All @@ -28,11 +28,10 @@ simulate_factors <- function (m, k) {
return(normalize.cols(F))
}

# Randomly generate an n x k loadings matrix (i.e., mixture
# proportions matrix) for a multinomial topic model with k topics. The
# first topic is present in varying proportions in all documents. In
# most documents, a single "major" topic predominates. Note that k
# should be 3 or more.
# Randomly generate an n x k loadings (L) matrix for a multinomial topic
# model with k topics. The first topic is present in varying
# proportions in all documents. In most documents, a single "major"
# topic predominates. Note that k should be 3 or more.
simulate_loadings <- function (n, k) {
L <- matrix(0,n,k)
L[,1] <- runif(n,0,2)
Expand Down Expand Up @@ -60,18 +59,17 @@ simulate_multinom_counts <- function (L, F, s) {
return(X)
}

# Create a Structure plot to visualize the mixture proportions L.
simdata_structure_plot <- function (L, major_topic, topic_colors)
# Create a 'Structure plot' to visualize the mixture proportions L.
simdata_structure_plot <- function (L, loadings_order, topic_colors, title = "")
structure_plot(L,topics = 1:k,colors = topic_colors,
loadings_order = order(major_topic)) +
loadings_order = loadings_order) +
scale_x_continuous(breaks = seq(0,100,10)) +
xlab("document") +
theme(axis.text.x = element_text(angle = 0,hjust = 0))
labs(x = "document",title = title) +
theme(axis.text.x = element_text(angle = 0,hjust = 0),
plot.title = element_text(size = 10,face = "plain"))

# Compare two estimates of L (the topic proportions matrix) in a
# scatterplot.
loadings_scatterplot <- function (L1, L2, colors, xlab = "fit1",
ylab = "fit2") {
# Compare two estimates of L or F in a scatterplot.
loadings_scatterplot <- function (L1, L2, colors, xlab = "fit1", ylab = "fit2") {
n <- nrow(L1)
k <- ncol(L2)
pdat <- data.frame(x = as.vector(L1),
Expand Down

0 comments on commit 4e9dfbe

Please sign in to comment.