From cf1acd98c5991814c110a3082cba283beab28f6e Mon Sep 17 00:00:00 2001 From: chainsawriot Date: Wed, 22 Nov 2023 11:11:06 +0100 Subject: [PATCH] Fix #46 --- R/get_dist.R | 6 +++++- tests/testthat/test-dfm.R | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/R/get_dist.R b/R/get_dist.R index ed4657f..e2f7132 100644 --- a/R/get_dist.R +++ b/R/get_dist.R @@ -238,9 +238,13 @@ dfm.tokens_with_proximity <- function(x, tolower = TRUE, remove_padding = FALSE, x_attrs <- attributes(x) x_docvars <- quanteda::docvars(x) x_docnames <- attr(x, "docvars")$docname_ - type <- attr(x, "types") temp <- unclass(x) index <- unlist(temp, use.names = FALSE) + type <- attr(x, "types") + if (0 %in% index) { + index <- index + 1 + type <- c("", type) + } val <- weight_function(unlist(quanteda::docvars(x, "proximity"), use.names = FALSE)) temp <- Matrix::sparseMatrix( j = index, diff --git a/tests/testthat/test-dfm.R b/tests/testthat/test-dfm.R index eaa3b38..8ea5b74 100644 --- a/tests/testthat/test-dfm.R +++ b/tests/testthat/test-dfm.R @@ -29,3 +29,10 @@ test_that("tolower", { res %>% dfm(tolower = TRUE) -> output expect_true("turkish" %in% colnames(output)) }) + +test_that("Padding #46", { + suppressPackageStartupMessages(library(quanteda)) + toks <- tokens(c("a b c", "A B C D")) %>% tokens_remove("b", padding = TRUE) + expect_error(toks %>% tokens_proximity("a") %>% dfm(), NA) +}) +