diff --git a/README.md b/README.md index 7a64580..20457de 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ TorchCP has implemented the following methods: | 2023 | [**Conformal Prediction Sets for Ordinal Classification**](https://proceedings.neurips.cc/paper_files/paper/2023/file/029f699912bf3db747fe110948cc6169-Paper-Conference.pdf) | NeurIPS'23 | | classification.trainer.ordinal | | 2022 | [**Training Uncertainty-Aware Classifiers with Conformalized Deep Learning**](https://arxiv.org/abs/2205.05878) | NeurIPS'22 | [Link](https://github.com/bat-sheva/conformal-learning) | classification.loss.uncertainty_aware | | 2022 | [**Learning Optimal Conformal Classifiers**](https://arxiv.org/abs/2110.09192) | ICLR'22 | [Link](https://github.com/google-deepmind/conformal_training/tree/main) | classification.loss.conftr | +| 2021 | [**Optimized conformal classification using gradient descent approximation**](https://arxiv.org/abs/2105.11255) | Arxiv | | classification.loss.scpo | | 2021 | [**Uncertainty Sets for Image Classifiers using Conformal Prediction**](https://arxiv.org/abs/2009.14193 ) | ICLR'21 | [Link](https://github.com/aangelopoulos/conformal_classification) | classification.score.raps classification.score.topk | | 2020 | [**Classification with Valid and Adaptive Coverage**](https://proceedings.neurips.cc/paper/2020/file/244edd7e85dc81602b7615cd705545f5-Paper.pdf) | NeurIPS'20 | [Link](https://github.com/msesia/arc) | classification.score.aps | | 2019 | [**Conformal Prediction Under Covariate Shift**](https://arxiv.org/abs/1904.06019) | NeurIPS'19 | [Link](https://github.com/ryantibs/conformal/) | classification.predictor.weight | @@ -128,7 +129,6 @@ TorchCP is still under active development. We will add the following features/it |------|-----------------------------------------------------------------------------------------------------------------|---------|----------------------------------------------------------------------------| | 2022 | [**Adaptive Conformal Predictions for Time Series**](https://arxiv.org/abs/2202.07282) | ICML'22 | [Link](https://github.com/mzaffran/AdaptiveConformalPredictionsTimeSeries) | | 2022 | [**Conformal Prediction Sets with Limited False Positives**](https://arxiv.org/abs/2202.07650) | ICML'22 | [Link](https://github.com/ajfisch/conformal-fp) | -| 2021 | [**Optimized conformal classification using gradient descent approximation**](https://arxiv.org/abs/2105.11255) | Arxiv | | ## Installation diff --git a/torchcp/classification/loss/scpo.py b/torchcp/classification/loss/scpo.py new file mode 100644 index 0000000..75b4b29 --- /dev/null +++ b/torchcp/classification/loss/scpo.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +__all__ = ["SCPO"] + +import torch + +from torchcp.classification.loss.confts import ConfTSLoss +from torchcp.classification.loss.conftr import ConfTrLoss + + +class SCPOLoss(ConfTSLoss): + """ + Surrogate Conformal Predictor Optimization (SCPO). + + The class implements the loss function of the surrogate conformal predictor optimization, + which is an approach to train the conformal predictor directly with maximum predictive + efficiency as the optimization objective. The conformal predictor is approximated by a + differentiable objective function and gradient descent used to optimize it. + + Args: + predictor (torchcp.classification.Predictor): An instance of the CP predictor class. + alpha (float): The significance level for each training batch. + lambda_val (float): Weight for the coverage loss term. + gamma_val (float): Inverse of the temperature value. + loss_transform (str, optional): A transform for loss. Default is "log". + Can be "log" or "neg_inv". + + Examples:: + >>> predictor = torchcp.classification.SplitPredictor(score_function=THR(score_type="identity")) + >>> scpo = SCPOLoss(predictor=predictor, alpha=0.01) + >>> logits = torch.randn(100, 10) + >>> labels = torch.randint(0, 2, (100,)) + >>> loss = scpo(logits, labels) + >>> loss.backward() + + Reference: + Bellotti et al. "Optimized conformal classification using gradient descent approximation", http://arxiv.org/abs/2105.11255 + + """ + + def __init__(self, predictor, alpha, lambda_val=500, gamma_val=5, loss_transform="log"): + super(SCPOLoss, self).__init__(predictor, alpha) + self.lambda_val = lambda_val + + if loss_transform == "log": + self.transform = torch.log + elif loss_transform == "neg_inv": + self.transform = lambda x: -1 / x + else: + raise ValueError("loss_transform should be log or neg_inv.") + + self.size_loss_fn = ConfTrLoss(predictor, + alpha, + fraction=0.5, + epsilon=1/gamma_val, + loss_type="valid", + target_size=0, + loss_transform="abs") + self.coverage_loss_fn = ConfTrLoss(predictor, + alpha, + fraction=0.5, + epsilon=1/gamma_val, + loss_type="coverage") + + def forward(self, logits, labels): + logits = logits.to(self.device) + labels = labels.to(self.device) + + test_scores = self.predictor.score_function(logits) + test_labels = labels + + return self.compute_loss(test_scores, test_labels, 1) + + def compute_loss(self, test_scores, test_labels, tau): + size_loss = self.size_loss_fn.compute_loss(test_scores, test_labels, tau) + coverage_loss = self.coverage_loss_fn.compute_loss(test_scores, test_labels, tau) + return self.transform(size_loss + self.lambda_val * coverage_loss) \ No newline at end of file diff --git a/torchcp/classification/trainer/__init__.py b/torchcp/classification/trainer/__init__.py index cfd6f91..ae2169c 100644 --- a/torchcp/classification/trainer/__init__.py +++ b/torchcp/classification/trainer/__init__.py @@ -11,3 +11,4 @@ from .ts_trainer import TSTrainer from .ua_trainer import UncertaintyAwareTrainer from .ordinal_trainer import OrdinalTrainer +from .scpo_trainer import SCPOTrainer \ No newline at end of file diff --git a/torchcp/classification/trainer/model_zoo.py b/torchcp/classification/trainer/model_zoo.py index bb5f85b..6477c7c 100644 --- a/torchcp/classification/trainer/model_zoo.py +++ b/torchcp/classification/trainer/model_zoo.py @@ -165,3 +165,55 @@ def forward(self, x): # the unimodal distribution x = self.varphi_function(x) return x + + +class SurrogateCPModel(nn.Module): + """ + This model wraps a given base model and adds a linear layer on top of its final feature output. + The base model's parameters are frozen to prevent updates during training, so only the added + linear layer is trainable. + + Args: + base_model (nn.Module): Pre-trained model + + Shape: + - Input: Same as base_model input + - Output: (batch_size, num_classes) 1 - logits + + Examples: + >>> base_model = resnet18(pretrained=True) + >>> model = SurrogateCPModel(base_model) + >>> inputs = torch.randn(10, 3, 224, 224) + >>> logits = model(inputs) + """ + + def __init__(self, base_model: nn.Module): + super().__init__() + self.base_model = base_model + self.linear = nn.Linear(in_features=base_model.fc.out_features, + out_features=base_model.fc.out_features, + bias=False) + + # Freeze base model parameters + self.freeze_base_model() + + def freeze_base_model(self): + """Freeze all parameters in base model""" + for param in self.base_model.parameters(): + param.requires_grad = False + self.base_model.eval() + + def is_base_model_frozen(self) -> bool: + """Check if base model parameters are frozen""" + return not any(p.requires_grad for p in self.base_model.parameters()) + + def forward(self, x: Tensor) -> Tensor: + with torch.no_grad(): # Ensure no gradients flow through base model + logits = self.base_model(x) + + return 1 - self.linear(logits) + + def train(self, mode: bool = True): + super().train(mode) # Set training mode for TemperatureScalingModel + self.base_model.eval() # Keep base_model in eval mode + return self \ No newline at end of file diff --git a/torchcp/classification/trainer/scpo_trainer.py b/torchcp/classification/trainer/scpo_trainer.py new file mode 100644 index 0000000..7fed6ce --- /dev/null +++ b/torchcp/classification/trainer/scpo_trainer.py @@ -0,0 +1,65 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch + +from torchcp.classification.loss.scpo import SCPOLoss +from torchcp.classification.predictor import SplitPredictor +from torchcp.classification.score import THR +from torchcp.classification.trainer.base_trainer import Trainer +from torchcp.classification.trainer.model_zoo import SurrogateCPModel + + +class SCPOTrainer(Trainer): + """ + Trainer for Surrogate Conformal Predictor Optimization. + + Args: + alpha (float): The significance level for each training batch. + model (torch.nn.Module): Base neural network model to be calibrated. + device (torch.device, optional): Device to run the model on. If None, will automatically use GPU ('cuda') if available, otherwise CPU ('cpu') + Default: None + verbose (bool): Whether to display training progress. Default: True. + lr (float): Learning rate for the optimizer. Default is 0.1. + lambda_val (float): Weight for the coverage loss term. + gamma_val (float): Inverse of the temperature value. + + Examples: + >>> # Define base model + >>> backbone = torchvision.models.resnet18(pretrained=True) + >>> + >>> # Create SCPO trainer + >>> trainer = SCPOTrainer( + ... alpha=0.01, + ... model=model, + ... device=device, + ... verbose=True) + >>> + >>> # Train model + >>> trainer.train( + ... train_loader=train_loader, + ... num_epochs=10 + ... ) + """ + + def __init__( + self, + alpha: float, + model: torch.nn.Module, + device: torch.device = None, + verbose: bool = True, + lr: float = 0.1, + lambda_val: float = 10000, + gamma_val: float = 1): + + model = SurrogateCPModel(model) + super().__init__(model, device=device, verbose=verbose) + predictor = SplitPredictor(score_function=THR(score_type="identity"), model=model) + + self.optimizer = torch.optim.Adam(self.model.linear.parameters(), lr=lr) + self.loss_fn = SCPOLoss(predictor=predictor, alpha=alpha, + lambda_val=lambda_val, gamma_val=gamma_val)