Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conformal prediction with conditional guarantees #455

Open
wants to merge 131 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
4a367bd
ADD: first implementation of the CCP method
May 28, 2024
fffd511
UPD: increase test_results_with_constant_sample_weights assert_allclo…
Damien-Bouet May 28, 2024
56c67cf
MOVE PhiFunction into utils folder
Damien-Bouet May 31, 2024
765c70a
ADD Polynomial and Gaussian PhiFunctions
Damien-Bouet May 31, 2024
c8c004d
UPD docstrings and return self on fit and calibrate
Damien-Bouet May 31, 2024
1767c1e
FIX: tests
Damien-Bouet May 31, 2024
8a46555
ADD: Paper simulations reproduction
Damien-Bouet May 31, 2024
d75a48b
FIX: tests and coverage
Damien-Bouet Jun 3, 2024
e29d0d7
FIX: paper simulation
Damien-Bouet Jun 3, 2024
fb4dbcb
Remove Literal for python 3.7 compatibility
Damien-Bouet Jun 3, 2024
16c7069
UPD: sample_weight_test tol values
Damien-Bouet Jun 3, 2024
b32626e
RMV: seaborn from paper simulation imports
Damien-Bouet Jun 3, 2024
4f4447e
FIX: linting
Damien-Bouet Jun 3, 2024
015bdb7
FIX: useless seaborn grid style
Damien-Bouet Jun 3, 2024
22a901d
FIX: Gaussian exp formula
Damien-Bouet Jun 4, 2024
db23dc0
MOVE: check_parameters in outside of init
Damien-Bouet Jun 4, 2024
7f2cf56
FIX: Gaussian exp formula test
Damien-Bouet Jun 4, 2024
24d8e19
MOVE: PhiFunctions import from regression to regression.utils
Damien-Bouet Jun 4, 2024
023207f
UPD: Add fit/calib _attributes and Base classes inheritence
Damien-Bouet Jun 5, 2024
79efe4e
UPD: Improve parameters checks
Damien-Bouet Jun 5, 2024
c278273
MOVE: check_estimator into utils
Damien-Bouet Jun 5, 2024
595b037
UPD: CCP docstrings
Damien-Bouet Jun 5, 2024
b6735c2
UPD: Return predictions only if alpha is None
Damien-Bouet Jun 5, 2024
e48a8e6
UPD: Improve and functions
Damien-Bouet Jun 5, 2024
5158462
UPD: Convert PhiFunction into a Abstract class
Damien-Bouet Jun 5, 2024
2fcc380
ADD: CustomPhiFunction
Damien-Bouet Jun 5, 2024
88e9273
UPD: Tests
Damien-Bouet Jun 5, 2024
a376dd5
reduce Gibbs paper simulation runtime
Damien-Bouet Jun 5, 2024
c2481f3
RENAME: move CCP on the template fit/predict (with fit_estimator, fit…
Damien-Bouet Jun 5, 2024
3885491
FIX: forgot to stage a line from 'UPD: Convert PhiFunction into a Abs…
Damien-Bouet Jun 5, 2024
f728d77
FIX: array error for np.std
Damien-Bouet Jun 5, 2024
aeb7979
FIX: some forgotten fit_calibrator renaming
Damien-Bouet Jun 6, 2024
f848813
ADD: _is_fitted function (almost the copy of the private sklearn.util…
Damien-Bouet Jun 6, 2024
dc19045
typing
Damien-Bouet Jun 6, 2024
4fdd852
typing again...
Damien-Bouet Jun 6, 2024
b0e22ec
UPD: CustomPhiFunction can now take PhiFunction instances in function…
Damien-Bouet Jun 6, 2024
293f085
Merge branch 'master' into 449-cp-with-conditional-guarantees
Jun 6, 2024
809a39f
RENAME: check_estimator_regression
Damien-Bouet Jun 6, 2024
4cf1a9f
RMV: Exemple from MapieCCPRegressor
Damien-Bouet Jun 6, 2024
c588ec8
UPD: make fit method mandatory and use check_is_fitted from sklearn
Damien-Bouet Jun 10, 2024
8048220
MOVE: Externalise some utils functions
Damien-Bouet Jun 10, 2024
1adf14d
MOVE: PhiFunctions into phi_function folder
Damien-Bouet Jun 10, 2024
ec4ce10
FIX: Coverage
Damien-Bouet Jun 10, 2024
b8be35e
ADD: PhiFunction multiplication
Damien-Bouet Jun 10, 2024
99ea3fe
FIX: coverage
Damien-Bouet Jun 10, 2024
db360f9
MOVE and RENAME: PhiFunctions in calibrators/ccp
Damien-Bouet Jun 11, 2024
0a7ec52
ADD: Abstract 'Calibrator' class
Damien-Bouet Jun 11, 2024
e9e0f43
UPD: externalise MapieCCPRegressor into abstract SplitMapie and move …
Damien-Bouet Jun 12, 2024
f7e8296
RENAME tests
Damien-Bouet Jun 12, 2024
e45bf48
ADD: Draft of Classification, to assess the generaliation capacities …
Damien-Bouet Jun 12, 2024
f08f8ae
UPD: improve abstract calibrator class signature
Damien-Bouet Jun 13, 2024
d437750
FIX: Coverage
Damien-Bouet Jun 13, 2024
33573a3
UPD: docstring
Damien-Bouet Jun 14, 2024
9168201
UPD: docstrings and rename
Damien-Bouet Jun 14, 2024
328130c
ADD: demo notebook CCP Regression
Damien-Bouet Jun 14, 2024
61e52b1
MERGE master
Damien-Bouet Jun 17, 2024
fbd8bd0
UPD: linting, tests and coverage
Damien-Bouet Jun 17, 2024
cfe299f
UPD: move reg_param into init
Damien-Bouet Jun 17, 2024
6db9867
ADD: author
Damien-Bouet Jun 17, 2024
68f704d
FIX: has no attribute
Damien-Bouet Jun 17, 2024
e8914a0
REMOVE: CCP Docstring example
Damien-Bouet Jun 17, 2024
afacd1a
REMOVE: example results
Damien-Bouet Jun 17, 2024
2538ca0
FIX: ccp_regression_demo
Damien-Bouet Jun 17, 2024
73c3a29
ADD: tutorial_ccp_CandC.ipynb
Damien-Bouet Jun 17, 2024
cb80f3b
ADD: ccp_tutorial notebook in readthedocs
Damien-Bouet Jun 17, 2024
f1e3be4
UPD: ccp tuto
Damien-Bouet Jun 17, 2024
950f02d
FIX: remove seaborn import
Damien-Bouet Jun 18, 2024
a7a3297
FIX: isort imports
Damien-Bouet Jun 18, 2024
cc41209
UPD: ccp tutorial
Damien-Bouet Jun 18, 2024
3cfdba1
UPD: remove multipliers from CCPCalibrator init and remove assert
Damien-Bouet Jun 18, 2024
ef8e378
DEL: ccp notebook moved in the doc, not usefull anymore
Damien-Bouet Jun 18, 2024
2ddc543
FIX: typo in docstrings
Damien-Bouet Jun 18, 2024
2e5918f
MOVE: check_calibrator in calibrators.utils
Damien-Bouet Jun 18, 2024
64be82e
UPD: docstrings and minor fix
Damien-Bouet Jun 18, 2024
cad8e28
ADD: notebook tutorial_ccp_CandC in regression notebooks doc
Damien-Bouet Jun 18, 2024
dca25d3
ADD: test to check equivalence of new and old implementation of stand…
Damien-Bouet Jun 18, 2024
b5ed289
UPD: typos
Damien-Bouet Jun 21, 2024
9bb8863
ADD: perfect width in ccp tutorial plots
Damien-Bouet Jun 21, 2024
9115a87
UPD: Change regularization from L2 to L1
Damien-Bouet Jun 21, 2024
1eed6de
FIX: ccp tuto plot
Damien-Bouet Jun 21, 2024
8247e30
FIX: ccp tuto
Damien-Bouet Jun 21, 2024
e04ab4a
UPD: only sample gaussian points where multipliers values are not zer…
Damien-Bouet Jul 15, 2024
909ec93
UPD: gaussian default value set to 20
Damien-Bouet Jul 16, 2024
60bda1a
FIX: calib_kwargs bug fix
Damien-Bouet Jul 16, 2024
d14fa1b
MOVE: ccp null feature warning call
Damien-Bouet Jul 16, 2024
89e31b9
UPD: calib kwargs docstring
Damien-Bouet Jul 16, 2024
43dd443
UPD: multiply default sigma value by dnum of dimensions
Damien-Bouet Jul 16, 2024
0ab0d77
UPD: docstrings and some renaming
Damien-Bouet Jul 20, 2024
f3d272a
FIX: calib_kwargs bug and linting
Damien-Bouet Jul 22, 2024
907aec6
RMV warning in optional arg in custim ccp
Damien-Bouet Jul 22, 2024
3de4ef3
FIX: multiplier impact on normalize and sigma
Damien-Bouet Jul 23, 2024
e704700
UPD: ccp tutorial
Damien-Bouet Jul 23, 2024
2900605
ADD: ccp in api doc
Damien-Bouet Jul 23, 2024
a54aad8
UPD: add ccp tutorial conclusion and remove typo
Damien-Bouet Jul 23, 2024
8f12ac4
FIX typo in ccp_tuto
Damien-Bouet Jul 23, 2024
44c78fa
ADD: CCP theoretical description doc
Damien-Bouet Jul 24, 2024
fea21c4
FIX: ccp theory doc
Damien-Bouet Jul 25, 2024
28ff85b
RENAME compile_functions_warnings_errors utils function and update er…
Damien-Bouet Jul 25, 2024
4c560fd
RENAME ccp calibrator test functions
Damien-Bouet Jul 25, 2024
b31609c
Try to fix the doc
Damien-Bouet Jul 25, 2024
2babfa2
UPD: ccp_CandC notebook
Damien-Bouet Jul 25, 2024
71eaa1c
MOVE CCP doc into a new section
Damien-Bouet Jul 25, 2024
c67bb57
ADD tuto papier reference link
Damien-Bouet Jul 25, 2024
391f529
ADD: calibrator doc
Damien-Bouet Jul 25, 2024
ddc9b52
UPD: minor corrections in the doc
Damien-Bouet Jul 26, 2024
f3e0d32
MOVE: BaseCalibrator import
Damien-Bouet Jul 26, 2024
00bfd27
UPD docstrings and add ccp reference
Damien-Bouet Jul 26, 2024
c1a00dc
RMV not reproductible warning
Damien-Bouet Jul 26, 2024
fefd068
Merge branch 'master' into 449-cp-with-conditional-guarantees
Damien-Bouet Jul 26, 2024
20037f7
REFACTO: Adapte the PR to the new Classifier refacto
Damien-Bouet Jul 29, 2024
8ac6ae9
FIX: tests
Damien-Bouet Jul 29, 2024
add96e6
UNDO changes in sets.utils
Damien-Bouet Jul 29, 2024
4cb0aba
Linting
Damien-Bouet Jul 29, 2024
1afcf27
UPD: change naive by standard in doc
Damien-Bouet Jul 29, 2024
20e94b2
UPD: change optimize to SLSQP
Damien-Bouet Jul 31, 2024
66fe957
UPD: update CandC notebook after changing optimizer
Damien-Bouet Jul 31, 2024
3c0f1c7
FIX tests
Damien-Bouet Jul 31, 2024
e82b1c7
UPD: fix coverage
Damien-Bouet Jul 31, 2024
40bf93a
FIX: typo
Damien-Bouet Jul 31, 2024
8ca56d9
UPD: theoretical description
Damien-Bouet Aug 5, 2024
2fc54fd
ADD: reference in README
Damien-Bouet Aug 5, 2024
26f07da
UPD: HISTORY with new CCP content
Damien-Bouet Aug 5, 2024
c33ed18
UPD: theoretical doc update and typo
Damien-Bouet Aug 7, 2024
db375a3
UPD: remove sample_weights and corrected alpha in the calibration ste…
Damien-Bouet Aug 7, 2024
53837bd
UPD: Typos in the doc
Damien-Bouet Aug 7, 2024
57e15f8
linting
Damien-Bouet Aug 7, 2024
9fa15fe
UPD: test values
Damien-Bouet Aug 7, 2024
ab14963
Merge branch 'master' into 449-cp-with-conditional-guarantees
Damien-Bouet Aug 8, 2024
744d56f
typo
Damien-Bouet Aug 8, 2024
14e05b9
UPD: add :class: tag in docstrings
Damien-Bouet Aug 9, 2024
10a419e
UPD: doc
Damien-Bouet Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,369 @@
"""
======================================================================
Reproduction of part of the paper experiments of Gibbs et al. (2023)
======================================================================

:class:`~mapie.regression.MapieCCPRegressor` is used to reproduce a
part of the paper experiments of Gibbs et al. (2023) in their article [1]
which we argue is a good procedure to get adaptative prediction intervals (PI)
and a guaranteed coverage on all sub groups of interest.

For a given model, the simulation adjusts the MAPIE regressors using the
``CCP`` method, on a synthetic dataset first considered by Romano et al. (2019)
[2], and compares the bounds of the PIs with the standard split CP.

In order to reproduce the results of the standard split conformal prediction
(Split CP), we reuse the Mapie implementation in
:class:`~mapie.regression.MapieRegressor`.

This simulation is carried out to check that the CCP method implemented in
MAPIE gives the same results as [1], and that the bounds of the PIs are
obtained.

[1] Isaac Gibbs, John J. Cherian, Emmanuel J. Candès (2023).
Conformal Prediction With Conditional Guarantees

[2] Yaniv Romano, Evan Patterson, Emmanuel J. Candès (2019).
Conformalized Quantile Regression.
33rd Conference on Neural Information Processing Systems (NeurIPS 2019).
"""
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mapie.conformity_scores import AbsoluteConformityScore
from mapie.regression import (MapieCCPRegressor, MapieRegressor,
PhiFunction, GaussianPhiFunction)
from scipy.stats import norm
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures

warnings.filterwarnings("ignore")

random_state = 1
np.random.seed(random_state)


###############################################################################
# 1. Global model parameters
# -----------------------------------------------------------------------------

def init_model():
# the degree of the polynomial regression
degree = 4

model = Pipeline(
[
("poly", PolynomialFeatures(degree=degree)),
("linear", LinearRegression())
]
)
return model

###############################################################################
# 2. Generate and present data
# -----------------------------------------------------------------------------


def generate_data(n_train=2000, n_calib=2000, n_test=500):
def f(x):
ax = 0*x
for i in range(len(x)):
ax[i] = (np.random.poisson(np.sin(x[i])**2 + 0.1)
+ 0.03*x[i]*np.random.randn(1))
ax[i] += 25*(np.random.uniform(0, 1, 1) < 0.01)*np.random.randn(1)
return ax.astype(np.float32)

# training features
X_train = np.random.uniform(0, 5.0, size=n_train).astype(np.float32)
X_calib = np.random.uniform(0, 5.0, size=n_calib).astype(np.float32)
X_test = np.random.uniform(0, 5.0, size=n_test).astype(np.float32)

# generate labels
y_train = f(X_train)
y_calib = f(X_calib)
y_test = f(X_test)

# reshape the features
X_train = X_train.reshape(-1, 1)
X_calib = X_calib.reshape(-1, 1)
X_test = X_test.reshape(-1, 1)

return X_train, y_train, X_calib, y_calib, X_test, y_test


X_train, y_train, X_calib, y_calib, X_test, y_test = generate_data()

fig = plt.figure(figsize=(12, 5))
ax1 = fig.add_subplot(1, 2, 1)
ax1.scatter(X_train[:, 0], y_train, s=1.5, alpha=0.6, label="Train Data")
ax1.set_xlabel("X")
ax1.set_ylabel("Y")
ax1.set_title("Train Data")
ax1.legend()

ax2 = fig.add_subplot(1, 2, 2)
ax2.scatter(X_train[:, 0], y_train, s=1.5, alpha=0.6, label="Train Data")
ax2.set_ylim([-2, 6])
ax2.set_xlabel("X")
ax2.set_ylabel("Y")
ax2.set_title("Zoom")
ax2.legend()

plt.show()

##############################################################################
# 3. Prepare model and show predictions
# -----------------------------------------------------------------------------

model = init_model()

model.fit(X_train, y_train)

sort_order = np.argsort(X_test[:, 0])
x_test_s = X_test[sort_order]
y_pred_s = model.predict(x_test_s)

plt.figure(figsize=(6, 5))
plt.scatter(X_test[:, 0], y_test, s=1.5, alpha=0.6, label="Test Data")
plt.plot(x_test_s, y_pred_s, "-k", label="Prediction")
plt.ylim([-2, 6])
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Test Data (Zoom)")
plt.legend()
plt.show()


##############################################################################
# 4. Prepare Experiments
# -----------------------------------------------------------------------------
# In this experiment, we will use the
# :class:`~mapie.regression.MapieRegressor` and
# :class:`~mapie.regression.MapieCCPRegressor` to compute prediction intervals
# with the basic Split CP method and the paper CCP method.
# The coverages will be computed on 500 different dataset generation, to have
LacombeLouis marked this conversation as resolved.
Show resolved Hide resolved
# a good idea of the true value. Indeed, the empirical coverage of a single
# experiment is stochastic, because of the finite number of calibration and
# test samples.

ALPHA = 0.1


def estimate_coverage(mapie_split, mapie_ccp, group_functs=[]):
_, _, X_calib, y_calib, X_test, y_test = generate_data()

mapie_split.fit(X_calib, y_calib)
_, y_pi_split = mapie_split.predict(X_test, alpha=ALPHA)

mapie_ccp.calibrate(X_calib, y_calib)
_, y_pi_ccp = mapie_ccp.predict(X_test)

cover_split = np.logical_or(y_test < y_pi_split[:, 0, 0],
y_test > y_pi_split[:, 1, 0])
cover_ccp = np.logical_or(y_test < y_pi_ccp[:, 0, 0],
y_test > y_pi_ccp[:, 1, 0])
group_covers = []
marginal_cover = np.asarray((cover_split.mean(), cover_ccp.mean()))
for funct in group_functs:
group_cover = np.zeros((2,))
group_cover[0] = (funct(X_test).flatten()
* cover_split).sum() / funct(X_test).sum()
group_cover[1] = (funct(X_test).flatten()
* cover_ccp).sum() / funct(X_test).sum()
group_covers.append(group_cover)
return marginal_cover, np.array(group_covers)


def plot_results(X_test, y_test, n_trials=10,
experiment="Groups", split_sym=True):

# Split CP
mapie_split = MapieRegressor(
model, method="base", cv="prefit",
conformity_score=AbsoluteConformityScore(sym=split_sym)
)
mapie_split.conformity_score.eps = 1e-5
mapie_split.fit(X_calib, y_calib)
_, y_pi_split = mapie_split.predict(X_test, alpha=ALPHA)

if experiment == "Groups":
# CCP Groups
phi_groups = PhiFunction([
lambda X, t=t: np.logical_and(X >= t, X < t + 0.5).astype(int)
for t in np.arange(0, 5.5, 0.5)
])
mapie_ccp = MapieCCPRegressor(
model, phi=phi_groups, alpha=ALPHA, cv="prefit",
conformity_score=AbsoluteConformityScore(sym=False),
random_state=None
)
mapie_ccp.conformity_score_.eps = 1e-5
mapie_ccp.calibrate(X_calib, y_calib)
_, y_pi_ccp = mapie_ccp.predict(X_test)
else:
# CCP Shifts
eval_locs = [1.5, 3.5]
eval_scale = 0.2
other_locs = [0.5, 2.5, 4.5]
other_scale = 1

phi_shifts = GaussianPhiFunction(
points=(
np.array(eval_locs+other_locs).reshape(-1, 1),
[eval_scale]*len(eval_locs) + [other_scale]*len(other_locs),
),
marginal_guarantee=True,
normalized=False,
)
mapie_ccp = MapieCCPRegressor(
model, phi=phi_shifts, alpha=ALPHA, cv="prefit",
conformity_score=AbsoluteConformityScore(sym=False),
random_state=None
)
mapie_ccp.conformity_score_.eps = 1e-5
mapie_ccp.calibrate(X_calib, y_calib)
_, y_pi_ccp = mapie_ccp.predict(X_test)

# =========== n_trials run to get average marginal coverage ============
if experiment == "Groups":
eval_functions = [
lambda X, a=a, b=b: np.logical_and(X >= a, X <= b).astype(int)
for a, b in zip([1, 3], [2, 4])
]
eval_names = ["[1, 2]", "[3, 4]"]
else:
eval_functions = [
lambda x: norm.pdf(x, loc=1.5, scale=0.2).reshape(-1, 1),
lambda x: norm.pdf(x, loc=3.5, scale=0.2).reshape(-1, 1)
]
eval_names = ["f1", "f2"]

marginal_cov = np.zeros((n_trials, 2))
group_cov = np.zeros((len(eval_functions), n_trials, 2))
for j in range(n_trials):
marginal_cov[j], group_cov[:, j, :] = estimate_coverage(
mapie_split, mapie_ccp, eval_functions
)

coverageData = pd.DataFrame()

for group, cov in zip(["Marginal"]+eval_names,
[marginal_cov] + list(group_cov)):
for i, name in enumerate(["Split", "CCP"]):
coverageData = pd.concat(
[coverageData,
pd.DataFrame({'Method': [name] * len(cov),
'Range': [group] * len(cov),
'Miscoverage': np.asarray(cov)[:, i]})],
axis=0
)

# ================== results plotting ==================
cp = plt.get_cmap('tab10').colors

# Set font and style
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['axes.grid'] = False

fig = plt.figure()
fig.set_size_inches(17, 6)

sort_order = np.argsort(X_test[:, 0])
x_test_s = X_test[sort_order]
y_test_s = y_test[sort_order]
y_pred_s = model.predict(x_test_s)

ax1 = fig.add_subplot(1, 3, 1)
ax1.plot(x_test_s, y_test_s, '.', alpha=0.2)
ax1.plot(x_test_s, y_pred_s, lw=1, color='k')
ax1.plot(x_test_s, y_pi_split[sort_order, 0, 0], color=cp[0], lw=2)
ax1.plot(x_test_s, y_pi_split[sort_order, 1, 0], color=cp[0], lw=2)
ax1.fill_between(x_test_s.flatten(), y_pi_split[sort_order, 0, 0],
y_pi_split[sort_order, 1, 0],
color=cp[0], alpha=0.4, label='split prediction interval')
ax1.set_ylim(-2, 6.5)
ax1.tick_params(axis='both', which='major', labelsize=14)
ax1.set_xlabel("$X$", fontsize=16, labelpad=10)
ax1.set_ylabel("$Y$", fontsize=16, labelpad=10)
ax1.set_title("Split calibration", fontsize=18, pad=12)

if experiment == 'Groups':
ax1.axvspan(1, 2, facecolor='grey', alpha=0.25)
ax1.axvspan(3, 4, facecolor='grey', alpha=0.25)
else:
for loc in eval_locs:
ax1.plot(x_test_s, norm.pdf(x_test_s, loc=loc, scale=eval_scale),
color='grey', ls='--', lw=3)

ax2 = fig.add_subplot(1, 3, 2, sharex=ax1, sharey=ax1)
ax2.plot(x_test_s, y_test_s, '.', alpha=0.2)
ax2.plot(x_test_s, y_pred_s, color='k', lw=1)
ax2.plot(x_test_s, y_pi_ccp[sort_order, 0, 0], color=cp[1], lw=2)
ax2.plot(x_test_s, y_pi_ccp[sort_order, 1, 0], color=cp[1], lw=2)
ax2.fill_between(x_test_s.flatten(), y_pi_ccp[sort_order, 0, 0],
y_pi_ccp[sort_order, 1, 0], color=cp[1], alpha=0.4,
label='conditional calibration')
ax2.tick_params(axis='both', which='major', direction='out', labelsize=14)
ax2.set_xlabel("$X$", fontsize=16, labelpad=10)
ax2.set_ylabel("$Y$", fontsize=16, labelpad=10)
ax2.set_title("Conditional calibration", fontsize=18, pad=12)

if experiment == 'Groups':
ax2.axvspan(1, 2, facecolor='grey', alpha=0.25)
ax2.axvspan(3, 4, facecolor='grey', alpha=0.25)
else:
for loc in eval_locs:
ax2.plot(x_test_s, norm.pdf(x_test_s, loc=loc, scale=eval_scale),
color='grey', ls='--', lw=3)

ax3 = fig.add_subplot(1, 3, 3)

ranges = coverageData['Range'].unique()
methods = coverageData['Method'].unique()
bar_width = 0.8 / len(methods)
for i, method in enumerate(methods):
method_data = coverageData[coverageData['Method'] == method]
x = np.arange(len(ranges)) + i * bar_width
ax3.bar(x, method_data.groupby("Range")['Miscoverage'].mean(),
width=bar_width, label=method, color=cp[i])

ax3.set_xticks(np.arange(len(ranges)) + bar_width * (len(methods) - 1) / 2)
ax3.set_xticklabels(ranges)

ax3.axhline(0.1, color='red')
ax3.legend()
ax3.set_ylabel("Miscoverage", fontsize=18, labelpad=10)
ax3.set_xlabel(experiment, fontsize=18, labelpad=10)
ax3.set_ylim(0., 0.2)
ax3.tick_params(axis='both', which='major', labelsize=14)

plt.tight_layout(pad=2)
plt.show()


##############################################################################
# 5. Reproduce experiment and results
# -----------------------------------------------------------------------------

plot_results(X_test, y_test, 500, experiment="Groups")

plot_results(X_test, y_test, 500, experiment="Shifts")


##############################################################################
# We succesfully reproduced the experiement of the Gibbs et al. paper [1].

##############################################################################
# 6. Variant of the experiments: let's compare what is comparable
# -----------------------------------------------------------------------------
#
# In the paper, the proposed method (used with not symetrical PI) is compared
# to the split method with symetrical PI. Let's compare it to the split CP with
# unsymetrical PI, to have a fair comparison.

plot_results(X_test, y_test, 500, experiment="Groups")

plot_results(X_test, y_test, 500, experiment="Shifts", split_sym=False)
9 changes: 8 additions & 1 deletion mapie/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from .quantile_regression import MapieQuantileRegressor
from .regression import MapieRegressor
from .ccp_regression import MapieCCPRegressor
from .utils.ccp_phi_function import (PhiFunction, PolynomialPhiFunction,
GaussianPhiFunction)
from .time_series_regression import MapieTimeSeriesRegressor

__all__ = [
"MapieRegressor",
"MapieQuantileRegressor",
"MapieTimeSeriesRegressor"
"MapieTimeSeriesRegressor",
"MapieCCPRegressor",
"PhiFunction",
"PolynomialPhiFunction",
"GaussianPhiFunction",
thibaultcordier marked this conversation as resolved.
Show resolved Hide resolved
]
Loading
Loading