Skip to content

Commit

Permalink
Merge pull request #446 from tidymodels/initial_validation_time_split
Browse files Browse the repository at this point in the history
Add `initial_validation_time_split()`
  • Loading branch information
hfrick authored Aug 16, 2023
2 parents a6b5288 + 55408cb commit 1fc7cf2
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 15 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ export(group_vfold_cv)
export(initial_split)
export(initial_time_split)
export(initial_validation_split)
export(initial_validation_time_split)
export(int_bca)
export(int_pctl)
export(int_t)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

* `training()`, `testing()`, `analysis()`, and `assessment()` are now S3 generics with methods for `rsplit` objects. Previously they manually required the input to be an `rsplit` object (#384).

* The new `initial_validation_split()` generates a three-way split of the data into training, validation, and test sets. With the new `validation_set()`, this can be turned into an `rset` object for tuning (#403).
* The new `initial_validation_split()`, along with variants `initial_validation_time_split()` and `group_initial_validation_split()`, generates a three-way split of the data into training, validation, and test sets. With the new `validation_set()`, this can be turned into an `rset` object for tuning (#403, #446).

* Functions which don't use the ellipsis `...` now enforce empty dots (#429).

Expand Down
58 changes: 53 additions & 5 deletions R/initial_validation_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
#'
#' `initial_validation_split()` creates a random three-way split of the data
#' into a training set, a validation set, and a testing set.
#' `group_initial_validation_split()` creates similar splits of the data based on some
#' grouping variable, so that all data in a "group" are assigned to the same
#' partition.
#' `initial_validation_time_split()` does the same, but instead of a random
#' selection the training, validation, and testing set are in order of the full
#' data set, with the first observations being put into the training set.
#' `group_initial_validation_split()` creates similar random splits of the data
#' based on some grouping variable, so that all data in a "group" are assigned
#' to the same partition.
#' `training()`, `validation()`, and `testing()` can be used to extract the
#' resulting data sets.
#' Use [`validation_set()`] create an `rset` object for use with functions from
#' Use [`validation_set()`] to create an `rset` object for use with functions from
#' the tune package such as `tune::tune_grid()`.
#'
#' @template strata_details
#'
#' @inheritParams vfold_cv
#' @inheritParams make_strata
#' @param prop A length 2 vector of proportions of data to be retained for training and
#' @param prop A length-2 vector of proportions of data to be retained for training and
#' validation data, respectively.
#' @inheritParams rlang::args_dots_empty
#' @param x An object of class `initial_validation_split`.
Expand All @@ -33,6 +36,12 @@
#' validation_data <- validation(car_split)
#' test_data <- testing(car_split)
#'
#' data(drinks, package = "modeldata")
#' drinks_split <- initial_validation_time_split(drinks)
#' train_data <- training(drinks_split)
#' validation_data <- validation(drinks_split)
#' c(max(train_data$date), min(validation_data$date))
#'
#' data(ames, package = "modeldata")
#' set.seed(1353)
#' ames_split <- group_initial_validation_split(ames, group = Neighborhood)
Expand Down Expand Up @@ -137,6 +146,45 @@ check_prop_3 <- function(prop, call = rlang::caller_env()) {
invisible(prop)
}

#' @rdname initial_validation_split
#' @export
initial_validation_time_split <- function(data,
prop = c(0.6, 0.2),
...) {
rlang::check_dots_empty()

check_prop_3(prop)
prop_train <- prop[1]
prop_val <- prop[2] / (1 - prop_train)

n_train <- floor(nrow(data) * prop_train)
n_val <- floor((nrow(data) - n_train) * prop_val)

train_id <- seq(1, n_train, by = 1)
val_id <- seq(n_train + 1, n_train + n_val, by = 1)

res <- list(
data = data,
train_id = train_id,
val_id = val_id,
test_id = NA,
id = "split"
)

# include those so that they can be attached to the `rset` later in `validation_set()`
val_att <- list(
prop = prop
)
attr(res, "val_att") <- val_att

class(res) <- c(
"initial_validation_time_split",
"initial_validation_split",
"three_way_split"
)
res
}

#' @inheritParams make_groups
#' @rdname initial_validation_split
#' @export
Expand Down
2 changes: 1 addition & 1 deletion R/validation_set.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ validation.val_split <- function(x, ...) {
testing.val_split <- function(x, ...) {
rlang::abort(
"The testing data is not part of the validation set object.",
i = "It is part of the result of `initial_validation_split()`."
i = "It is part of the result of the initial 3-way split, e.g., with `initial_validation_split()`."
)
}
22 changes: 17 additions & 5 deletions man/initial_validation_split.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 25 additions & 3 deletions tests/testthat/test-initial_validation_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ test_that("basic split - accessor functions", {
})
})


test_that("basic split stratified", {
dat <- data.frame(
id = 1:100,
Expand Down Expand Up @@ -75,9 +74,33 @@ test_that("basic split stratified", {
})
})

test_that("time split", {
dat1 <- data.frame(a = 1:109)

test_that("grouped split", {
set.seed(11)
rs1 <- initial_validation_time_split(dat1, prop = c(0.6, 0.2))

expect_s3_class(
rs1,
c("initial_validation_time_split", "initial_validation_split", "three_way_split")
)

exp_size_train <- floor(nrow(dat1) * 0.6)
exp_size_val <- floor((nrow(dat1) - exp_size_train) * 0.2 / (1 - 0.6))

expect_equal(rs1$train_id, seq(1, exp_size_train))
expect_equal(rs1$val_id, seq(exp_size_train + 1, exp_size_train + exp_size_val))
expect_equal(rs1$test_id, NA)

expect_equal(rs1$data, dat1)

good_val <- length(intersect(rs1$train_id, rs1$val_id))
expect_equal(good_val, 0)
good_test <- length(intersect(rs1$val_id, rs1$test_id))
expect_equal(good_test, 0)
})

test_that("grouped split", {
# all observations of each group should be in only one of the 3 data sets
# = all obs in the same group and no intersection in the groups
# from the 3 data sets
Expand Down Expand Up @@ -181,7 +204,6 @@ test_that("grouped split - accessor functions", {
})
})


test_that("check_prop_3() works", {
expect_snapshot(error = TRUE, check_prop_3(0.3))
expect_snapshot(error = TRUE, check_prop_3("zero"))
Expand Down

0 comments on commit 1fc7cf2

Please sign in to comment.