Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add total.label argument for groupingsets, cube, rollup #5973

Merged
merged 17 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,42 @@ rowwiseDT(

2. Limited support for subsetting or aggregating columns of type `expression`, [#5596](https://github.com/Rdatatable/data.table/issues/5596). Thanks to @tsp for the report, and @ben-schwen for the fix.

3. `groupingsets.data.table()`, `cube.data.table()`, and `rollup.data.table()` gain a `label` argument, which allows the user to specify a label for each grouping variable, to be included in the grouping variable column in the output in rows where the variable has been aggregated, [#5351](https://github.com/Rdatatable/data.table/issues/5351). Thanks to @markseeto for the request, @jangorecki and @markseeto for specifying the desired behaviour, and @markseeto for implementing.

```r
DT = data.table(V1 = rep(c("a1", "a2"), each = 5),
V2 = rep(rep(c("b1", "b2"), c(3, 2)), 2),
V3 = rep(c("c1", "c2"), c(3, 7)),
V4 = rep(1:2, c(6, 4)),
V5 = rep(1:2, c(9, 1)),
V6 = rep(c(1.1, 1.2), c(2, 8)))

# Call groupingsets() and specify a label for V1, a different label for the other character grouping
# variables, a label for the integer grouping variables, and a label for the numeric grouping variable.

groupingsets(DT, .N, by = c("V1", "V2", "V3", "V4", "V5", "V6"),
sets = list(c("V1", "V2", "V3"), c("V1", "V4"), c("V4", "V6"), "V2", "V5", character()),
label = list(V1 = "All values", character = "Total", integer = 999L, numeric = NaN))

# V1 V2 V3 V4 V5 V6 N
# <char> <char> <char> <int> <int> <num> <int>
# 1: a1 b1 c1 999 999 NaN 3
# 2: a1 b2 c2 999 999 NaN 2
# 3: a2 b1 c2 999 999 NaN 3
# 4: a2 b2 c2 999 999 NaN 2
# 5: a1 Total Total 1 999 NaN 5
# 6: a2 Total Total 1 999 NaN 1
# 7: a2 Total Total 2 999 NaN 4
# 8: All values Total Total 1 999 1.1 2
# 9: All values Total Total 1 999 1.2 4
# 10: All values Total Total 2 999 1.2 4
# 11: All values b1 Total 999 999 NaN 6
# 12: All values b2 Total 999 999 NaN 4
# 13: All values Total Total 999 1 NaN 9
# 14: All values Total Total 999 2 NaN 1
# 15: All values Total Total 999 999 NaN 10
```

## BUG FIXES

1. Using `print.data.table()` with character truncation using `datatable.prettyprint.char` no longer errors with `NA` entries, [#6441](https://github.com/Rdatatable/data.table/issues/6441). Thanks to @r2evans for the bug report, and @joshhwuu for the fix.
Expand Down
71 changes: 66 additions & 5 deletions R/groupingsets.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
rollup = function(x, ...) {
UseMethod("rollup")
}
rollup.data.table = function(x, j, by, .SDcols, id = FALSE, ...) {
rollup.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) {
# input data type basic validation
if (!is.data.table(x))
stopf("Argument 'x' must be a data.table object")
Expand All @@ -13,13 +13,13 @@ rollup.data.table = function(x, j, by, .SDcols, id = FALSE, ...) {
sets = lapply(length(by):0L, function(i) by[0L:i])
# redirect to workhorse function
jj = substitute(j)
groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj)
groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj, label=label)
}

cube = function(x, ...) {
UseMethod("cube")
}
cube.data.table = function(x, j, by, .SDcols, id = FALSE, ...) {
cube.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) {
# input data type basic validation
if (!is.data.table(x))
stopf("Argument 'x' must be a data.table object")
Expand All @@ -35,13 +35,13 @@ cube.data.table = function(x, j, by, .SDcols, id = FALSE, ...) {
sets = lapply((2L^n):1L, function(jj) by[keepBool[jj, ]])
# redirect to workhorse function
jj = substitute(j)
groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj)
groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj, label=label)
}

groupingsets = function(x, ...) {
UseMethod("groupingsets")
}
groupingsets.data.table = function(x, j, by, sets, .SDcols, id = FALSE, jj, ...) {
groupingsets.data.table = function(x, j, by, sets, .SDcols, id = FALSE, jj, label = NULL, ...) {
# input data type basic validation
if (!is.data.table(x))
stopf("Argument 'x' must be a data.table object")
Expand All @@ -57,6 +57,14 @@ groupingsets.data.table = function(x, j, by, sets, .SDcols, id = FALSE, jj, ...)
stopf("Argument 'sets' must be a list of character vectors.")
if (!is.logical(id))
stopf("Argument 'id' must be a logical scalar.")
if (!(is.null(label) ||
(is.atomic(label) && length(label) == 1L) ||
(is.list(label) && all(vapply_1b(label, is.atomic)) && all(lengths(label) == 1L) && !is.null(names(label)))))
stopf("Argument 'label', if not NULL, must be a scalar or a named list of scalars.")
if (is.list(label) && !is.null(names(label)) && ("" %chin% names(label) || anyNA(names(label))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this !is.null(names(label)) check redundant? Since we have is.list(label) && ... !is.null(names(label)) in the above requirement?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do find it a bit surprising that the above check requires "a named list of scalars" but we have a separate test for "all list elements must be named", maybe best to add in the check for ""/NA names to the above condition?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can address this as a small follow-up PR if you agree, don't want to hold the PR back further.

Copy link
Contributor Author

@markseeto markseeto Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this !is.null(names(label)) check redundant?

Yes, I think you're right.

I do find it a bit surprising that the above check requires "a named list of scalars" but we have a separate test for "all list elements must be named", maybe best to add in the check for ""/NA names to the above condition?

Maybe, but with separate checks the error messages can be more specific, the second one being for the situation where label is a named list but not all elements have a name. If we combine the error messages into one, it would be something like "Argument 'label', if not NULL, must be (1) a scalar, or (2) a named list with each element being named and each element being a scalar." Or "Argument 'label', if not NULL, must be (1) a scalar, or (2) a named list with no names being "" or NA and each element being a scalar." I think these are less clear and less helpful than having separate error messages depending on the situation. If you disagree, please let me know. It's not something I feel strongly about.

stopf("When argument 'label' is a list, all of the list elements must be named.")
if (is.list(label) && anyDuplicated(names(label)))
stopf("When argument 'label' is a list, the element names must not contain duplicates.")
# logic constraints validation
if (!all((sets.all.by <- unique(unlist(sets))) %chin% by))
stopf("All columns used in 'sets' argument must be in 'by' too. Columns used in 'sets' but not present in 'by': %s", brackify(setdiff(sets.all.by, by)))
Expand All @@ -66,6 +74,36 @@ groupingsets.data.table = function(x, j, by, sets, .SDcols, id = FALSE, jj, ...)
stopf("Character vectors in 'sets' list must not have duplicated column names within a single grouping set.")
if (length(sets) > 1L && (idx<-anyDuplicated(lapply(sets, sort))))
warningf("'sets' contains a duplicate (i.e., equivalent up to sorting) element at index %d; as such, there will be duplicate rows in the output -- note that grouping by A,B and B,A will produce the same aggregations. Use `sets=unique(lapply(sets, sort))` to eliminate duplicates.", idx)
if (is.list(label)) {
other.allowed.names = c("character", "integer", "numeric", "factor", "Date", "IDate")
allowed.label.list.names = c(by, vapply_1c(.shallow(x, by), function(u) class(u)[1]),
other.allowed.names)
label.names = names(label)
if (!all(label.names %in% allowed.label.list.names))
stopf("When argument 'label' is a list, all element names must be (1) in 'by', or (2) the first element of the class in the data.table 'x' of a variable in 'by', or (3) one of %s. Element names not satisfying this condition: %s",
brackify(other.allowed.names), brackify(setdiff(label.names, allowed.label.list.names)))
label.classes = lapply(label, class)
label.names.in.by = intersect(label.names, by)
label.names.not.in.by = setdiff(label.names, label.names.in.by)
label.names.in.by.classes = label.classes[label.names.in.by]
x.label.names.in.by.classes = lapply(.shallow(x, label.names.in.by), class)
label.names.not.in.by.classes1 = vapply_1c(label.classes[label.names.not.in.by], function(u) u[1])
if (!all(idx <- mapply(identical, label.names.in.by.classes, x.label.names.in.by.classes))) {
info = gettextf(
"%s (label: %s; data: %s)",
label.names.in.by[!idx],
vapply_1c(label.names.in.by.classes[!idx], toString),
vapply_1c(x.label.names.in.by.classes[!idx], toString))
stopf("When argument 'label' is a list, the class of each 'label' element with name in 'by' must match the class of the corresponding column of the data.table 'x'. Class mismatch for: %s", brackify(info))
}
if (!all(idx <- label.names.not.in.by == label.names.not.in.by.classes1)) {
info = gettextf(
"(label name: %s; label class[1]: %s)",
label.names.not.in.by[!idx],
label.names.not.in.by.classes1[!idx])
stopf("When argument 'label' is a list, the name of each element of 'label' not in 'by' must match the first element of the class of the element value. Mismatches: %s", brackify(info))
}
}
# input arguments handling
jj = if (!missing(jj)) jj else substitute(j)
av = all.vars(jj, TRUE)
Expand All @@ -85,6 +123,27 @@ groupingsets.data.table = function(x, j, by, sets, .SDcols, id = FALSE, jj, ...)
set(empty, j = "grouping", value = integer())
setcolorder(empty, c("grouping", by, setdiff(names(empty), c("grouping", by))))
}
# Define variables related to label
if (!is.null(label)) {
total.vars = intersect(by, unlist(lapply(sets, function(u) setdiff(by, u))))
if (is.list(label)) {
by.vars.not.in.label = setdiff(by, names(label))
by.vars.not.in.label.class1 = vapply_1c(x, function(u) class(u)[1L])[by.vars.not.in.label]
labels.by.vars.not.in.label = label[by.vars.not.in.label.class1[by.vars.not.in.label.class1 %in% label.names.not.in.by]]
names(labels.by.vars.not.in.label) <- by.vars.not.in.label[by.vars.not.in.label.class1 %in% label.names.not.in.by]
label.expanded = c(label[label.names.in.by], labels.by.vars.not.in.label)
label.expanded = label.expanded[intersect(by, names(label.expanded))] # reorder
} else {
by.vars.matching.scalar.class1 = by[vapply_1c(x, function(u) class(u)[1L])[by] == class(label)[1L]]
label.expanded = as.list(rep(label, length(by.vars.matching.scalar.class1)))
names(label.expanded) <- by.vars.matching.scalar.class1
}
label.use = label.expanded[intersect(total.vars, names(label.expanded))]
if (any(idx <- vapply_1b(names(label.expanded), function(u) label.expanded[[u]] %in% x[[u]]))) {
info = gettextf("%s (label: %s)", names(label.expanded)[idx], vapply_1c(label.expanded[idx], as.character))
warningf("For the following variables, the 'label' value was already in the data: %s", brackify(info))
}
}
# workaround for rbindlist fill=TRUE on integer64 #1459
int64.cols = vapply_1b(empty, inherits, "integer64")
int64.cols = names(int64.cols)[int64.cols]
Expand All @@ -105,6 +164,8 @@ groupingsets.data.table = function(x, j, by, sets, .SDcols, id = FALSE, jj, ...)
missing.int64.by.cols = setdiff(int64.by.cols, by.set)
if (length(missing.int64.by.cols)) r[, (missing.int64.by.cols) := bit64::as.integer64(NA)]
}
if (!is.null(label) && length(by.label.use.vars <- intersect(setdiff(by, by.set), names(label.use))) > 0L)
r[, (by.label.use.vars) := label.use[by.label.use.vars]]
r
}
# actually processing everything here
Expand Down
Loading
Loading