Skip to content

Commit

Permalink
Merge pull request #488 from tidymodels/inner_split-bootstrap
Browse files Browse the repository at this point in the history
Add `inner_split()` methods for bootstrap
  • Loading branch information
hfrick authored May 23, 2024
2 parents 776d46f + 5b6dabe commit 7d35a8e
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 5 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ S3method(dim,rsplit)
S3method(get_rsplit,default)
S3method(get_rsplit,rset)
S3method(inner_split,apparent_split)
S3method(inner_split,boot_split)
S3method(inner_split,clustering_split)
S3method(inner_split,group_boot_split)
S3method(inner_split,group_mc_split)
S3method(inner_split,group_vfold_split)
S3method(inner_split,mc_split)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# rsample (development version)

* The new `inner_split()` function and its methods for various resamples is for usage in tune to create a inner resample of the analysis set to fit the preprocessor and model on one part and the post-processor on the other part (#483).
* The new `inner_split()` function and its methods for various resamples is for usage in tune to create a inner resample of the analysis set to fit the preprocessor and model on one part and the post-processor on the other part (#483, #488).

## Bug fixes

Expand Down
51 changes: 49 additions & 2 deletions R/inner_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
#' * For `vfold_split` and `group_vfold_split` objects, it will ignore
#' `split_args$times` and `split_args$repeats`. `split_args$v` will be used to
#' set `split_args$prop` to `1 - 1/v` if `prop` is not already set and otherwise
#' ignored. The method
#' for `group_vfold_split` will always use `split_args$balance = NULL`.
#' ignored. The method for `group_vfold_split` will always use
#' `split_args$balance = NULL`.
#' * For `boot_split` and `group_boot_split` objects, it will ignore
#' `split_args$times`.
#' * For `clustering_split` objects, it will ignore `split_args$repeats`.
#'
#' @keywords internal
Expand Down Expand Up @@ -119,6 +121,51 @@ inner_split.group_vfold_split <- function(x, split_args, ...) {
split_inner
}

# bootstrap --------------------------------------------------------------

#' @rdname inner_split
#' @export
inner_split.boot_split <- function(x, split_args, ...) {
check_dots_empty()

# use unique rows to prevent the same information from entering
# both the inner analysis and inner assessment set
id_outer_analysis <- unique(x$in_id)
analysis_set <- x$data[id_outer_analysis, , drop = FALSE]

split_args$times <- 1
split_inner <- rlang::inject(
bootstraps(analysis_set, !!!split_args)
)
split_inner <- split_inner$splits[[1]]

class_inner <- paste0(class(x)[1], "_inner")
class(split_inner) <- c(class_inner, class(x))
split_inner
}

#' @rdname inner_split
#' @export
inner_split.group_boot_split <- function(x, split_args, ...) {
check_dots_empty()

# use unique rows to prevent the same information from entering
# both the inner analysis and inner assessment set
id_outer_analysis <- unique(x$in_id)
analysis_set <- x$data[id_outer_analysis, , drop = FALSE]

split_args$times <- 1
split_inner <- rlang::inject(
group_bootstraps(analysis_set, !!!split_args)
)
split_inner <- split_inner$splits[[1]]

class_inner <- paste0(class(x)[1], "_inner")
class(split_inner) <- c(class_inner, class(x))
split_inner
}


# clustering -------------------------------------------------------------

#' @rdname inner_split
Expand Down
12 changes: 10 additions & 2 deletions man/inner_split.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 57 additions & 0 deletions tests/testthat/test-inner_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,63 @@ test_that("group_vfold_split", {
})


# bootstrap --------------------------------------------------------------

test_that("boot_split", {
set.seed(11)
r_set <- bootstraps(warpbreaks, times = 2)
split_args <- .get_split_args(r_set)
r_split <- get_rsplit(r_set, 1)

isplit <- inner_split(r_split, split_args)

expect_lte(
nrow(isplit$data),
analysis(r_split) %>% nrow()
)

expect_identical(
analysis(isplit),
isplit$data[isplit$in_id, ],
ignore_attr = "row.names"
)
expect_identical(
assessment(isplit),
isplit$data[complement(isplit), ],
ignore_attr = "row.names"
)
})

test_that("group_boot_split", {
skip_if_not_installed("modeldata")

data(ames, package = "modeldata", envir = rlang::current_env())

set.seed(11)
r_set <- group_bootstraps(ames, group = "MS_SubClass", times = 2)
split_args <- .get_split_args(r_set)
r_split <- get_rsplit(r_set, 1)

isplit <- inner_split(r_split, split_args)

expect_lte(
nrow(isplit$data),
analysis(r_split) %>% nrow()
)

expect_identical(
analysis(isplit),
isplit$data[isplit$in_id, ],
ignore_attr = "row.names"
)
expect_identical(
assessment(isplit),
isplit$data[complement(isplit), ],
ignore_attr = "row.names"
)
})


# clustering -------------------------------------------------------------

test_that("clustering_split", {
Expand Down

0 comments on commit 7d35a8e

Please sign in to comment.