-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict_responses.Rmd
119 lines (99 loc) · 3.98 KB
/
predict_responses.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
---
title: "Twitter Response Generator"
output: html_document
---
```{r setup, include=FALSE}
# Required R package installation:
# These will install packages if they are not already installed
# Set the correct default repository
r = getOption("repos")
r["CRAN"] = "http://cran.rstudio.com"
options(repos = r)
if (!require("knitr")) {
install.packages("knitr")
library(knitr)
}
if (!require("kableExtra")) {
install.packages("kableExtra")
library(kableExtra)
}
if (!require("httr")) {
install.packages("httr")
library(httr)
}
knitr::opts_chunk$set(echo = TRUE)
source("text_helpers.R")
```
```{r, echo=FALSE}
response_sampler_url <- "http://localhost:8080/batchsampleresponses"
embedding_type <- "sbert"
embedding_dims <- if (embedding_type=="use_large") 512 else 384
```
### Predict responses to a prompt tweet:
```{r}
# Prompt author (e.g., WHO, CDCgov, CDCDirector, etc...)
prompt_authors <- list("WHO", "CDCgov", "CDCDirector")#, "ECDC_EU")
# Prompt message
prompt_messages <- list("4 steps to stay fit to beat #COVID19: eat a healthy diet #notobacco be physically active stop harmful use of alcohol",
"Thank you for doing your part beat #COVID19! We are excited to share these 4 steps anybody can take to stay fit against #COVID19: eat a healthy diet, reduce tobacco use, be physically active, and drink alcohol in moderation.")
# Number of responses to predict
response_sample_size <- 30
output_latex <- FALSE
# Generation hyperparams
num_beams <- 3
temperature <- 1.5
random_seed <- 42
```
```{r, echo=FALSE}
sample_model_responses <- function(prompt_author, prompt_message, response_sample_size, num_beams,
temperature, random_seed) {
body <- list(sample_size = response_sample_size,
num_beams = num_beams,
temperature = temperature,
prompts = list(list(author=prompt_author, message=clean_text(prompt_message))))
if (isFALSE(is.null(random_seed)) && isFALSE(is.na(random_seed))) {
body["random_state"] <- random_seed
}
res <- POST(url=response_sampler_url, encode="json", body=body)
res.list <- content(res)
model_responses <- lapply(1:length(res.list), function(i) {
return (list(responses=data.frame(quoted_status.screen_name=body$prompts[[i]]$author,
quoted_status.full_text=body$prompts[[i]]$message,
full_text=unlist(res.list[[i]][[1]]),
sentiment=unlist(res.list[[i]][[3]])),
responses.vectors=matrix(sapply(do.call("rbind", res.list[[i]][[2]]), function(x) x),
nrow=response_sample_size, ncol=embedding_dims)))
})
return(model_responses)
}
```
```{r, echo=FALSE}
if (output_latex){
output <- list()
} else {
output <- htmltools::tagList()
}
output_idx <- 1
for (prompt_message in prompt_messages) {
for (prompt_author in prompt_authors) {
results <- sample_model_responses(prompt_author, prompt_message, response_sample_size, num_beams,
temperature, random_seed)
avg_sentiment <- mean(results[[1]]$responses$sentiment)
sd_sentiment <- sd(results[[1]]$responses$sentiment)
caption <- paste0("Responses to ", prompt_author, ": ", prompt_message, " (mean: ", round(avg_sentiment, 3),
"; sd: ", round(sd_sentiment, 3), ")")
if (output_latex) {
results_latex <- kable(results[[1]]$responses[,c("full_text", "sentiment")], "latex",
col.names=c("Generated Responses:", "Sentiment:"), caption=caption, booktabs=TRUE)
output[[output_idx]] <- results_latex
} else {
results_table <- kable(results[[1]]$responses[,c("full_text", "sentiment")],
col.names=c("Generated Responses:", "Sentiment:"), caption=caption) %>%
kable_styling()
output[[output_idx]] <- htmltools::HTML(results_table)
}
output_idx <- output_idx + 1
}
}
output
```