Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check smooth #441

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ export(check_posterior_predictions)
export(check_predictions)
export(check_residuals)
export(check_singularity)
export(check_smooth)
export(check_sphericity)
export(check_sphericity_bartlett)
export(check_symmetry)
Expand Down
243 changes: 243 additions & 0 deletions R/check_smooth.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
#' Check Smooth Term in GAMs
#'
#' Randomization test looking for evidence of residual pattern attributable to covariates of each smooth.
#'
#' @param x An [mgcv] GAM model.
#' @param iterations Number of permutations.
#' @param ... Other arguments to be passed to other functions (not used for now).
#'
#' @examples
#' if (require("mgcv")) {
#' model <- mgcv::gam(Sepal.Length ~ s(Petal.Length, k = 3) + s(Petal.Width, k = 5), data = iris)
#' check_smooth(model)
#' }
#' @export
check_smooth <- function(x, iterations = 400, ...) {
# Based on k.check in https://github.com/cran/mgcv/blob/master/R/plots.r
n_smooths <- length(x$smooth) # TODO: fix that for brms

if (n_smooths == 0) stop("No smooth terms were detected.")

Check warning on line 19 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/check_smooth.R,line=19,col=23,[condition_call_linter] Use stop(., call. = FALSE) not to display the call in an error message.

Check warning on line 19 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/check_smooth.R,line=19,col=23,[condition_call_linter] Use stop(., call. = FALSE) not to display the call in an error message.

rsd <- insight::get_residuals(x)
if (insight::model_info(x)$is_bayesian) {
# TODO: change depending on https://github.com/easystats/insight/issues/540
rsd <- rsd[, 1]
}

# Initialize containers
ve <- rep(0, iterations)
p.val <- v.obs <- kc <- edf <- rep(0, n_smooths)
snames <- rep("", n_smooths)

# Data length
n <- nrow(x$model)
modelmatrix <- x$model

nr <- length(rsd)

# Iterate through smooth terms
for (s in 1:n_smooths) {
ok <- TRUE
x$smooth[[s]]$by <- "NA" # Can't deal with 'by' variables
dat <- mgcv_ExtractData(x$smooth[[s]], modelmatrix, NULL)$data

# Sanity check that 'dat' is of the good format
if (!is.null(attr(dat, "index")) || !is.null(attr(dat[[1]], "matrix")) || is.matrix(dat[[1]])) ok <- FALSE
if (ok) dat <- as.data.frame(dat)

# Get info
snames[s] <- x$smooth[[s]]$label # Smooth name
idx <- x$smooth[[s]]$first.para:x$smooth[[s]]$last.para # which parameter does it correspond to
kc[s] <- length(idx)
edf[s] <- sum(x$edf[idx])
nc <- x$smooth[[s]]$dim

# drop any by variables
if (ok && ncol(dat) > nc) dat <- dat[, 1:nc, drop = FALSE]
# Check if any factor
for (j in 1:nc) if (is.factor(dat[[j]])) ok <- FALSE
if (!ok) {
p.val[s] <- v.obs[s] <- NA # can't do this test with summation convention/factors
next # Skip iteration
}


if (nc == 1) {
# 1-Dimensional term -----------
e <- diff(rsd[order(dat[, 1])])
v.obs[s] <- mean(e^2) / 2
# Reshuffle n-times
for (i in 1:iterations) {
e <- diff(rsd[sample(1:nr, nr)]) # shuffle

Check warning on line 71 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/check_smooth.R,line=71,col=23,[sample_int_linter] sample.int(n, m, ...) is preferable to sample(1:n, m, ...).

Check warning on line 71 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/check_smooth.R,line=71,col=23,[sample_int_linter] sample.int(n, m, ...) is preferable to sample(1:n, m, ...).
ve[i] <- mean(e^2) / 2
}
} else {
# multidimensional term ---------
# If tensor product (have to consider scaling)
if (!is.null(x$smooth[[s]]$margin)) {
# get the scale factors...
beta <- stats::coef(x)[idx]

Check warning on line 79 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/check_smooth.R,line=79,col=9,[object_overwrite_linter] 'beta' is an exported object from package 'base'. Avoid re-using such symbols.
f0 <- mgcv::PredictMat(x$smooth[[s]], dat) %*% beta
gr.f <- rep(0, ncol(dat))
for (i in 1:nc) {
datp <- dat
dx <- diff(range(dat[, i])) / 1000
datp[, i] <- datp[, i] + dx
fp <- mgcv::PredictMat(x$smooth[[s]], datp) %*% beta
gr.f[i] <- mean(abs(fp - f0)) / dx
}
# Rescale distances
for (i in 1:nc) {
dat[, i] <- dat[, i] - min(dat[, i])
dat[, i] <- gr.f[i] * dat[, i] / max(dat[, i])
}
}
nn <- 3
ni <- mgcv_nearest(nn, as.matrix(dat))$ni # TODO: this function calls an mgcv internal

e <- rsd - rsd[ni[, 1]]
for (j in 2:nn) e <- c(e, rsd - rsd[ni[, j]])
v.obs[s] <- mean(e^2) / 2

# Reshuffle n-times
for (i in 1:iterations) {
rsdr <- rsd[sample(1:nr, nr)] ## shuffle

Check warning on line 104 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/check_smooth.R,line=104,col=21,[sample_int_linter] sample.int(n, m, ...) is preferable to sample(1:n, m, ...).
e <- rsdr - rsdr[ni[, 1]]
for (j in 2:nn) e <- c(e, rsdr - rsdr[ni[, j]])
ve[i] <- mean(e^2) / 2
}
}
p.val[s] <- mean(ve < v.obs[s])
v.obs[s] <- v.obs[s] / mean(rsd^2)
}
out <- data.frame(Term = snames, EDF_max = kc, EDF = edf, k = v.obs, p = p.val)
out
}




