Skip to content
This repository has been archived by the owner on Feb 11, 2024. It is now read-only.

Commit

Permalink
Make tolower work for dfm() ref #27 (#45)
Browse files Browse the repository at this point in the history
* Make `tolower` work for dfm()

* Update doc, add tests
  • Loading branch information
chainsawriot authored Nov 22, 2023
1 parent 6272667 commit 5f08a69
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 14 deletions.
30 changes: 24 additions & 6 deletions R/get_dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ tokens_proximity <- function(x, pattern, get_min = TRUE, valuetype = c("glob", "
quanteda::meta(toks, field = "pattern") <- pp(pattern)
attr(toks, "pattern") <- pattern ## custom field
quanteda::meta(toks, field = "get_min") <- get_min
quanteda::meta(toks, field = "valuetype") <- valuetype
quanteda::meta(toks, field = "case_insensitive") <- case_insensitive
quanteda::meta(toks, field = "count_from") <- count_from
quanteda::meta(toks, field = "tolower") <- tolower
quanteda::meta(toks, field = "keep_acronyms") <- keep_acronyms
class(toks) <- c("tokens_with_proximity")
Expand Down Expand Up @@ -174,17 +177,29 @@ convert.tokens_with_proximity <- function(x, to = c("data.frame"), ...) {
return(do.call(rbind, result_list))
}

tokens_proximity_tolower <- function(x) {
## update from inside, docvars(x, "proximity") is updated too.
return(tokens_proximity(x, pattern = attr(x, "pattern"),
get_min = quanteda::meta(x, "get_min"),
valuetype = quanteda::meta(x, "valuetype"),
case_insensitive = quanteda::meta(x, "case_insensitive"),
count_from = quanteda::meta(x, "count_from"),
tolower = TRUE, keep_acronyms = quanteda::meta(x, "count_from"))
)
}

#' Create a document-feature matrix
#'
#' Construct a sparse document-feature matrix from the output of [tokens_proximity()].
#' @param x output of [tokens_proximity()]
#' @param tolower ignored
#' @param remove_padding ignored
#' @param remove_docvars_proximity boolean, remove the "proximity" document variable
#' @param verbose ignored
#' @param weight_function a weight function, default to invert distance
#' @param x output of [tokens_proximity()].
#' @param tolower convert all features to lowercase.
#' @param remove_padding ignored.
#' @param remove_docvars_proximity boolean, remove the "proximity" document variable.
#' @param verbose ignored,
#' @param weight_function a weight function, default to invert distance,
#' @param ... not used.
#' @importFrom quanteda dfm
#' @return a [quanteda::dfm-class] object
#' @details By default, words closer to keywords are weighted higher. You might change that with another `weight_function`. Please also note that `tolower` and `remove_padding` have no effect. It is because changing tokens at this point would need to recalculate the proximity vectors. Please do all the text manipulation before running [tokens_proximity()].
#' @examples
#' library(quanteda)
Expand Down Expand Up @@ -217,6 +232,9 @@ dfm.tokens_with_proximity <- function(x, tolower = TRUE, remove_padding = FALSE,
weight_function = function(x) {
1 / x
}, ...) {
if (!quanteda::meta(x, "tolower") && tolower) {
x <- tokens_proximity_tolower(x)
}
x_attrs <- attributes(x)
x_docvars <- quanteda::docvars(x)
x_docnames <- attr(x, "docvars")$docname_
Expand Down
15 changes: 9 additions & 6 deletions man/dfm.tokens_with_proximity.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 16 additions & 2 deletions tests/testthat/test-dfm.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ test_that("normal", {
suppressPackageStartupMessages(library(quanteda))
testdata <-
c("Turkish President Tayyip Erdogan, in his strongest comments yet on the Gaza conflict, said on Wednesday the Palestinian militant group Hamas was not a terrorist organisation but a liberation group fighting to protect Palestinian lands.")
res <- testdata %>% tokens() %>% tokens_tolower() %>% tokens_proximity(pattern = "turkish")
res <- testdata %>% tokens() %>% tokens_proximity(pattern = "turkish")
res %>% dfm() -> output
expect_equal(as.numeric(output[1,"in"]), 0.166666, tolerance = 0.0001)
})
Expand All @@ -11,7 +11,21 @@ test_that("weight function", {
suppressPackageStartupMessages(library(quanteda))
testdata <-
c("Turkish President Tayyip Erdogan, in his strongest comments yet on the Gaza conflict, said on Wednesday the Palestinian militant group Hamas was not a terrorist organisation but a liberation group fighting to protect Palestinian lands.")
res <- testdata %>% tokens() %>% tokens_tolower() %>% tokens_proximity(pattern = "turkish")
res <- testdata %>% tokens() %>% tokens_proximity(pattern = "turkish")
res %>% dfm(weight_function = identity) -> output2
expect_equal(as.numeric(output2[1,","]), 20, tolerance = 0.0001)
})

test_that("tolower", {
suppressPackageStartupMessages(library(quanteda))
testdata <-
c("Turkish President Tayyip Erdogan, in his strongest comments yet on the Gaza conflict, said on Wednesday the Palestinian militant group Hamas was not a terrorist organisation but a liberation group fighting to protect Palestinian lands.")
res <- testdata %>% tokens() %>% tokens_proximity(pattern = "turkish", tolower = FALSE)
res %>% dfm(tolower = TRUE) -> output
expect_true("turkish" %in% colnames(output))
res %>% dfm(tolower = FALSE) -> output
expect_false("turkish" %in% colnames(output))
res <- testdata %>% tokens() %>% tokens_proximity(pattern = phrase("Tayyip Erdogan"), tolower = FALSE)
res %>% dfm(tolower = TRUE) -> output
expect_true("turkish" %in% colnames(output))
})

0 comments on commit 5f08a69

Please sign in to comment.