Skip to content

Commit

Permalink
Merge pull request #547 from tidymodels/check-strata
Browse files Browse the repository at this point in the history
Check strata input
  • Loading branch information
hfrick authored Sep 20, 2024
2 parents 3f30ddb + e9d1440 commit b02f57d
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 30 deletions.
2 changes: 1 addition & 1 deletion R/boot.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ bootstraps <-
if (length(strata) == 0) strata <- NULL
}

strata_check(strata, data)
check_strata(strata, data)

split_objs <-
boot_splits(
Expand Down
4 changes: 2 additions & 2 deletions R/initial_validation_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ initial_validation_split <- function(data,
strata <- NULL
}
}
strata_check(strata, data)
check_strata(strata, data)

split_train <- mc_cv(
data = data,
Expand Down Expand Up @@ -209,7 +209,7 @@ group_initial_validation_split <- function(data,
strata <- NULL
}
}
strata_check(strata, data)
check_strata(strata, data)

if (missing(strata)) {
split_train <- group_mc_cv(
Expand Down
2 changes: 1 addition & 1 deletion R/mc.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ mc_cv <- function(data, prop = 3 / 4, times = 25,
if (length(strata) == 0) strata <- NULL
}

strata_check(strata, data)
check_strata(strata, data)

split_objs <-
mc_splits(
Expand Down
17 changes: 9 additions & 8 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,17 @@ add_class <- function(x, cls) {
x
}

strata_check <- function(strata, data) {
check_strata <- function(strata, data, call = caller_env()) {
check_string(strata, allow_null = TRUE, call = call)

if (!is.null(strata)) {
if (!is.character(strata) | length(strata) != 1) {
cli_abort("{.arg strata} should be a single name or character value.")
}
if (inherits(data[, strata], "Surv")) {
cli_abort("{.arg strata} cannot be a {.cls Surv} object. Use the time or event variable directly.")
}
if (!(strata %in% names(data))) {
cli_abort("{strata} is not in {.arg data}.")
cli_abort(c(
"{.field strata} cannot be a {.cls Surv} object.",
"i" = "Use the time or event variable directly."
),
call = call
)
}
}
invisible(NULL)
Expand Down
2 changes: 1 addition & 1 deletion R/validation_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ validation_split <- function(data, prop = 3 / 4,
}
}

strata_check(strata, data)
check_strata(strata, data)

split_objs <-
mc_splits(
Expand Down
2 changes: 1 addition & 1 deletion R/vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ vfold_cv <- function(data, v = 10, repeats = 1,
if (length(strata) == 0) strata <- NULL
}

strata_check(strata, data)
check_strata(strata, data)
check_repeats(repeats)

if (repeats == 1) {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/boot.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
Code
bootstraps(warpbreaks, strata = c("tension", "wool"))
Condition
Error in `strata_check()`:
! `strata` should be a single name or character value.
Error in `bootstraps()`:
! `strata` must be a single string or `NULL`, not a character vector.

---

Expand Down
7 changes: 4 additions & 3 deletions tests/testthat/_snaps/make_strata.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@
# don't stratify on Surv objects

Code
strata_check("surv", df)
check_strata("surv", df)
Condition
Error in `strata_check()`:
! `strata` cannot be a <Surv> object. Use the time or event variable directly.
Error:
! strata cannot be a <Surv> object.
i Use the time or event variable directly.

4 changes: 2 additions & 2 deletions tests/testthat/_snaps/mc.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
Code
mc_cv(warpbreaks, strata = c("tension", "wool"))
Condition
Error in `strata_check()`:
! `strata` should be a single name or character value.
Error in `mc_cv()`:
! `strata` must be a single string or `NULL`, not a character vector.

# printing

Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/validation_split.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@
Code
validation_split(warpbreaks, strata = c("tension", "wool"))
Condition
Error in `strata_check()`:
! `strata` should be a single name or character value.
Error in `validation_split()`:
! `strata` must be a single string or `NULL`, not a character vector.

# printing

Expand Down
23 changes: 20 additions & 3 deletions tests/testthat/_snaps/vfold.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Stratifying groups that make up 1% of the data may be statistically risky.
* Consider increasing `pool` to at least 0.1

# bad args
# strata arg is checked

Code
vfold_cv(iris, strata = iris$Species)
Expand All @@ -21,11 +21,28 @@
Code
vfold_cv(iris, strata = c("Species", "Sepal.Width"))
Condition
Error in `strata_check()`:
! `strata` should be a single name or character value.
Error in `vfold_cv()`:
! `strata` must be a single string or `NULL`, not a character vector.

---

Code
vfold_cv(iris, strata = NA)
Condition
Error in `vfold_cv()`:
! Selections can't have missing values.

---

Code
vfold_cv(dat, strata = b)
Condition
Error in `vfold_cv()`:
! strata cannot be a <Surv> object.
i Use the time or event variable directly.

# bad args

Code
vfold_cv(iris, v = -500)
Condition
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-make_strata.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ test_that("bad data", {



# strata_check() ----------------------------------------------------------
# check_strata() ----------------------------------------------------------

test_that("don't stratify on Surv objects", {
df <- data.frame(
Expand All @@ -58,6 +58,6 @@ test_that("don't stratify on Surv objects", {
)

expect_snapshot(error = TRUE, {
strata_check("surv", df)
check_strata("surv", df)
})
})
27 changes: 25 additions & 2 deletions tests/testthat/test-vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,37 @@ test_that("strata", {
)
})


test_that("bad args", {
test_that("strata arg is checked", {
expect_snapshot(error = TRUE, {
vfold_cv(iris, strata = iris$Species)
})

# errors from `check_strata()`
expect_snapshot(error = TRUE, {
vfold_cv(iris, strata = c("Species", "Sepal.Width"))
})

expect_snapshot(error = TRUE, {
vfold_cv(iris, strata = NA)
})

# make Surv object without a dependeny on the survival package
surv_obj <- structure(
c(306, 455, 1010, 210, 883, 1, 1, 0, 1, 1),
dim = c(5L, 2L),
dimnames = list(NULL, c("time", "status")),
type = "right",
class = "Surv"
)
dat <- data.frame(a = 1:5)
# add Surv object like this for older R versions (<= 4.2.3)
dat$b <- surv_obj
expect_snapshot(error = TRUE, {
vfold_cv(dat, strata = b)
})
})

test_that("bad args", {
expect_snapshot(error = TRUE, {
vfold_cv(iris, v = -500)
})
Expand Down

0 comments on commit b02f57d

Please sign in to comment.