# Utils -------------------------------------------------------------------


# https://github.com/cran/mgcv/blob/c263c882daf8b2ed55e6e3d1fb712cf20d79a710/R/smooth.r#L3614
#' @keywords internal
mgcv_ExtractData <- function(object, data, knots) {

insight::check_if_installed("mgcv")

# https://github.com/cran/mgcv/blob/c263c882daf8b2ed55e6e3d1fb712cf20d79a710/R/smooth.r#L318
get.var <- function(txt, data, vecMat = TRUE)
# txt contains text that may be a variable name and may be an expression
# for creating a variable. get.var first tries data[[txt]] and if that
# fails tries evaluating txt within data (only). Routine returns NULL
# on failure, or if result is not numeric or a factor.
# matrices are coerced to vectors, which facilitates matrix arguments
# to smooths.
{

Check warning on line 137 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/check_smooth.R,line=137,col=3,[brace_linter] Opening curly braces should never go on their own line and should always be followed by a new line.
x <- data[[txt]]
if (is.null(x)) {
x <- try(eval(parse(text = txt), data, enclos = NULL), silent = TRUE)
if (inherits(x, "try-error")) x <- NULL
}
if (!is.numeric(x) && !is.factor(x)) x <- NULL
if (is.matrix(x)) {
if (ncol(x) == 1) {
x <- as.numeric(x)
ismat <- FALSE
} else {
ismat <- TRUE
}
} else {
ismat <- FALSE
}
if (vecMat && is.matrix(x)) x <- x[1:prod(dim(x))] ## modified from x <- as.numeric(x) to allow factors

Check warning on line 154 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/check_smooth.R,line=154,col=40,[seq_linter] Use seq_len(prod(...)) instead of 1:prod(...), which is likely to be wrong in the empty edge case.
if (ismat) attr(x, "matrix") <- TRUE
x
}


## `data' and `knots' contain the data needed to evaluate the `terms', `by'
## and `knots' elements of `object'. This routine does so, and returns
## a list with element `data' containing just the evaluated `terms',
## with the by variable as the last column. If the `terms' evaluate matrices,
## then a check is made of whether repeat evaluations are being made,
## and if so only the unique evaluation points are returned in data, along
## with the `index' attribute required to re-assemble the full dataset.
knt <- dat <- list()
## should data be processed as for summation convention with matrix arguments?
vecMat <- if (is.null(object$xt$sumConv)) TRUE else object$xt$sumConv
for (i in 1:length(object$term)) {

Check warning on line 170 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/check_smooth.R,line=170,col=13,[seq_linter] Use seq_along(...) instead of 1:length(...), which is likely to be wrong in the empty edge case.
dat[[object$term[i]]] <- get.var(object$term[i], data, vecMat = vecMat)
knt[[object$term[i]]] <- get.var(object$term[i], knots, vecMat = vecMat)
}
names(dat) <- object$term
m <- length(object$term)
if (!is.null(attr(dat[[1]], "matrix")) && vecMat) { ## strip down to unique covariate combinations
n <- length(dat[[1]])
X <- matrix(unlist(dat), n, m)
if (is.numeric(X)) {
X <- mgcv::uniquecombs(X)
if (nrow(X) < n * .9) { ## worth the hassle

Check warning on line 181 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/check_smooth.R,line=181,col=25,[numeric_leading_zero_linter] Include the leading zero for fractional numeric constants.
for (i in 1:m) dat[[i]] <- X[, i] ## return only unique rows
attr(dat, "index") <- attr(X, "index") ## index[i] is row of dat[[i]] containing original row i
}
} ## end if(is.numeric(X))
}
if (object$by != "NA") {
by <- get.var(object$by, data)

Check warning on line 188 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/check_smooth.R,line=188,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
if (!is.null(by)) {
dat[[m + 1]] <- by
names(dat)[m + 1] <- object$by
}
}
list(data = dat, knots = knt)
}



