From 74cc57979cc9a0ada35629116cc1984a9247afc1 Mon Sep 17 00:00:00 2001 From: saichandrapandraju Date: Mon, 14 Apr 2025 21:51:53 -0400 Subject: [PATCH 1/2] Introduce RemoteLLMAttribution to support LLMAttribution for remotely hosted models that provide logprobs (like vLLM) --- captum/attr/__init__.py | 5 + captum/attr/_core/llm_attr.py | 117 +++++++++ captum/attr/_core/remote_provider.py | 93 +++++++ setup.py | 4 + tests/attr/test_llm_attr.py | 364 ++++++++++++++++++++++++++- 5 files changed, 581 insertions(+), 2 deletions(-) create mode 100644 captum/attr/_core/remote_provider.py mode change 100755 => 100644 setup.py diff --git a/captum/attr/__init__.py b/captum/attr/__init__.py index a33cd862dd..ee006bbe2d 100644 --- a/captum/attr/__init__.py +++ b/captum/attr/__init__.py @@ -27,7 +27,9 @@ LLMAttribution, LLMAttributionResult, LLMGradientAttribution, + RemoteLLMAttribution, ) +from captum.attr._core.remote_provider import RemoteLLMProvider, VLLMProvider from captum.attr._core.lrp import LRP from captum.attr._core.neuron.neuron_conductance import NeuronConductance from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap @@ -111,6 +113,9 @@ "LLMAttribution", "LLMAttributionResult", "LLMGradientAttribution", + "RemoteLLMAttribution", + "RemoteLLMProvider", + "VLLMProvider", "InternalInfluence", "InterpretableInput", "LayerGradCam", diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 3466ad4996..772b06838f 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -35,6 +35,7 @@ TextTokenInput, ) from torch import nn, Tensor +from captum.attr._core.remote_provider import RemoteLLMProvider DEFAULT_GEN_ARGS: Dict[str, Any] = { "max_new_tokens": 25, @@ -892,3 +893,119 @@ def forward( # the attribution target is limited to the log probability return token_log_probs + + +class RemoteLLMAttribution(LLMAttribution): + """ + Attribution class for large language models that are hosted remotely and offer logprob APIs. + """ + def __init__( + self, + attr_method: PerturbationAttribution, + tokenizer: TokenizerLike, + provider: RemoteLLMProvider, + attr_target: str = "log_prob", + ) -> None: + """ + Args: + attr_method: Instance of a supported perturbation attribution class + tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method + provider: Remote LLM provider that implements the RemoteLLMProvider protocol + attr_target: attribute towards log probability or probability. + Available values ["log_prob", "prob"] + Default: "log_prob" + """ + super().__init__( + attr_method=attr_method, + tokenizer=tokenizer, + attr_target=attr_target, + ) + + self.provider = provider + self.attr_method.forward_func = self._remote_forward_func + + def _get_target_tokens( + self, + inp: InterpretableInput, + target: Union[str, torch.Tensor, None] = None, + skip_tokens: Union[List[int], List[str], None] = None, + gen_args: Optional[Dict[str, Any]] = None + ) -> Tensor: + """ + Get the target tokens for the remote LLM provider. + """ + assert isinstance( + inp, self.SUPPORTED_INPUTS + ), f"RemoteLLMAttribution does not support input type {type(inp)}" + + if target is None: + # generate when None with remote provider + assert hasattr(self.provider, "generate") and callable(self.provider.generate), ( + "The provider does not have generate function for generating target sequence." + "Target must be given for attribution" + ) + if not gen_args: + gen_args = DEFAULT_GEN_ARGS + + model_inp = self._format_model_input(inp.to_model_input()) + target_str = self.provider.generate(model_inp, **gen_args) + target_tokens = self.tokenizer.encode(target_str, return_tensors="pt", add_special_tokens=False)[0] + + else: + target_tokens = super()._get_target_tokens(inp, target, skip_tokens, gen_args) + + return target_tokens + + def _format_model_input(self, model_input: Union[str, Tensor]) -> str: + """ + Format the model input for the remote LLM provider. + """ + # return str input + if isinstance(model_input, Tensor): + return self.tokenizer.decode(model_input.flatten()) + return model_input + + def _remote_forward_func( + self, + perturbed_tensor: Union[None, Tensor], + inp: InterpretableInput, + target_tokens: Tensor, + use_cached_outputs: bool = False, + _inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None, + ) -> Tensor: + """ + Forward function for the remote LLM provider. + """ + + perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) + + target_str:str = self.tokenizer.decode(target_tokens) + + target_token_probs = self.provider.get_logprobs(input_prompt=perturbed_input, target_str=target_str, tokenizer=self.tokenizer) + + assert len(target_token_probs) == target_tokens.size()[0], ( + f"Number of token logprobs from provider ({len(target_token_probs)}) " + f"does not match expected target token length ({target_tokens.size()[0]})" + ) + + log_prob_list: List[Tensor] = list(map(torch.tensor, target_token_probs)) + + total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0) + # 1st element is the total prob, rest are the target tokens + # add a leading dim for batch even we only support single instance for now + if self.include_per_token_attr: + target_log_probs = torch.stack( + [total_log_prob, *log_prob_list], dim=0 + ).unsqueeze(0) + else: + target_log_probs = total_log_prob + target_probs = torch.exp(target_log_probs) + + if _inspect_forward: + prompt = perturbed_input + response = self.tokenizer.decode(target_tokens) + + # callback for externals to inspect (prompt, response, seq_prob) + _inspect_forward(prompt, response, target_probs[0].tolist()) + + return target_probs if self.attr_target != "log_prob" else target_log_probs \ No newline at end of file diff --git a/captum/attr/_core/remote_provider.py b/captum/attr/_core/remote_provider.py new file mode 100644 index 0000000000..9d5a0461a9 --- /dev/null +++ b/captum/attr/_core/remote_provider.py @@ -0,0 +1,93 @@ +from abc import ABC, abstractmethod +from typing import Any, List, Optional +from captum._utils.typing import TokenizerLike +from openai import OpenAI +import os + +class RemoteLLMProvider(ABC): + """All remote LLM providers that offer logprob via API (like vLLM) extends this class.""" + + api_url: str + + @abstractmethod + def generate( + self, + prompt: str, + **gen_args: Any + ) -> str: + """ + Args: + prompt: The input prompt to generate from + gen_args: Additional generation arguments + + Returns: + The generated text. + """ + ... + + @abstractmethod + def get_logprobs( + self, + input_prompt: str, + target_str: str, + tokenizer: Optional[TokenizerLike] = None + ) -> List[float]: + """ + Get the log probabilities for all tokens in the target string. + + Args: + input_prompt: The input prompt + target_str: The target string + tokenizer: The tokenizer to use + + Returns: + A list of log probabilities corresponding to each token in the target prompt. + For a `target_str` of `t` tokens, this method returns a list of logprobs of length `k`. + """ + ... + +class VLLMProvider(RemoteLLMProvider): + def __init__(self, api_url: str): + assert api_url.strip() != "", "API URL is required" + + self.api_url = api_url + self.client = OpenAI(base_url=self.api_url, + api_key=os.getenv("OPENAI_API_KEY", "EMPTY") + ) + self.model_name = self.client.models.list().data[0].id + + + def generate(self, prompt: str, **gen_args: Any) -> str: + if not 'max_tokens' in gen_args: + gen_args['max_tokens'] = gen_args.pop('max_new_tokens', 25) + if 'do_sample' in gen_args: + gen_args.pop('do_sample') + + response = self.client.completions.create( + model=self.model_name, + prompt=prompt, + **gen_args + ) + + return response.choices[0].text + + def get_logprobs(self, input_prompt: str, target_str: str, tokenizer: Optional[TokenizerLike] = None) -> List[float]: + assert tokenizer is not None, "Tokenizer is required for VLLM provider" + + num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False)) + + prompt = input_prompt + target_str + + response = self.client.completions.create( + model=self.model_name, + prompt=prompt, + temperature=0.0, + max_tokens=1, + extra_body={"prompt_logprobs": 0} + ) + prompt_logprobs = [] + for probs in response.choices[0].prompt_logprobs[1:]: + prompt_logprobs.append(list(probs.values())[0]['logprob']) + + return prompt_logprobs[-num_target_str_tokens:] + \ No newline at end of file diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index 38cb97d5b3..2f473b5327 --- a/setup.py +++ b/setup.py @@ -63,9 +63,12 @@ def report(*args): TEST_REQUIRES = ["pytest", "pytest-cov", "parameterized", "flask", "flask-compress"] +REMOTE_REQUIRES = ["openai"] + DEV_REQUIRES = ( INSIGHTS_REQUIRES + TEST_REQUIRES + + REMOTE_REQUIRES + [ "black", "flake8", @@ -169,6 +172,7 @@ def get_package_files(root, subdirs): "insights": INSIGHTS_REQUIRES, "test": TEST_REQUIRES, "tutorials": TUTORIALS_REQUIRES, + "remote": REMOTE_REQUIRES, }, package_data={"captum": package_files}, data_files=[ diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index d6f1a2a4ea..1f557e8e25 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -21,14 +21,15 @@ import torch from captum._utils.models.linear_model import SkLearnLasso -from captum._utils.typing import BatchEncodingType +from captum._utils.typing import BatchEncodingType, TokenizerLike from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.kernel_shap import KernelShap from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients from captum.attr._core.lime import Lime -from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution +from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution, RemoteLLMAttribution +from captum.attr._core.remote_provider import RemoteLLMProvider from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._utils.attribution import GradientAttribution, PerturbationAttribution from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput @@ -669,3 +670,362 @@ def test_llm_attr_with_skip_tensor_target(self) -> None: self.assertEqual(token_attr.shape, (5, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + +class DummyRemoteLLMProvider(RemoteLLMProvider): + def __init__(self, deterministic_logprobs: bool = False) -> None: + self.api_url = "https://test-api.com" + self.deterministic_logprobs = deterministic_logprobs + + def generate(self, prompt: str, **gen_args: Any) -> str: + assert "mock_response" in gen_args, "must mock response to use DummyRemoteLLMProvider to generate" + return gen_args["mock_response"] + + def get_logprobs(self, input_prompt: str, target_str: str, tokenizer: Optional[TokenizerLike] = None) -> List[float]: + assert tokenizer is not None, "Tokenizer is required" + prompt = input_prompt + target_str + tokens = tokenizer.encode(prompt, add_special_tokens=False) + num_tokens = len(tokens) + + num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False)) + + logprobs = [] + + for i in range(num_tokens): + # Start with a base value + logprob = -0.1 - (0.01 * i) + + # Make sensitive to key features + if "a" not in prompt: + logprob -= 0.1 + if "c" not in prompt: + logprob -= 0.2 + if "d" not in prompt: + logprob -= 0.3 + if "f" not in prompt: + logprob -= 0.4 + + logprobs.append(logprob) + + return logprobs[-num_target_str_tokens:] + +@parameterized_class( + ("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)] +) +class TestRemoteLLMAttr(BaseTest): + # pyre-fixme[13]: Attribute `device` is never initialized. + device: str + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @parameterized.expand( + [ + ( + AttrClass, + delta, + n_samples, + torch.tensor(true_seq_attr), + torch.tensor(true_tok_attr), + ) + for AttrClass, delta, n_samples, true_seq_attr, true_tok_attr in zip( + (FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass + (0.001, 0.001, 0.001), # delta + (None, 1000, None), # n_samples + ( # true_seq_attr + [0.5, 1.0, 1.5, 2.0], # FeatureAblation + [0.5, 1.0, 1.5, 2.0], # ShapleyValueSampling + [0.5, 1.0, 1.5, 2.0], # ShapleyValues + ), + ( # true_tok_attr + [ # FeatureAblation + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + ], + [ # ShapleyValueSampling + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + ], + [ # ShapleyValues + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + ], + ), + ) + ] + ) + def test_remote_llm_attr( + self, + AttrClass: Type[PerturbationAttribution], + delta: float, + n_samples: Optional[int], + true_seq_attr: Tensor, + true_tok_attr: Tensor, + ) -> None: + attr_kws: Dict[str, int] = {} + if n_samples is not None: + attr_kws["n_samples"] = n_samples + + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = AttrClass(placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + "m n o p q", + skip_tokens=[0], + # use_cached_outputs=self.use_cached_outputs, + # pyre-fixme[6]: In call `LLMAttribution.attribute`, + # for 4th positional argument, expected + # `Optional[typing.Callable[..., typing.Any]]` but got `int`. + **attr_kws, # type: ignore + ) + + self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4)) + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) + self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) + + assertTensorAlmostEqual( + self, + actual=res.seq_attr, + expected=true_seq_attr, + delta=delta, + mode="max", + ) + assertTensorAlmostEqual( + self, + actual=res.token_attr, + expected=true_tok_attr, + delta=delta, + mode="max", + ) + + def test_remote_llm_attr_without_target(self) -> None: + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + gen_args={"mock_response": "x y z"}, + # use_cached_outputs=self.use_cached_outputs, + ) + + self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(cast(Tensor, res.token_attr).shape, (3, 4)) + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) + self.assertEqual(res.output_tokens, ["x", "y", "z"]) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) + + def test_remote_llm_attr_fa_log_prob(self) -> None: + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + attr_target="log_prob", + ) + + # from TestLLMAttr + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + "m n o p q", + skip_tokens=[0], + # use_cached_outputs=self.use_cached_outputs, + ) + + # With FeatureAblation, the seq attr in log_prob + # equals to the sum of each token attr + assertTensorAlmostEqual(self, res.seq_attr, cast(Tensor, res.token_attr).sum(0)) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @parameterized.expand( + [ + ( + AttrClass, + delta, + n_samples, + torch.tensor(true_seq_attr), + interpretable_model, + ) + for AttrClass, delta, n_samples, true_seq_attr, interpretable_model in zip( + (Lime, KernelShap), + (0.003, 0.001), + (1000, 2500), + ( + [0.4956, 0.9957, 1.4959, 1.9959], + [0.5, 1.0, 1.5, 2.0], + ), + (SkLearnLasso(alpha=0.001), None), + ) + ] + ) + def test_remote_llm_attr_without_token( + self, + AttrClass: Type[PerturbationAttribution], + delta: float, + n_samples: int, + true_seq_attr: Tensor, + interpretable_model: Optional[nn.Module] = None, + ) -> None: + init_kws = {} + if interpretable_model is not None: + init_kws["interpretable_model"] = interpretable_model + attr_kws: Dict[str, int] = {} + if n_samples is not None: + attr_kws["n_samples"] = n_samples + + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = AttrClass(placeholder_model, **init_kws) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + attr_target="log_prob", + ) + + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + "m n o p q", + skip_tokens=[0], + # use_cached_outputs=self.use_cached_outputs, + **attr_kws, # type: ignore + ) + + self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(res.token_attr, None) + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) + self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + assertTensorAlmostEqual( + self, + actual=res.seq_attr, + expected=true_seq_attr, + delta=delta, + mode="max", + ) + def test_remote_llm_attr_futures_not_implemented(self) -> None: + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider() + attr_method = FeatureAblation(placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + attributions = None + with self.assertRaises(NotImplementedError): + attributions = remote_llm_attr.attribute_future() + self.assertEqual(attributions, None) + + def test_remote_llm_attr_with_no_skip_tokens(self) -> None: + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(placeholder_model) + remote_llm_fa = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTokenInput("a b c", tokenizer) + res = remote_llm_fa.attribute( + inp, + "m n o p q" + ) + + # 5 output tokens, 4 input tokens including sos + self.assertEqual(res.seq_attr.shape, (4,)) + assert res.token_attr is not None + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (6, 4)) + self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) + self.assertEqual(res.output_tokens, ["", "m", "n", "o", "p", "q"]) + + def test_remote_llm_attr_with_skip_tensor_target(self) -> None: + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(placeholder_model) + remote_llm_fa = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTokenInput("a b c", tokenizer) + res = remote_llm_fa.attribute( + inp, + torch.tensor(tokenizer.encode("m n o p q")), + skip_tokens=[0], + ) + + # 5 output tokens, 4 input tokens including sos + self.assertEqual(res.seq_attr.shape, (4,)) + assert res.token_attr is not None + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (5, 4)) + self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) + self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) From e6c929bd160e5685241614767d15b6d37dd161a0 Mon Sep 17 00:00:00 2001 From: saichandrapandraju Date: Thu, 24 Apr 2025 18:55:16 -0400 Subject: [PATCH 2/2] add optional 'model_name' for VLLMProvider and add better exception handling --- captum/attr/_core/llm_attr.py | 13 ++- captum/attr/_core/remote_provider.py | 152 ++++++++++++++++++++++----- 2 files changed, 133 insertions(+), 32 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 772b06838f..202c2bbdc9 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -975,18 +975,21 @@ def _remote_forward_func( ) -> Tensor: """ Forward function for the remote LLM provider. + + Raises: + ValueError: If the number of token logprobs doesn't match expected length """ - perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) target_str:str = self.tokenizer.decode(target_tokens) target_token_probs = self.provider.get_logprobs(input_prompt=perturbed_input, target_str=target_str, tokenizer=self.tokenizer) - assert len(target_token_probs) == target_tokens.size()[0], ( - f"Number of token logprobs from provider ({len(target_token_probs)}) " - f"does not match expected target token length ({target_tokens.size()[0]})" - ) + if len(target_token_probs) != target_tokens.size()[0]: + raise ValueError( + f"Number of token logprobs from provider ({len(target_token_probs)}) " + f"does not match expected target token length ({target_tokens.size()[0]})" + ) log_prob_list: List[Tensor] = list(map(torch.tensor, target_token_probs)) diff --git a/captum/attr/_core/remote_provider.py b/captum/attr/_core/remote_provider.py index 9d5a0461a9..149337b962 100644 --- a/captum/attr/_core/remote_provider.py +++ b/captum/attr/_core/remote_provider.py @@ -47,47 +47,145 @@ def get_logprobs( ... class VLLMProvider(RemoteLLMProvider): - def __init__(self, api_url: str): - assert api_url.strip() != "", "API URL is required" + def __init__(self, api_url: str, model_name: Optional[str] = None): + """ + Initialize a vLLM provider. + + Args: + api_url: The URL of the vLLM API + model_name: The name of the model to use. If None, the first model from + the API's model list will be used. + Raises: + ValueError: If api_url is empty or model_name is not in the API's model list + ConnectionError: If API connection fails + """ + if not api_url.strip(): + raise ValueError("API URL is required") + self.api_url = api_url - self.client = OpenAI(base_url=self.api_url, + + try: + self.client = OpenAI(base_url=self.api_url, api_key=os.getenv("OPENAI_API_KEY", "EMPTY") ) - self.model_name = self.client.models.list().data[0].id - + + # If model_name is not provided, get the first available model from the API + if model_name is None: + models = self.client.models.list().data + if not models: + raise ValueError("No models available from the vLLM API") + self.model_name = models[0].id + else: + self.model_name = model_name + + except ConnectionError as e: + raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}") + except Exception as e: + raise Exception(f"Unexpected error while initializing vLLM provider: {str(e)}") def generate(self, prompt: str, **gen_args: Any) -> str: - if not 'max_tokens' in gen_args: + """ + Generate text using the vLLM API. + + Args: + prompt: The input prompt for text generation + **gen_args: Additional generation arguments + + Returns: + str: The generated text + + Raises: + KeyError: If API response is missing expected data + ConnectionError: If connection to API fails + """ + # Parameter normalization + if 'max_tokens' not in gen_args: gen_args['max_tokens'] = gen_args.pop('max_new_tokens', 25) if 'do_sample' in gen_args: gen_args.pop('do_sample') + + try: + response = self.client.completions.create( + model=self.model_name, + prompt=prompt, + **gen_args + ) + if not hasattr(response, 'choices') or not response.choices: + raise KeyError("API response missing expected 'choices' data") + + return response.choices[0].text - response = self.client.completions.create( - model=self.model_name, - prompt=prompt, - **gen_args - ) - - return response.choices[0].text + except ConnectionError as e: + raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}") + except Exception as e: + raise Exception(f"Unexpected error during text generation: {str(e)}") - def get_logprobs(self, input_prompt: str, target_str: str, tokenizer: Optional[TokenizerLike] = None) -> List[float]: - assert tokenizer is not None, "Tokenizer is required for VLLM provider" + def get_logprobs( + self, + input_prompt: str, + target_str: str, + tokenizer: Optional[TokenizerLike] = None + ) -> List[float]: + """ + Get the log probabilities for all tokens in the target string. + + Args: + input_prompt: The input prompt + target_str: The target string + tokenizer: The tokenizer to use + + Returns: + A list of log probabilities corresponding to each token in the target prompt. + For a `target_str` of `t` tokens, this method returns a list of logprobs of length `k`. + + Raises: + ValueError: If tokenizer is None or target_str is empty or response format is invalid + KeyError: If API response is missing expected data + IndexError: If response format is unexpected + ConnectionError: If connection to API fails + """ + if tokenizer is None: + raise ValueError("Tokenizer is required for vLLM provider") + if not target_str: + raise ValueError("Target string cannot be empty") num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False)) prompt = input_prompt + target_str + + try: + response = self.client.completions.create( + model=self.model_name, + prompt=prompt, + temperature=0.0, + max_tokens=1, + extra_body={"prompt_logprobs": 0} + ) + + if not hasattr(response, 'choices') or not response.choices: + raise KeyError("API response missing expected 'choices' data") + + if not hasattr(response.choices[0], 'prompt_logprobs'): + raise KeyError("API response missing 'prompt_logprobs' data") + + prompt_logprobs = [] + try: + for probs in response.choices[0].prompt_logprobs[1:]: + if not probs: + raise ValueError("Empty probability data in API response") + prompt_logprobs.append(list(probs.values())[0]['logprob']) + except (IndexError, KeyError) as e: + raise IndexError(f"Unexpected format in log probability data: {str(e)}") + + if len(prompt_logprobs) < num_target_str_tokens: + raise ValueError(f"Not enough logprobs received: expected {num_target_str_tokens}, got {len(prompt_logprobs)}") + + return prompt_logprobs[-num_target_str_tokens:] - response = self.client.completions.create( - model=self.model_name, - prompt=prompt, - temperature=0.0, - max_tokens=1, - extra_body={"prompt_logprobs": 0} - ) - prompt_logprobs = [] - for probs in response.choices[0].prompt_logprobs[1:]: - prompt_logprobs.append(list(probs.values())[0]['logprob']) - - return prompt_logprobs[-num_target_str_tokens:] + except ConnectionError as e: + raise ConnectionError(f"Failed to connect to vLLM API when getting logprobs: {str(e)}") + except Exception as e: + raise Exception(f"Unexpected error while getting log probabilities: {str(e)}") + \ No newline at end of file