Skip to content

Feature/optimized cp #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ TorchCP has implemented the following methods:
| 2023 | [**Conformal Prediction Sets for Ordinal Classification**](https://proceedings.neurips.cc/paper_files/paper/2023/file/029f699912bf3db747fe110948cc6169-Paper-Conference.pdf) | NeurIPS'23 | | classification.trainer.ordinal |
| 2022 | [**Training Uncertainty-Aware Classifiers with Conformalized Deep Learning**](https://arxiv.org/abs/2205.05878) | NeurIPS'22 | [Link](https://github.com/bat-sheva/conformal-learning) | classification.loss.uncertainty_aware |
| 2022 | [**Learning Optimal Conformal Classifiers**](https://arxiv.org/abs/2110.09192) | ICLR'22 | [Link](https://github.com/google-deepmind/conformal_training/tree/main) | classification.loss.conftr |
| 2021 | [**Optimized conformal classification using gradient descent approximation**](https://arxiv.org/abs/2105.11255) | Arxiv | | classification.loss.scpo |
| 2021 | [**Uncertainty Sets for Image Classifiers using Conformal Prediction**](https://arxiv.org/abs/2009.14193 ) | ICLR'21 | [Link](https://github.com/aangelopoulos/conformal_classification) | classification.score.raps classification.score.topk |
| 2020 | [**Classification with Valid and Adaptive Coverage**](https://proceedings.neurips.cc/paper/2020/file/244edd7e85dc81602b7615cd705545f5-Paper.pdf) | NeurIPS'20 | [Link](https://github.com/msesia/arc) | classification.score.aps |
| 2019 | [**Conformal Prediction Under Covariate Shift**](https://arxiv.org/abs/1904.06019) | NeurIPS'19 | [Link](https://github.com/ryantibs/conformal/) | classification.predictor.weight |
Expand Down Expand Up @@ -128,7 +129,6 @@ TorchCP is still under active development. We will add the following features/it
|------|-----------------------------------------------------------------------------------------------------------------|---------|----------------------------------------------------------------------------|
| 2022 | [**Adaptive Conformal Predictions for Time Series**](https://arxiv.org/abs/2202.07282) | ICML'22 | [Link](https://github.com/mzaffran/AdaptiveConformalPredictionsTimeSeries) |
| 2022 | [**Conformal Prediction Sets with Limited False Positives**](https://arxiv.org/abs/2202.07650) | ICML'22 | [Link](https://github.com/ajfisch/conformal-fp) |
| 2021 | [**Optimized conformal classification using gradient descent approximation**](https://arxiv.org/abs/2105.11255) | Arxiv | |

## Installation

Expand Down
82 changes: 82 additions & 0 deletions torchcp/classification/loss/scpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

__all__ = ["SCPO"]

import torch

from torchcp.classification.loss.confts import ConfTSLoss
from torchcp.classification.loss.conftr import ConfTrLoss


class SCPOLoss(ConfTSLoss):
"""
Surrogate Conformal Predictor Optimization (SCPO).

The class implements the loss function of the surrogate conformal predictor optimization,
which is an approach to train the conformal predictor directly with maximum predictive
efficiency as the optimization objective. The conformal predictor is approximated by a
differentiable objective function and gradient descent used to optimize it.

Args:
predictor (torchcp.classification.Predictor): An instance of the CP predictor class.
alpha (float): The significance level for each training batch.
lambda_val (float): Weight for the coverage loss term.
gamma_val (float): Inverse of the temperature value.
loss_transform (str, optional): A transform for loss. Default is "log".
Can be "log" or "neg_inv".

Examples::
>>> predictor = torchcp.classification.SplitPredictor(score_function=THR(score_type="identity"))
>>> scpo = SCPOLoss(predictor=predictor, alpha=0.01)
>>> logits = torch.randn(100, 10)
>>> labels = torch.randint(0, 2, (100,))
>>> loss = scpo(logits, labels)
>>> loss.backward()

Reference:
Bellotti et al. "Optimized conformal classification using gradient descent approximation", http://arxiv.org/abs/2105.11255

"""

def __init__(self, predictor, alpha, lambda_val=500, gamma_val=5, loss_transform="log"):
super(SCPOLoss, self).__init__(predictor, alpha)
self.lambda_val = lambda_val

if loss_transform == "log":
self.transform = torch.log
elif loss_transform == "neg_inv":
self.transform = lambda x: -1 / x
else:
raise ValueError("loss_transform should be log or neg_inv.")

self.size_loss_fn = ConfTrLoss(predictor,
alpha,
fraction=0.5,
epsilon=1/gamma_val,
loss_type="valid",
target_size=0,
loss_transform="abs")
self.coverage_loss_fn = ConfTrLoss(predictor,
alpha,
fraction=0.5,
epsilon=1/gamma_val,
loss_type="coverage")

def forward(self, logits, labels):
logits = logits.to(self.device)
labels = labels.to(self.device)

test_scores = self.predictor.score_function(logits)
test_labels = labels

return self.compute_loss(test_scores, test_labels, 1)

def compute_loss(self, test_scores, test_labels, tau):
size_loss = self.size_loss_fn.compute_loss(test_scores, test_labels, tau)
coverage_loss = self.coverage_loss_fn.compute_loss(test_scores, test_labels, tau)
return self.transform(size_loss + self.lambda_val * coverage_loss)
1 change: 1 addition & 0 deletions torchcp/classification/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .ts_trainer import TSTrainer
from .ua_trainer import UncertaintyAwareTrainer
from .ordinal_trainer import OrdinalTrainer
from .scpo_trainer import SCPOTrainer
52 changes: 52 additions & 0 deletions torchcp/classification/trainer/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,55 @@ def forward(self, x):
# the unimodal distribution
x = self.varphi_function(x)
return x


class SurrogateCPModel(nn.Module):
"""
This model wraps a given base model and adds a linear layer on top of its final feature output.
The base model's parameters are frozen to prevent updates during training, so only the added
linear layer is trainable.

Args:
base_model (nn.Module): Pre-trained model

Shape:
- Input: Same as base_model input
- Output: (batch_size, num_classes) 1 - logits

Examples:
>>> base_model = resnet18(pretrained=True)
>>> model = SurrogateCPModel(base_model)
>>> inputs = torch.randn(10, 3, 224, 224)
>>> logits = model(inputs)
"""

def __init__(self, base_model: nn.Module):
super().__init__()
self.base_model = base_model
self.linear = nn.Linear(in_features=base_model.fc.out_features,
out_features=base_model.fc.out_features,
bias=False)

# Freeze base model parameters
self.freeze_base_model()

def freeze_base_model(self):
"""Freeze all parameters in base model"""
for param in self.base_model.parameters():
param.requires_grad = False
self.base_model.eval()

def is_base_model_frozen(self) -> bool:
"""Check if base model parameters are frozen"""
return not any(p.requires_grad for p in self.base_model.parameters())

def forward(self, x: Tensor) -> Tensor:
with torch.no_grad(): # Ensure no gradients flow through base model
logits = self.base_model(x)

return 1 - self.linear(logits)

def train(self, mode: bool = True):
super().train(mode) # Set training mode for TemperatureScalingModel
self.base_model.eval() # Keep base_model in eval mode
return self
65 changes: 65 additions & 0 deletions torchcp/classification/trainer/scpo_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import torch

from torchcp.classification.loss.scpo import SCPOLoss
from torchcp.classification.predictor import SplitPredictor
from torchcp.classification.score import THR
from torchcp.classification.trainer.base_trainer import Trainer
from torchcp.classification.trainer.model_zoo import SurrogateCPModel


class SCPOTrainer(Trainer):
"""
Trainer for Surrogate Conformal Predictor Optimization.

Args:
alpha (float): The significance level for each training batch.
model (torch.nn.Module): Base neural network model to be calibrated.
device (torch.device, optional): Device to run the model on. If None, will automatically use GPU ('cuda') if available, otherwise CPU ('cpu')
Default: None
verbose (bool): Whether to display training progress. Default: True.
lr (float): Learning rate for the optimizer. Default is 0.1.
lambda_val (float): Weight for the coverage loss term.
gamma_val (float): Inverse of the temperature value.

Examples:
>>> # Define base model
>>> backbone = torchvision.models.resnet18(pretrained=True)
>>>
>>> # Create SCPO trainer
>>> trainer = SCPOTrainer(
... alpha=0.01,
... model=model,
... device=device,
... verbose=True)
>>>
>>> # Train model
>>> trainer.train(
... train_loader=train_loader,
... num_epochs=10
... )
"""

def __init__(
self,
alpha: float,
model: torch.nn.Module,
device: torch.device = None,
verbose: bool = True,
lr: float = 0.1,
lambda_val: float = 10000,
gamma_val: float = 1):

model = SurrogateCPModel(model)
super().__init__(model, device=device, verbose=verbose)
predictor = SplitPredictor(score_function=THR(score_type="identity"), model=model)

self.optimizer = torch.optim.Adam(self.model.linear.parameters(), lr=lr)
self.loss_fn = SCPOLoss(predictor=predictor, alpha=alpha,
lambda_val=lambda_val, gamma_val=gamma_val)