# https://github.com/cran/mgcv/blob/2db5036f529dff6ec8a4a4ba0d2df1804a4d2668/R/sparse.r#L129
#' @keywords internal
mgcv_nearest <- function(k, X, gt.zero = FALSE, get.a = FALSE) {

insight::check_if_installed("mgcv")

## The rows of X contain coordinates of points.
## For each point, this routine finds its k nearest
## neighbours, returning a list of 2, n by k matrices:
## ni - ith row indexes the rows of X containing
## the k nearest neighbours of X[i,]
## dist - ith row is the distances to the k nearest
## neighbours.
## a - area associated with each point, if get.a is TRUE
## ties are broken arbitrarily.
## gt.zero indicates that neighbours must have distances greater
## than zero...

if (gt.zero) {
Xu <- mgcv::uniquecombs(X)
ind <- attr(Xu, "index") ## Xu[ind,] == X
} else {
Xu <- X
ind <- 1:nrow(X)

Check warning on line 222 in R/check_smooth.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/check_smooth.R,line=222,col=12,[seq_linter] Use seq_len(nrow(...)) instead of 1:nrow(...), which is likely to be wrong in the empty edge case.
}
if (k > nrow(Xu)) stop("not enough unique values to find k nearest")
nobs <- length(ind)
n <- nrow(Xu)
d <- ncol(Xu)
dist <- matrix(0, n, k)
if (get.a) a <- 1:n else a <- 1

# TODO: how to get that without the call to an internal????
oo <- .C(mgcv:::C_k_nn,
Xu = as.double(Xu), dist = as.double(dist), a = as.double(a), ni = as.integer(dist),
n = as.integer(n), d = as.integer(d), k = as.integer(k), get.a = as.integer(get.a)
)

dist <- matrix(oo$dist, n, k)[ind, ]
rind <- 1:nobs
rind[ind] <- 1:nobs
ni <- matrix(rind[oo$ni + 1], n, k)[ind, ]
if (get.a) a <- oo$a[ind] else a <- NULL
list(ni = ni, dist = dist, a = a)
}
24 changes: 24 additions & 0 deletions man/check_smooth.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions tests/testthat/test-check_smooth.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
if (requiet("testthat") && requiet("performance") && requiet("mgcv") && requiet("brms")) {
test_that("check_smooth | mgcv", {
skip_if_not_installed("mgcv")
set.seed(333)

m1 <- mgcv::gam(Sepal.Length ~ s(Petal.Length, k = 3) + s(Petal.Width, k=5), data = iris)
m2 <- mgcv::gam(Sepal.Length ~ t2(Petal.Length, Petal.Width, k=5), data = iris)
m3 <- mgcv::gam(Sepal.Length ~ s(Petal.Length, k = 3, by = Species), data = iris)

rez1 <- check_smooth(m1, iterations=1000)
rez2 <- check_smooth(m2, iterations=1000)
rez3 <- check_smooth(m3, iterations=1000)

c1 <- as.data.frame(mgcv::k.check(m1, n.rep=1000))
c2 <- as.data.frame(mgcv::k.check(m2, n.rep=1000))
c3 <- as.data.frame(mgcv::k.check(m3, n.rep=1000))

# Deterministic
expect_equal(max(rez1$EDF_max - c1$`k'`), 0, tolerance = 0)
expect_equal(max(rez1$EDF - c1$edf), 0, tolerance = 0)
expect_equal(max(rez2$EDF_max - c2$`k'`), 0, tolerance = 0)
expect_equal(max(rez2$EDF - c2$edf), 0, tolerance = 0)
expect_equal(max(rez3$EDF_max - c3$`k'`), 0, tolerance = 0)
expect_equal(max(rez3$EDF - c3$edf), 0, tolerance = 0)

# Random
expect_equal(max(rez1$k - c1$`k-index`), 0, tolerance = 0.1)
expect_equal(max(rez1$p - c1$`p-value`), 0, tolerance = 0.1)
expect_equal(max(rez2$k - c2$`k-index`), 0, tolerance = 0.1)
expect_equal(max(rez2$p - c2$`p-value`), 0, tolerance = 0.1)
expect_equal(max(rez3$k - c3$`k-index`), 0, tolerance = 0.1)
expect_equal(max(rez3$p - c3$`p-value`), 0, tolerance = 0.1)
})

test_that("check_smooth | brms", {
skip_if_not_installed("brms")

# m1 <- brms::brm(Sepal.Length ~ s(Petal.Length, k = 3) + s(Petal.Width, k=5), data = iris,
# refresh = 0, iter = 200, algorithm = "meanfield")

# brms models currently not supported
})
}
Loading