From d39be4c0084e22dbd650d6c1e54be2c03b19c3de Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Fri, 4 Apr 2025 02:03:46 +0700 Subject: [PATCH 01/18] feat: enhance deduplication rules provisioning with provider support and environment configuration --- .../deduplication_rules_provisioning.py | 44 +- keep/api/config.py | 6 - keep/api/consts.py | 3 + keep/providers/providers_service.py | 526 ++++++++++++------ .../test_deduplications_provisioning.py | 227 -------- tests/test_providers_yaml_provisioning.py | 214 ++++++- tests/test_provisioning.py | 225 ++++++++ 7 files changed, 780 insertions(+), 465 deletions(-) delete mode 100644 tests/deduplication/test_deduplications_provisioning.py diff --git a/keep/api/alert_deduplicator/deduplication_rules_provisioning.py b/keep/api/alert_deduplicator/deduplication_rules_provisioning.py index 4b26d3b98c..d89815f662 100644 --- a/keep/api/alert_deduplicator/deduplication_rules_provisioning.py +++ b/keep/api/alert_deduplicator/deduplication_rules_provisioning.py @@ -4,21 +4,24 @@ import keep.api.core.db as db from keep.api.core.config import config -from keep.providers.providers_factory import ProvidersFactory +from keep.api.models.db.provider import Provider logger = logging.getLogger(__name__) -def provision_deduplication_rules(deduplication_rules: dict[str, any], tenant_id: str): +def provision_deduplication_rules( + deduplication_rules: dict[str, any], tenant_id: str, provider: Provider +): """ Provisions deduplication rules for a given tenant. Args: deduplication_rules (dict[str, any]): A dictionary where the keys are rule names and the values are - DeduplicationRuleRequestDto objects. + DeduplicationRuleRequestDto objects. tenant_id (str): The ID of the tenant for which deduplication rules are being provisioned. + provider (Provider): The provider for which the deduplication rules are being provisioned. """ - enrich_with_providers_info(deduplication_rules, tenant_id) + enrich_with_providers_info(deduplication_rules, provider) all_deduplication_rules_from_db = db.get_all_deduplication_rules(tenant_id) provisioned_deduplication_rules = [ @@ -98,46 +101,17 @@ def provision_deduplication_rules(deduplication_rules: dict[str, any], tenant_id ) -def provision_deduplication_rules_from_env(tenant_id: str): - """ - Provisions deduplication rules from environment variables for a given tenant. - This function reads deduplication rules from environment variables, validates them, - and then provisions them into the database. It handles the following: - - Deletes deduplication rules from the database that are not present in the environment variables. - - Updates existing deduplication rules in the database if they are present in the environment variables. - - Creates new deduplication rules in the database if they are not already present. - Args: - tenant_id (str): The ID of the tenant for which deduplication rules are being provisioned. - Raises: - ValueError: If the deduplication rules from the environment variables are invalid. - """ - - deduplication_rules_from_env_dict = get_deduplication_rules_to_provision() - - if not deduplication_rules_from_env_dict: - logger.info("No deduplication rules found in env. Nothing to provision.") - return - - provision_deduplication_rules(deduplication_rules_from_env_dict, tenant_id) - - -def enrich_with_providers_info(deduplication_rules: dict[str, any], tenant_id: str): +def enrich_with_providers_info(deduplication_rules: dict[str, any], provider: Provider): """ Enriches passed deduplication rules with provider ID and type information. Args: deduplication_rules (dict[str, any]): A list of deduplication rules to be enriched. - tenant_id (str): The ID of the tenant for which deduplication rules are being provisioned. + provider (Provider): The provider for which the deduplication rules are being provisioned. """ - installed_providers = ProvidersFactory.get_installed_providers(tenant_id) - installed_providers_dict = { - provider.details.get("name"): provider for provider in installed_providers - } - for rule_name, rule in deduplication_rules.items(): logger.info(f"Enriching deduplication rule: {rule_name}") - provider = installed_providers_dict.get(rule.get("provider_name")) rule["provider_id"] = provider.id rule["provider_type"] = provider.type diff --git a/keep/api/config.py b/keep/api/config.py index 98a16dd4b3..13f136afda 100644 --- a/keep/api/config.py +++ b/keep/api/config.py @@ -2,9 +2,6 @@ import os import keep.api.logging -from keep.api.alert_deduplicator.deduplication_rules_provisioning import ( - provision_deduplication_rules_from_env, -) from keep.api.api import AUTH_TYPE from keep.api.core.db_on_start import migrate_db, try_create_single_tenant from keep.api.core.dependencies import SINGLE_TENANT_UUID @@ -33,9 +30,6 @@ def provision_resources(): logger.info("Workflows provisioned successfully") provision_dashboards(SINGLE_TENANT_UUID) logger.info("Dashboards provisioned successfully") - logger.info("Provisioning deduplication rules") - provision_deduplication_rules_from_env(SINGLE_TENANT_UUID) - logger.info("Deduplication rules provisioned successfully") else: logger.info("Provisioning resources is disabled") diff --git a/keep/api/consts.py b/keep/api/consts.py index d64b494010..0973a8d6eb 100644 --- a/keep/api/consts.py +++ b/keep/api/consts.py @@ -41,6 +41,9 @@ KEEP_ARQ_QUEUE_WORKFLOWS = "workflows" REDIS = os.environ.get("REDIS", "false") == "true" +REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") +REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) +REDIS_DB = int(os.environ.get("REDIS_DB", 0)) if REDIS: KEEP_ARQ_TASK_POOL = os.environ.get("KEEP_ARQ_TASK_POOL", KEEP_ARQ_TASK_POOL_ALL) diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py index bc889dd109..c18cf0f589 100644 --- a/keep/providers/providers_service.py +++ b/keep/providers/providers_service.py @@ -1,10 +1,12 @@ +import hashlib import json import logging import os import time import uuid -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional +import redis from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select @@ -12,6 +14,7 @@ from keep.api.alert_deduplicator.deduplication_rules_provisioning import ( provision_deduplication_rules, ) +from keep.api.consts import REDIS, REDIS_DB, REDIS_HOST, REDIS_PORT from keep.api.core.config import config from keep.api.core.db import ( engine, @@ -28,6 +31,9 @@ from keep.providers.providers_factory import ProvidersFactory from keep.secretmanager.secretmanagerfactory import SecretManagerFactory +DEFAULT_PROVIDER_HASH_STATE_FILE = "/state/{tenant_id}_providers_hash.txt" + + logger = logging.getLogger(__name__) @@ -150,6 +156,8 @@ def install_provider( provisioned: bool = False, validate_scopes: bool = True, pulling_enabled: bool = True, + session: Optional[Session] = None, + commit: bool = True, ) -> Dict[str, Any]: provider_unique_id = uuid.uuid4().hex logger.info( @@ -186,56 +194,63 @@ def install_provider( secret_value=json.dumps(config), ) - with Session(engine) as session: - provider_model = Provider( - id=provider_unique_id, - tenant_id=tenant_id, - name=provider_name, - type=provider_type, - installed_by=installed_by, - installation_time=time.time(), - configuration_key=secret_name, - validatedScopes=validated_scopes, - consumer=provider.is_consumer, - provisioned=provisioned, - pulling_enabled=pulling_enabled, - ) - try: - session.add(provider_model) + session_managed = False + if not session: + session = Session(engine) + session_managed = True + + provider_model = Provider( + id=provider_unique_id, + tenant_id=tenant_id, + name=provider_name, + type=provider_type, + installed_by=installed_by, + installation_time=time.time(), + configuration_key=secret_name, + validatedScopes=validated_scopes, + consumer=provider.is_consumer, + provisioned=provisioned, + pulling_enabled=pulling_enabled, + ) + try: + session.add(provider_model) + if commit: session.commit() - except IntegrityError as e: - if "FOREIGN KEY constraint" in str(e): - raise - try: - # if the provider is already installed, delete the secret - logger.warning( - "Provider already installed, deleting secret", - extra={"error": str(e)}, - ) - secret_manager.delete_secret( - secret_name=secret_name, - ) - logger.warning("Secret deleted") - except Exception: - logger.exception("Failed to delete the secret") - pass - raise HTTPException( - status_code=409, detail="Provider already installed" + except IntegrityError as e: + if "FOREIGN KEY constraint" in str(e): + raise + try: + # if the provider is already installed, delete the secret + logger.warning( + "Provider already installed, deleting secret", + extra={"error": str(e)}, ) + secret_manager.delete_secret( + secret_name=secret_name, + ) + logger.warning("Secret deleted") + except Exception: + logger.exception("Failed to delete the secret") + pass + raise HTTPException(status_code=409, detail="Provider already installed") + finally: + if session_managed: + session.close() - if provider_model.consumer: - try: - event_subscriber = EventSubscriber.get_instance() - event_subscriber.add_consumer(provider) - except Exception: - logger.exception("Failed to register provider as a consumer") + if provider_model.consumer: + try: + event_subscriber = EventSubscriber.get_instance() + event_subscriber.add_consumer(provider) + except Exception: + logger.exception("Failed to register provider as a consumer") - return { - "type": provider_type, - "id": provider_unique_id, - "details": config, - "validatedScopes": validated_scopes, - } + return { + "provider": provider_model, + "type": provider_type, + "id": provider_unique_id, + "details": config, + "validatedScopes": validated_scopes, + } @staticmethod def update_provider( @@ -296,7 +311,11 @@ def update_provider( @staticmethod def delete_provider( - tenant_id: str, provider_id: str, session: Session, allow_provisioned=False + tenant_id: str, + provider_id: str, + session: Session, + allow_provisioned=False, + commit: bool = True, ): provider_model: Provider = session.exec( select(Provider).where( @@ -342,7 +361,8 @@ def delete_provider( logger.exception(msg="Provider deleted but failed to clean up provider") session.delete(provider_model) - session.commit() + if commit: + session.commit() @staticmethod def validate_provider_scopes( @@ -378,6 +398,130 @@ def is_provider_installed(tenant_id: str, provider_name: str) -> bool: provider = get_provider_by_name(tenant_id, provider_name) return provider is not None + @staticmethod + def provision_provider_deduplication_rules( + tenant_id: str, + provider: Provider, + deduplication_rules: Dict[str, Dict[str, Any]], + ): + """ + Provision deduplication rules for a provider. + + Args: + tenant_id (str): The tenant ID. + provider (Provider): The provider to provision the deduplication rules for. + deduplication_rules (Dict[str, Dict[str, Any]]): The deduplication rules to provision. + """ + + # Provision the deduplication rules + deduplication_rules_dict: dict[str, dict] = {} + for rule_name, rule_config in deduplication_rules.items(): + logger.info(f"Provisioning deduplication rule {rule_name}") + rule_config["name"] = rule_name + rule_config["provider_name"] = provider.name + rule_config["provider_type"] = provider.type + deduplication_rules_dict[rule_name] = rule_config + + # Provision deduplication rules + provision_deduplication_rules( + deduplication_rules=deduplication_rules_dict, + tenant_id=tenant_id, + provider=provider, + ) + + @staticmethod + def write_provisioned_hash(tenant_id: str, hash_value: str): + """ + Write the provisioned hash to Redis or file. + + Args: + tenant_id (str): The tenant ID. + hash_value (str): The hash value to write. + """ + if REDIS: + r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) + r.set(f"{tenant_id}_providers_hash", hash_value) + logger.info(f"Provisioned hash for tenant {tenant_id} written to Redis!") + else: + with open( + DEFAULT_PROVIDER_HASH_STATE_FILE.format(tenant_id=tenant_id), "w" + ) as f: + f.write(hash_value) + logger.info(f"Provisioned hash for tenant {tenant_id} written to file!") + + @staticmethod + def get_provisioned_hash(tenant_id: str) -> Optional[str]: + """ + Get the provisioned hash from Redis or file. + + Args: + tenant_id (str): The tenant ID. + + Returns: + Optional[str]: The provisioned hash, or None if not found. + """ + previous_hash = None + if REDIS: + try: + with redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) as r: + previous_hash = r.get(f"{tenant_id}_providers_hash") + if isinstance(previous_hash, bytes): + previous_hash = previous_hash.decode("utf-8").strip() + logger.info( + f"Provisioned hash for tenant {tenant_id}: {previous_hash or 'Not found'}" + ) + except redis.RedisError as e: + logger.warning(f"Redis error for tenant {tenant_id}: {e}") + + if previous_hash is None: + try: + with open( + DEFAULT_PROVIDER_HASH_STATE_FILE.format(tenant_id=tenant_id), + "r", + encoding="utf-8", + ) as f: + previous_hash = f.read().strip() + logger.info( + f"Provisioned hash for tenant {tenant_id} read from file: {previous_hash}" + ) + except FileNotFoundError: + logger.info(f"Provisioned hash file for tenant {tenant_id} not found.") + except Exception as e: + logger.warning( + f"Failed to read hash from file for tenant {tenant_id}: {e}" + ) + + return previous_hash if previous_hash else None + + @staticmethod + def calculate_provider_hash( + provisioned_providers_dir: Optional[str] = None, + provisioned_providers_json: Optional[str] = None, + ) -> str: + """ + Calculate the hash of the provider configurations. + + Args: + provisioned_providers_dir (Optional[str]): Directory containing provider YAML files. + provisioned_providers_json (Optional[str]): JSON string of provider configurations. + + Returns: + str: SHA256 hash of the provider configurations. + """ + if provisioned_providers_json: + providers_data = provisioned_providers_json + elif provisioned_providers_dir: + providers_data = [] + for file in os.listdir(provisioned_providers_dir): + if file.endswith((".yaml", ".yml")): + provider_path = os.path.join(provisioned_providers_dir, file) + with open(provider_path, "r") as yaml_file: + providers_data.append(yaml_file.read()) + else: + providers_data = "" # No providers to provision + + return hashlib.sha256(json.dumps(providers_data).encode("utf-8")).hexdigest() + @staticmethod def provision_providers(tenant_id: str): """ @@ -391,10 +535,6 @@ def provision_providers(tenant_id: str): provisioned_providers_dir = os.environ.get("KEEP_PROVIDERS_DIRECTORY") provisioned_providers_json = os.environ.get("KEEP_PROVIDERS") - if not (provisioned_providers_dir or provisioned_providers_json): - logger.info("No providers for provisioning found") - return - if ( provisioned_providers_dir is not None and provisioned_providers_json is not None @@ -413,147 +553,185 @@ def provision_providers(tenant_id: str): # Get all existing provisioned providers provisioned_providers = get_all_provisioned_providers(tenant_id) - ### Provisioning from env var - if provisioned_providers_json is not None: - # Avoid circular import - from keep.parser.parser import Parser + if not (provisioned_providers_dir or provisioned_providers_json): + if provisioned_providers: + logger.info( + "No providers for provisioning found. Deleting all provisioned providers." + ) + else: + logger.info("No providers for provisioning found. Nothing to do.") + return - parser = Parser() - context_manager = ContextManager(tenant_id=tenant_id) - parser._parse_providers_from_env(context_manager) - env_providers = context_manager.providers_context + # Calculate the hash of the provider configurations + providers_hash = ProvidersService.calculate_provider_hash( + provisioned_providers_dir, provisioned_providers_json + ) - # Un-provisioning other providers. - for provider in provisioned_providers: - if provider.name not in env_providers: - with Session(engine) as session: - try: - logger.info(f"Deleting provider {provider.name}") - ProvidersService.delete_provider( - tenant_id, provider.id, session, allow_provisioned=True - ) - logger.info(f"Provider {provider.name} deleted") - except Exception as e: - logger.exception( - "Failed to delete provisioned provider that does not exist in the env var", - extra={"exception": e}, - ) - continue - - for provider_name, provider_config in env_providers.items(): - logger.info(f"Provisioning provider {provider_name}") - if ProvidersService.is_provider_installed(tenant_id, provider_name): - logger.info(f"Provider {provider_name} already installed") - continue - - logger.info(f"Installing provider {provider_name}") - try: - ProvidersService.install_provider( - tenant_id=tenant_id, - installed_by="system", - provider_id=provider_config["type"], - provider_name=provider_name, - provider_type=provider_config["type"], - provider_config=provider_config["authentication"], - provisioned=True, - validate_scopes=False, - ) - logger.info(f"Provider {provider_name} provisioned successfully") - except Exception as e: - logger.error( - "Error provisioning provider from env var", - extra={"exception": e}, - ) - - ### Provisioning from the directory - if provisioned_providers_dir is not None: - installed_providers = [] - for file in os.listdir(provisioned_providers_dir): - if file.endswith((".yaml", ".yml")): - logger.info(f"Provisioning provider from {file}") - provider_path = os.path.join(provisioned_providers_dir, file) + # Get the previous hash from Redis or file + previous_hash = ProvidersService.get_provisioned_hash(tenant_id) + if providers_hash == previous_hash: + logger.info( + "Provider configurations have not changed. Skipping provisioning." + ) + return + else: + logger.info("Provider configurations have changed. Provisioning providers.") + # Do all the provisioning within a transaction + session = Session(engine) + try: + with session.begin(): + ### We do delete all the provisioned providers and begin provisioning from the beginning. + logger.info( + f"Deleting all provisioned providers for tenant {tenant_id}" + ) + for provisioned_provider in provisioned_providers: try: - with open(provider_path, "r") as yaml_file: - provider_yaml = cyaml.safe_load(yaml_file.read()) - provider_name = provider_yaml["name"] - provider_type = provider_yaml["type"] - provider_config = provider_yaml.get("authentication", {}) - - # Skip if already installed - if ProvidersService.is_provider_installed( - tenant_id, provider_name - ): - logger.info( - f"Provider {provider_name} already installed" - ) - # Add to installed providers list. This is necessary, otherwise the provider - # will be un-provisioned on the process un-provisioning outdated providers. - installed_providers.append(provider_name) - continue + logger.info(f"Deleting provider {provisioned_provider.name}") + ProvidersService.delete_provider( + tenant_id, + provisioned_provider.id, + session, + allow_provisioned=True, + commit=False, + ) + logger.info(f"Provider {provisioned_provider.name} deleted") + except Exception as e: + logger.exception( + "Failed to delete provisioned provider", + extra={"exception": e}, + ) + continue - logger.info(f"Installing provider {provider_name}") - ProvidersService.install_provider( + # Flush the session to ensure all deletions are committed + session.flush() + + ### Provisioning from env var + if provisioned_providers_json is not None: + # Avoid circular import + from keep.parser.parser import Parser + + parser = Parser() + context_manager = ContextManager(tenant_id=tenant_id) + parser._parse_providers_from_env(context_manager) + env_providers = context_manager.providers_context + + for provider_name, provider_config in env_providers.items(): + # We skip checking if the provider is already installed, as it will skip the new configurations + # and we want to update the provisioned provider with the new configuration + logger.info(f"Provisioning provider {provider_name}") + try: + installed_provider_info = ProvidersService.install_provider( tenant_id=tenant_id, installed_by="system", - provider_id=provider_type, + provider_id=provider_config["type"], provider_name=provider_name, - provider_type=provider_type, - provider_config=provider_config, + provider_type=provider_config["type"], + provider_config=provider_config["authentication"], provisioned=True, validate_scopes=False, + session=session, + commit=False, ) + provider = installed_provider_info["provider"] logger.info( f"Provider {provider_name} provisioned successfully" ) - installed_providers.append(provider_name) + except Exception as e: + logger.error( + "Error provisioning provider from env var", + extra={"exception": e}, + ) + + # Flush the provider so that we can provision its deduplication rules + session.flush() + + # Configure deduplication rules + deduplication_rules = provider_config.get( + "deduplication_rules", {} + ) + if deduplication_rules: + logger.info( + f"Provisioning deduplication rules for provider {provider_name}" + ) + ProvidersService.provision_provider_deduplication_rules( + tenant_id=tenant_id, + provider=provider, + deduplication_rules=deduplication_rules, + ) - # Configure deduplication rules - deduplication_rules = provider_yaml.get( - "deduplication_rules", {} + ### Provisioning from the directory + if provisioned_providers_dir is not None: + for file in os.listdir(provisioned_providers_dir): + if file.endswith((".yaml", ".yml")): + logger.info(f"Provisioning provider from {file}") + provider_path = os.path.join( + provisioned_providers_dir, file ) - if deduplication_rules: - logger.info( - f"Provisioning deduplication rules for provider {provider_name}" - ) - deduplication_rules_dict: dict[str, dict] = {} - for ( - rule_name, - rule_config, - ) in deduplication_rules.items(): + try: + with open(provider_path, "r") as yaml_file: + provider_yaml = cyaml.safe_load(yaml_file.read()) + provider_name = provider_yaml["name"] + provider_type = provider_yaml["type"] + provider_config = provider_yaml.get( + "authentication", {} + ) + + # We skip checking if the provider is already installed, as it will skip the new configurations + # and we want to update the provisioned provider with the new configuration + logger.info(f"Installing provider {provider_name}") + installed_provider_info = ( + ProvidersService.install_provider( + tenant_id=tenant_id, + installed_by="system", + provider_id=provider_type, + provider_name=provider_name, + provider_type=provider_type, + provider_config=provider_config, + provisioned=True, + validate_scopes=False, + session=session, + commit=False, + ) + ) + provider = installed_provider_info["provider"] logger.info( - f"Provisioning deduplication rule {rule_name}" + f"Provider {provider_name} provisioned successfully" ) - rule_config["name"] = rule_name - rule_config["provider_name"] = provider_name - rule_config["provider_type"] = provider_type - deduplication_rules_dict[rule_name] = rule_config - - # Provision deduplication rules - provision_deduplication_rules( - deduplication_rules=deduplication_rules_dict, - tenant_id=tenant_id, - ) - except Exception as e: - logger.error( - "Error provisioning provider from directory", - extra={"exception": e}, - ) - # Un-provisioning other providers. - for provider in provisioned_providers: - if provider.name not in installed_providers: - with Session(engine) as session: - logger.info( - f"Deprovisioning provider {provider.name} as its file no longer exists or is outside the providers directory" - ) - ProvidersService.delete_provider( - tenant_id, provider.id, session, allow_provisioned=True - ) - logger.info( - f"Provider {provider.name} deprovisioned successfully" - ) + # Flush the provider so that we can provision its deduplication rules + session.flush() + + # Configure deduplication rules + deduplication_rules = provider_yaml.get( + "deduplication_rules", {} + ) + if deduplication_rules: + logger.info( + f"Provisioning deduplication rules for provider {provider_name}" + ) + ProvidersService.provision_provider_deduplication_rules( + tenant_id=tenant_id, + provider=provider, + deduplication_rules=deduplication_rules, + ) + except Exception as e: + logger.error( + "Error provisioning provider from directory", + extra={"exception": e}, + ) + continue + except Exception as e: + logger.error("Provisioning failed, rolling back", extra={"exception": e}) + session.rollback() + finally: + # Store the hash in Redis or file + try: + ProvidersService.write_provisioned_hash(tenant_id, providers_hash) + except Exception as e: + logger.warning(f"Failed to store hash: {e}") + session.close() @staticmethod def get_provider_logs( diff --git a/tests/deduplication/test_deduplications_provisioning.py b/tests/deduplication/test_deduplications_provisioning.py deleted file mode 100644 index edc0736431..0000000000 --- a/tests/deduplication/test_deduplications_provisioning.py +++ /dev/null @@ -1,227 +0,0 @@ -import json -from uuid import UUID -import pytest -from keep.api.alert_deduplicator.deduplication_rules_provisioning import ( - provision_deduplication_rules_from_env, -) -from unittest.mock import patch -from keep.api.models.db.alert import AlertDeduplicationRule -from keep.api.models.provider import Provider - - -@pytest.fixture -def setup(monkeypatch): - providers_in_env_var = { - "Installed Prometheus provider": { - "type": "prometheus", - "deduplication_rules": { - "provisioned fake existing deduplication rule": { - "description": "new description", - "fingerprint_fields": ["source"], - "full_deduplication": True, - "ignore_fields": ["ignore_field"], - } - } - }, - "Installed Grafana provider": { - "type": "grafana", - "deduplication_rules": { - "fake new deduplication rule": { - "description": "fake new deduplication rule description", - "fingerprint_fields": ["fingerprint"], - "full_deduplication": False, - } - } - }, - } - - deduplication_rules_in_db = [ - AlertDeduplicationRule( - id=UUID("f3a2b76c8430491da71684de9cf257ab"), - tenant_id="fake_tenant_id", - name="provisioned fake existing deduplication rule", - description="provisioned fake existing deduplication rule description", - provider_id="edc4d65d53204cefb511321be98f748e", - provider_type="prometheus", - last_updated_by="system", - created_by="system", - fingerprint_fields=["fingerprint", "source", "service"], - full_deduplication=False, - is_provisioned=True, - ), - AlertDeduplicationRule( - id=UUID("a5d8f32b6c7049efb913c21da7e845fd"), - tenant_id="fake_tenant_id", - name="provisioned fake deduplication rule to delete", - description="fake new deduplication rule description", - provider_id="a1b2c3d4e5f64789ab1234567890abcd", - provider_type="grafana", - last_updated_by="system", - created_by="system", - fingerprint_fields=["fingerprint"], - full_deduplication=False, - is_provisioned=True, - ), - AlertDeduplicationRule( - id=UUID("c7e3d28f95104b6a8f12dc45eb7639fa"), - tenant_id="fake_tenant_id", - name="not provisioned fake deduplication rule", - description="not provisioned fake deduplication rule", - provider_id="a1b2c3d4e5f64789ab1234567890abcd", - provider_type="grafana", - last_updated_by="user", - created_by="user", - fingerprint_fields=["fingerprint"], - full_deduplication=False, - is_provisioned=False, - ), - ] - installed_providers = [ - Provider( - id="edc4d65d53204cefb511321be98f748e", - display_name="Prometheus", - type="prometheus", - details={"name": "Installed Prometheus provider"}, - can_query=True, - can_notify=True, - ), - Provider( - id="p2b2c3d4e5f64789ab1234567890abcd", - display_name="Prometheus", - type="prometheus", - details={"name": "Installed Prometheus provider second"}, - can_query=True, - can_notify=True, - ), - Provider( - id="a1b2c3d4e5f64789ab1234567890abcd", - display_name="Grafana", - type="grafana", - details={"name": "Installed Grafana provider"}, - can_query=True, - can_notify=True, - ) - ] - - linked_providers = [ - Provider( - id="abcda1b2c3d4e5f64789ab1234567890", - display_name="Grafana", - type="grafana", - can_query=True, - can_notify=True, - ) - ] - - with patch( - "keep.api.core.db.get_all_deduplication_rules", - return_value=deduplication_rules_in_db, - ) as mock_get_all, patch( - "keep.api.core.db.delete_deduplication_rule", return_value=None - ) as mock_delete, patch( - "keep.api.core.db.update_deduplication_rule", return_value=None - ) as mock_update, patch( - "keep.api.core.db.create_deduplication_rule", return_value=None - ) as mock_create, patch( - "keep.providers.providers_factory.ProvidersFactory.get_installed_providers", - return_value=installed_providers, - ) as mock_get_providers, patch( - "keep.providers.providers_factory.ProvidersFactory.get_linked_providers", - return_value=linked_providers, - ) as mock_get_linked_providers: - - fake_tenant_id = "fake_tenant_id" - monkeypatch.setenv( - "KEEP_PROVIDERS", json.dumps(providers_in_env_var) - ) - - yield { - "mock_get_all": mock_get_all, - "mock_delete": mock_delete, - "mock_update": mock_update, - "mock_create": mock_create, - "mock_get_providers": mock_get_providers, - "mock_get_linked_providers": mock_get_linked_providers, - "fake_tenant_id": fake_tenant_id, - "providers_in_env_var": providers_in_env_var, - "deduplication_rules_in_db": deduplication_rules_in_db, - "linked_providers": linked_providers, - "installed_providers": installed_providers, - } - - -def test_provisioning_of_new_rule(setup): - """ - Test the provisioning of new deduplication rules from the environment. - """ - provision_deduplication_rules_from_env(setup["fake_tenant_id"]) - setup["mock_create"].assert_called_once_with( - tenant_id=setup["fake_tenant_id"], - name="fake new deduplication rule", - description="fake new deduplication rule description", - provider_id="a1b2c3d4e5f64789ab1234567890abcd", - provider_type="grafana", - created_by="system", - enabled=True, - fingerprint_fields=["fingerprint"], - full_deduplication=False, - ignore_fields=[], - priority=0, - is_provisioned=True, - ) - - -def test_provisioning_of_existing_rule(setup): - """ - Test the provisioning of new deduplication rules from the environment. - """ - provision_deduplication_rules_from_env(setup["fake_tenant_id"]) - setup["mock_update"].assert_called_once_with( - tenant_id=setup["fake_tenant_id"], - rule_id=str(UUID("f3a2b76c8430491da71684de9cf257ab")), - name="provisioned fake existing deduplication rule", - description="new description", - provider_id="edc4d65d53204cefb511321be98f748e", - provider_type="prometheus", - last_updated_by="system", - enabled=True, - fingerprint_fields=["source"], - full_deduplication=True, - ignore_fields=["ignore_field"], - priority=0, - ) - - -def test_deletion_of_provisioned_rule_not_in_env(setup): - """ - Test the provisioning of new deduplication rules from the environment. - """ - provision_deduplication_rules_from_env(setup["fake_tenant_id"]) - setup["mock_delete"].assert_called_once_with( - tenant_id=setup["fake_tenant_id"], - rule_id=str(UUID("a5d8f32b6c7049efb913c21da7e845fd")), - ) - -def test_not_throwing_error_if_env_var_empty(setup, monkeypatch): - monkeypatch.setenv( - "KEEP_PROVIDERS", '' - ) - try: - provision_deduplication_rules_from_env(setup["fake_tenant_id"]) - except Exception as e: - pytest.fail(f"provision_deduplication_rules_from_env raised an exception: {e}") - -def test_not_throwing_error_if_providers_do_not_have_dedup_rules(setup, monkeypatch): - providers_in_env_var = { - "Installed Prometheus provider": { - "type": "prometheus" - } - } - - monkeypatch.setenv( - "KEEP_PROVIDERS", json.dumps(providers_in_env_var) - ) - try: - provision_deduplication_rules_from_env(setup["fake_tenant_id"]) - except Exception as e: - pytest.fail(f"provision_deduplication_rules_from_env raised an exception: {e}") diff --git a/tests/test_providers_yaml_provisioning.py b/tests/test_providers_yaml_provisioning.py index 990e1ccbb5..91aa3287c6 100644 --- a/tests/test_providers_yaml_provisioning.py +++ b/tests/test_providers_yaml_provisioning.py @@ -109,29 +109,6 @@ def test_provision_provider_from_yaml(temp_providers_dir, sample_provider_yaml, assert rule["ignore_fields"] == ["name"] -def test_skip_existing_provider(temp_providers_dir, sample_provider_yaml): - """Test that existing providers are skipped during provisioning""" - # Create a YAML file - provider_file = os.path.join(temp_providers_dir, "test_provider.yaml") - with open(provider_file, "w") as f: - f.write(sample_provider_yaml) - - # Mock environment variables - with patch.dict(os.environ, {"KEEP_PROVIDERS_DIRECTORY": temp_providers_dir}): - # Mock database operations to simulate existing provider - with patch( - "keep.providers.providers_service.ProvidersService.is_provider_installed", - return_value=True, - ), patch( - "keep.providers.providers_service.ProvidersService.install_provider" - ) as mock_install: - # Call the provisioning function - ProvidersService.provision_providers("test-tenant") - - # Verify provider installation was not called - mock_install.assert_not_called() - - def test_invalid_yaml_file(temp_providers_dir): """Test handling of invalid YAML files""" # Create an invalid YAML file @@ -183,3 +160,194 @@ def test_missing_required_fields(temp_providers_dir): # Verify provider installation was not called mock_install.assert_not_called() + + +def test_provider_yaml_with_multiple_deduplication_rules(temp_providers_dir, caplog): + """Test provisioning a provider from YAML file with multiple deduplication rules""" + yaml_content = """ +name: test-victoriametrics +type: victoriametrics +authentication: + VMAlertHost: http://localhost + VMAlertPort: 1234 +deduplication_rules: + rule1: + description: First deduplication rule + fingerprint_fields: + - fingerprint + - source + full_deduplication: true + ignore_fields: + - name + rule2: + description: Second deduplication rule + fingerprint_fields: + - alert_id + - service + full_deduplication: false + ignore_fields: + - lastReceived +""" + # Create a YAML file + provider_file = os.path.join(temp_providers_dir, "test_provider.yaml") + with open(provider_file, "w") as f: + f.write(yaml_content) + + # Mock provider + mock_provider = MagicMock( + type="victoriametrics", + id="test-provider-id", + details={ + "name": "test-victoriametrics", + "authentication": {"VMAlertHost": "http://localhost", "VMAlertPort": 1234}, + }, + validatedScopes={}, + ) + + # Mock environment variables and services + with patch.dict(os.environ, {"KEEP_PROVIDERS_DIRECTORY": temp_providers_dir}): + with patch( + "keep.providers.providers_service.ProvidersService.is_provider_installed", + return_value=False, + ), patch( + "keep.providers.providers_service.ProvidersService.install_provider", + return_value=mock_provider, + ) as mock_install, patch( + "keep.providers.providers_service.provision_deduplication_rules" + ) as mock_provision_rules, patch( + "keep.api.core.db.get_all_provisioned_providers", return_value=[] + ): + # Call the provisioning function + ProvidersService.provision_providers("test-tenant") + + # Verify provider installation + mock_install.assert_called_once() + + # Verify deduplication rules provisioning + mock_provision_rules.assert_called_once() + call_args = mock_provision_rules.call_args[1] + assert call_args["tenant_id"] == "test-tenant" + + rules = call_args["deduplication_rules"] + assert len(rules) == 2 + + rule1 = rules["rule1"] + assert rule1["description"] == "First deduplication rule" + assert rule1["fingerprint_fields"] == ["fingerprint", "source"] + assert rule1["full_deduplication"] is True + assert rule1["ignore_fields"] == ["name"] + + rule2 = rules["rule2"] + assert rule2["description"] == "Second deduplication rule" + assert rule2["fingerprint_fields"] == ["alert_id", "service"] + assert rule2["full_deduplication"] is False + assert rule2["ignore_fields"] == ["lastReceived"] + + +def test_provider_yaml_with_empty_deduplication_rules(temp_providers_dir, caplog): + """Test provisioning a provider from YAML file with empty deduplication rules""" + yaml_content = """ +name: test-victoriametrics +type: victoriametrics +authentication: + VMAlertHost: http://localhost + VMAlertPort: 1234 +deduplication_rules: {} +""" + # Create a YAML file + provider_file = os.path.join(temp_providers_dir, "test_provider.yaml") + with open(provider_file, "w") as f: + f.write(yaml_content) + + # Mock provider + mock_provider = MagicMock( + type="victoriametrics", + id="test-provider-id", + details={ + "name": "test-victoriametrics", + "authentication": {"VMAlertHost": "http://localhost", "VMAlertPort": 1234}, + }, + validatedScopes={}, + ) + + # Mock environment variables and services + with patch.dict(os.environ, {"KEEP_PROVIDERS_DIRECTORY": temp_providers_dir}): + with patch( + "keep.providers.providers_service.ProvidersService.is_provider_installed", + return_value=False, + ), patch( + "keep.providers.providers_service.ProvidersService.install_provider", + return_value=mock_provider, + ) as mock_install, patch( + "keep.providers.providers_service.provision_deduplication_rules" + ) as mock_provision_rules, patch( + "keep.api.core.db.get_all_provisioned_providers", return_value=[] + ): + # Call the provisioning function + ProvidersService.provision_providers("test-tenant") + + # Verify provider installation was called + mock_install.assert_called_once() + + # Verify deduplication rules provisioning was called with empty rules + mock_provision_rules.assert_not_called() + + +def test_provider_yaml_with_invalid_deduplication_rules(temp_providers_dir, caplog): + """Test provisioning a provider from YAML file with invalid deduplication rules""" + yaml_content = """ +name: test-victoriametrics +type: victoriametrics +authentication: + VMAlertHost: http://localhost + VMAlertPort: 1234 +deduplication_rules: + invalid_rule: + # Missing required fields + description: Invalid rule +""" + # Create a YAML file + provider_file = os.path.join(temp_providers_dir, "test_provider.yaml") + with open(provider_file, "w") as f: + f.write(yaml_content) + + # Mock provider + mock_provider = MagicMock( + type="victoriametrics", + id="test-provider-id", + details={ + "name": "test-victoriametrics", + "authentication": {"VMAlertHost": "http://localhost", "VMAlertPort": 1234}, + }, + validatedScopes={}, + ) + + # Mock environment variables and services + with patch.dict(os.environ, {"KEEP_PROVIDERS_DIRECTORY": temp_providers_dir}): + with patch( + "keep.providers.providers_service.ProvidersService.is_provider_installed", + return_value=False, + ), patch( + "keep.providers.providers_service.ProvidersService.install_provider", + return_value=mock_provider, + ) as mock_install, patch( + "keep.providers.providers_service.provision_deduplication_rules" + ) as mock_provision_rules, patch( + "keep.api.core.db.get_all_provisioned_providers", return_value=[] + ): + # Call the provisioning function + ProvidersService.provision_providers("test-tenant") + + # Verify provider installation was called + mock_install.assert_called_once() + + # Verify deduplication rules provisioning was called + mock_provision_rules.assert_called_once() + call_args = mock_provision_rules.call_args[1] + assert call_args["tenant_id"] == "test-tenant" + + # Even invalid rules should be passed through, validation happens in provision_deduplication_rules + assert len(call_args["deduplication_rules"]) == 1 + rule = call_args["deduplication_rules"]["invalid_rule"] + assert rule["description"] == "Invalid rule" + assert "fingerprint_fields" not in rule diff --git a/tests/test_provisioning.py b/tests/test_provisioning.py index e987e29271..0d4ccafd14 100644 --- a/tests/test_provisioning.py +++ b/tests/test_provisioning.py @@ -371,3 +371,228 @@ def test_provision_provider_with_empty_tenant_table(db_session, client, test_app ) db_session.execute(text("PRAGMA foreign_keys = OFF;")) + + +@pytest.mark.parametrize( + "test_app", + [{"AUTH_TYPE": "NOAUTH"}], + indirect=True, +) +def test_no_provisioned_providers_and_unset_env_vars( + monkeypatch, db_session, client, test_app +): + """Test behavior when there are no provisioned providers and env vars are unset""" + # Import necessary modules + from unittest.mock import patch + + from keep.providers.providers_service import ProvidersService + + # Mock get_all_provisioned_providers to return an empty list + with patch( + "keep.providers.providers_service.get_all_provisioned_providers", + return_value=[], + ) as mock_get_providers, patch( + "keep.providers.providers_service.ProvidersService.delete_provider" + ) as mock_delete_provider: + # Call provision_providers without setting any env vars + ProvidersService.provision_providers("test-tenant") + + # Verify get_all_provisioned_providers was called + mock_get_providers.assert_called_once_with("test-tenant") + + # Verify delete_provider was not called since there were no providers to delete + mock_delete_provider.assert_not_called() + + +@pytest.mark.parametrize( + "test_app", + [{"AUTH_TYPE": "NOAUTH"}], + indirect=True, +) +def test_delete_provisioned_providers_when_env_vars_unset( + monkeypatch, db_session, client, test_app +): + """Test deleting provisioned providers when env vars are unset""" + # Import necessary modules + from unittest.mock import ANY, MagicMock, patch + + from keep.providers.providers_service import ProvidersService + + # Create a mock provider + mock_provider = MagicMock(id="test-id", name="test-provider", type="test-type") + + # Mock get_all_provisioned_providers to return our mock provider + with patch( + "keep.providers.providers_service.get_all_provisioned_providers", + return_value=[mock_provider], + ) as mock_get_providers, patch( + "keep.providers.providers_service.ProvidersService.delete_provider" + ) as mock_delete_provider: + # Call provision_providers without setting any env vars + ProvidersService.provision_providers("test-tenant") + + # Verify get_all_provisioned_providers was called + mock_get_providers.assert_called_once_with("test-tenant") + + # Verify delete_provider was called with correct parameters + mock_delete_provider.assert_called_once_with( + "test-tenant", + "test-id", + ANY, # Session object + allow_provisioned=True, + commit=False, + ) + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"existingProvider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort": 1234}}}', + }, + ], + indirect=True, +) +def test_replace_existing_provisioned_provider( + monkeypatch, db_session, client, test_app +): + """Test that when a new provider is provisioned via KEEP_PROVIDERS without including + the current provisioned provider, it removes the current one and installs the new one + """ + + # First verify the initial provider is installed + response = client.get("/providers", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + providers = response.json() + provisioned_providers = [ + p for p in providers.get("installed_providers") if p.get("provisioned") + ] + assert len(provisioned_providers) == 1 + # Provider name is in the details + provider_details = provisioned_providers[0].get("details", {}) + assert provider_details.get("name") == "existingProvider" + assert provisioned_providers[0]["type"] == "victoriametrics" + + # Change environment variable to new provider config that doesn't include the existing one + monkeypatch.setenv( + "KEEP_PROVIDERS", + '{"newProvider":{"type":"prometheus","authentication":{"url":"http://localhost:9090"}}}', + ) + + # Reload the app to apply the new environment changes + importlib.reload(sys.modules["keep.api.api"]) + from keep.api.api import get_app + + app = get_app() + + # Manually trigger the startup event + for event_handler in app.router.on_startup: + asyncio.run(event_handler()) + + # Manually trigger the provision resources + from keep.api.config import provision_resources + + provision_resources() + + client = TestClient(app) + + # Verify that the old provider is gone and new provider is installed + response = client.get("/providers", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + providers = response.json() + provisioned_providers = [ + p for p in providers.get("installed_providers") if p.get("provisioned") + ] + assert len(provisioned_providers) == 1 + provider_details = provisioned_providers[0].get("details", {}) + assert provider_details.get("name") == "newProvider" + assert provisioned_providers[0]["type"] == "prometheus" + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"rule1":{"description":"First rule","fingerprint_fields":["fingerprint","source"],"ignore_fields":["name"]}}}}', + }, + ], + indirect=True, +) +def test_delete_deduplication_rules_when_reprovisioning( + monkeypatch, db_session, client, test_app +): + """Test that deduplication rules are deleted when reprovisioning a provider without rules""" + + # First verify initial provider and rule are installed + response = client.get("/deduplications", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + rules = response.json() + assert len(rules) - 1 == 1 + assert rules[1]["name"] == "rule1" + + # Update provider config without any deduplication rules + monkeypatch.setenv( + "KEEP_PROVIDERS", + '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234}}}', + ) + + # Reload the app to apply the new environment changes + importlib.reload(sys.modules["keep.api.api"]) + from keep.api.api import get_app + + app = get_app() + + # Manually trigger the startup event + for event_handler in app.router.on_startup: + asyncio.run(event_handler()) + + # Manually trigger the provision resources + from keep.api.config import provision_resources + + provision_resources() + + client = TestClient(app) + + # Verify the rule was deleted + response = client.get("/deduplications", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + rules = response.json() + assert len(rules) == 0 + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"rule1":{"description":"First rule","fingerprint_fields":["fingerprint","source"]},"rule2":{"description":"Second rule","fingerprint_fields":["alert_id"]}}}}', + }, + ], + indirect=True, +) +def test_provision_provider_with_multiple_deduplication_rules( + db_session, client, test_app +): + """Test provisioning a provider with multiple deduplication rules""" + + # Verify the provider and rules are installed + response = client.get("/deduplications", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + rules = response.json() + assert len(rules) - 1 == 2 + + rule1 = next(r for r in rules[1:] if r["name"] == "rule1") + assert rule1["description"] == "First rule" + assert rule1["fingerprint_fields"] == ["fingerprint", "source"] + assert rule1["is_provisioned"] is True + + rule2 = next(r for r in rules if r["name"] == "rule2") + assert rule2["description"] == "Second rule" + assert rule2["fingerprint_fields"] == ["alert_id"] + assert rule2["is_provisioned"] is True + + # Verify both rules are associated with the same provider + assert rule1["provider_type"] == "victoriametrics" + assert rule2["provider_type"] == "victoriametrics" From 4fa513d06ff5739b4bc3aa81ab3d9e750aaedab1 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Fri, 4 Apr 2025 09:38:52 +0700 Subject: [PATCH 02/18] test: add unit tests for ProvidersService hash handling with Redis and file fallback --- tests/test_providers_provisioning_caching.py | 144 +++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 tests/test_providers_provisioning_caching.py diff --git a/tests/test_providers_provisioning_caching.py b/tests/test_providers_provisioning_caching.py new file mode 100644 index 0000000000..9e81fc2b75 --- /dev/null +++ b/tests/test_providers_provisioning_caching.py @@ -0,0 +1,144 @@ +import hashlib +import json +from unittest.mock import MagicMock, mock_open, patch + +import pytest +import redis + +from keep.api.consts import REDIS_DB, REDIS_HOST, REDIS_PORT +from keep.providers.providers_service import ProvidersService + + +@pytest.fixture +def tenant_id(): + return "test_tenant" + + +@pytest.fixture +def hash_value(): + return "test_hash" + + +def test_write_provisioned_hash_redis(tenant_id, hash_value): + """Test writing hash to Redis when Redis is enabled""" + with patch("keep.providers.providers_service.REDIS", True), patch( + "redis.Redis" + ) as mock_redis: + mock_redis_instance = MagicMock() + mock_redis.return_value = mock_redis_instance + + ProvidersService.write_provisioned_hash(tenant_id, hash_value) + + mock_redis.assert_called_once_with( + host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB + ) + mock_redis_instance.set.assert_called_once_with( + f"{tenant_id}_providers_hash", hash_value + ) + + +def test_write_provisioned_hash_file(tenant_id, hash_value): + """Test writing hash to file when Redis is disabled""" + file_path = f"/state/{tenant_id}_providers_hash.txt" + + with patch("keep.providers.providers_service.REDIS", False), patch( + "builtins.open", mock_open() + ) as mock_file: + ProvidersService.write_provisioned_hash(tenant_id, hash_value) + + mock_file.assert_called_once_with(file_path, "w") + mock_file().write.assert_called_once_with(hash_value) + + +def test_get_provisioned_hash_redis_success(tenant_id, hash_value): + """Test getting hash from Redis successfully""" + with patch("keep.providers.providers_service.REDIS", True), patch( + "redis.Redis" + ) as mock_redis: + mock_redis_instance = MagicMock() + mock_redis_instance.get.return_value = hash_value.encode() + mock_redis.return_value.__enter__.return_value = mock_redis_instance + + result = ProvidersService.get_provisioned_hash(tenant_id) + + assert result == hash_value + mock_redis_instance.get.assert_called_once_with(f"{tenant_id}_providers_hash") + + +def test_get_provisioned_hash_redis_error(tenant_id, hash_value): + """Test falling back to file when Redis fails""" + with patch("keep.providers.providers_service.REDIS", True), patch( + "redis.Redis" + ) as mock_redis, patch("builtins.open", mock_open(read_data=hash_value)): + mock_redis.return_value.__enter__.side_effect = redis.RedisError("Test error") + + result = ProvidersService.get_provisioned_hash(tenant_id) + + assert result == hash_value + + +def test_get_provisioned_hash_file_success(tenant_id, hash_value): + """Test getting hash from file successfully""" + with patch("keep.providers.providers_service.REDIS", False), patch( + "builtins.open", mock_open(read_data=hash_value) + ): + result = ProvidersService.get_provisioned_hash(tenant_id) + + assert result == hash_value + + +def test_get_provisioned_hash_file_not_found(tenant_id): + """Test handling file not found error""" + with patch("keep.providers.providers_service.REDIS", False), patch( + "builtins.open" + ) as mock_file: + mock_file.side_effect = FileNotFoundError() + + result = ProvidersService.get_provisioned_hash(tenant_id) + + assert result is None + + +def test_calculate_provider_hash_json(): + """Test calculating hash from JSON input""" + json_input = '{"provider": "test"}' + expected_hash = hashlib.sha256(json.dumps(json_input).encode("utf-8")).hexdigest() + + result = ProvidersService.calculate_provider_hash( + provisioned_providers_json=json_input + ) + + assert result == expected_hash + + +def test_calculate_provider_hash_directory(): + """Test calculating hash from directory input""" + test_dir = "/test/providers" + yaml_content = "provider: test" + + with patch("os.listdir") as mock_listdir, patch("os.path.join") as mock_join, patch( + "builtins.open", mock_open(read_data=yaml_content) + ): + mock_listdir.return_value = ["provider1.yaml", "provider2.yml", "other.txt"] + mock_join.side_effect = lambda *args: f"{args[0]}/{args[1]}" + + result = ProvidersService.calculate_provider_hash( + provisioned_providers_dir=test_dir + ) + + expected_data = [yaml_content, yaml_content] # Two YAML files + expected_hash = hashlib.sha256( + json.dumps(expected_data).encode("utf-8") + ).hexdigest() + + assert result == expected_hash + assert mock_listdir.call_count == 1 + assert mock_join.call_count == 2 + + +def test_calculate_provider_hash_no_input(): + """Test calculating hash with no input""" + result = ProvidersService.calculate_provider_hash() + expected_hash = hashlib.sha256(json.dumps("").encode("utf-8")).hexdigest() + + assert result == expected_hash From a14edd1d9dfc0980838a264338c91b12342fcbed Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Fri, 4 Apr 2025 09:48:59 +0700 Subject: [PATCH 03/18] test: add more unit tests for ProvidersService hash handling with Redis and file --- tests/test_providers_provisioning_caching.py | 127 +++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/tests/test_providers_provisioning_caching.py b/tests/test_providers_provisioning_caching.py index 9e81fc2b75..4f2e99ef6a 100644 --- a/tests/test_providers_provisioning_caching.py +++ b/tests/test_providers_provisioning_caching.py @@ -142,3 +142,130 @@ def test_calculate_provider_hash_no_input(): expected_hash = hashlib.sha256(json.dumps("").encode("utf-8")).hexdigest() assert result == expected_hash + + +def test_write_provisioned_hash_redis_enabled(tenant_id, hash_value): + """Test writing hash to Redis when Redis is enabled""" + with patch("keep.providers.providers_service.REDIS", True), patch( + "redis.Redis" + ) as mock_redis: + mock_redis_instance = MagicMock() + mock_redis.return_value = mock_redis_instance + + ProvidersService.write_provisioned_hash(tenant_id, hash_value) + + mock_redis.assert_called_once_with( + host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB + ) + mock_redis_instance.set.assert_called_once_with( + f"{tenant_id}_providers_hash", hash_value + ) + + +def test_write_provisioned_hash_redis_disabled(tenant_id, hash_value): + """Test writing hash to file when Redis is disabled""" + file_path = f"/state/{tenant_id}_providers_hash.txt" + + with patch("keep.providers.providers_service.REDIS", False), patch( + "builtins.open", mock_open() + ) as mock_file: + ProvidersService.write_provisioned_hash(tenant_id, hash_value) + + mock_file.assert_called_once_with(file_path, "w") + mock_file().write.assert_called_once_with(hash_value) + + +def test_get_provisioned_hash_redis_enabled_success(tenant_id, hash_value): + """Test getting hash from Redis successfully when Redis is enabled""" + with patch("keep.providers.providers_service.REDIS", True), patch( + "redis.Redis" + ) as mock_redis: + mock_redis_instance = MagicMock() + mock_redis_instance.get.return_value = hash_value.encode() + mock_redis.return_value.__enter__.return_value = mock_redis_instance + + result = ProvidersService.get_provisioned_hash(tenant_id) + + assert result == hash_value + mock_redis_instance.get.assert_called_once_with(f"{tenant_id}_providers_hash") + + +def test_get_provisioned_hash_redis_enabled_none_value(tenant_id): + """Test getting None from Redis when Redis is enabled but no value exists""" + with patch("keep.providers.providers_service.REDIS", True), patch( + "redis.Redis" + ) as mock_redis: + # Mock Redis returning None + mock_redis_instance = MagicMock() + mock_redis_instance.get.return_value = None + mock_redis.return_value.__enter__.return_value = mock_redis_instance + + result = ProvidersService.get_provisioned_hash(tenant_id) + + assert result is None + mock_redis_instance.get.assert_called_once_with(f"{tenant_id}_providers_hash") + + +def test_get_provisioned_hash_redis_disabled_success(tenant_id, hash_value): + """Test getting hash from file successfully when Redis is disabled""" + with patch("keep.providers.providers_service.REDIS", False), patch( + "builtins.open", mock_open(read_data=hash_value) + ), patch("redis.Redis") as mock_redis: + result = ProvidersService.get_provisioned_hash(tenant_id) + + assert result == hash_value + # Should not try to read from Redis + mock_redis.assert_not_called() + + +def test_get_provisioned_hash_redis_enabled_byte_decoding(tenant_id): + """Test proper decoding of bytes from Redis""" + encoded_hash = b"test_hash_with_whitespace \n" + expected_hash = "test_hash_with_whitespace" + + with patch("keep.providers.providers_service.REDIS", True), patch( + "redis.Redis" + ) as mock_redis: + mock_redis_instance = MagicMock() + mock_redis_instance.get.return_value = encoded_hash + mock_redis.return_value.__enter__.return_value = mock_redis_instance + + result = ProvidersService.get_provisioned_hash(tenant_id) + + assert result == expected_hash + mock_redis_instance.get.assert_called_once_with(f"{tenant_id}_providers_hash") + + +def test_calculate_provider_hash_consistency(tenant_id): + """Test that hash calculation is consistent for the same input""" + + # Test with JSON input + json_input_1 = '{"provider": "test"}' + json_input_2 = '{"provider": "test"}' + + hash_1 = ProvidersService.calculate_provider_hash( + provisioned_providers_json=json_input_1 + ) + hash_2 = ProvidersService.calculate_provider_hash( + provisioned_providers_json=json_input_2 + ) + + assert hash_1 == hash_2 + + # Test with directory input + yaml_content = "provider: test" + + with patch("os.listdir") as mock_listdir, patch("os.path.join") as mock_join, patch( + "builtins.open", mock_open(read_data=yaml_content) + ): + mock_listdir.return_value = ["provider1.yaml"] + mock_join.side_effect = lambda *args: f"{args[0]}/{args[1]}" + + hash_3 = ProvidersService.calculate_provider_hash( + provisioned_providers_dir="/test/dir" + ) + hash_4 = ProvidersService.calculate_provider_hash( + provisioned_providers_dir="/test/dir" + ) + + assert hash_3 == hash_4 From 0e05ce6bb8a0b8e12035c7af471d1e8414df2790 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Fri, 4 Apr 2025 10:14:47 +0700 Subject: [PATCH 04/18] fix: update log message for reading provisioned hash from file in ProvidersService --- keep/providers/providers_service.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py index c18cf0f589..091d367dbe 100644 --- a/keep/providers/providers_service.py +++ b/keep/providers/providers_service.py @@ -481,9 +481,7 @@ def get_provisioned_hash(tenant_id: str) -> Optional[str]: encoding="utf-8", ) as f: previous_hash = f.read().strip() - logger.info( - f"Provisioned hash for tenant {tenant_id} read from file: {previous_hash}" - ) + logger.info(f"Provisioned hash for tenant {tenant_id} read from file.") except FileNotFoundError: logger.info(f"Provisioned hash file for tenant {tenant_id} not found.") except Exception as e: From fc2fdd23d8762c78c1b194040bf625d238224501 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Fri, 4 Apr 2025 10:14:56 +0700 Subject: [PATCH 05/18] refactor: remove redundant deletion of non-existent deduplication rules in provisioning --- .../deduplication_rules_provisioning.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/keep/api/alert_deduplicator/deduplication_rules_provisioning.py b/keep/api/alert_deduplicator/deduplication_rules_provisioning.py index d89815f662..215c50a1f8 100644 --- a/keep/api/alert_deduplicator/deduplication_rules_provisioning.py +++ b/keep/api/alert_deduplicator/deduplication_rules_provisioning.py @@ -32,17 +32,6 @@ def provision_deduplication_rules( } actor = "system" - # delete rules that are not in the env - for provisioned_deduplication_rule in provisioned_deduplication_rules: - if str(provisioned_deduplication_rule.name) not in deduplication_rules: - logger.info( - "Deduplication rule with name '%s' is not in the env, deleting from DB", - provisioned_deduplication_rule.name, - ) - db.delete_deduplication_rule( - rule_id=str(provisioned_deduplication_rule.id), tenant_id=tenant_id - ) - for ( deduplication_rule_name, deduplication_rule_to_provision, From 3ecf11df1446273445552053f16fd7fae915e363 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sat, 5 Apr 2025 10:34:39 +0700 Subject: [PATCH 06/18] refactor: update ProvidersService to use secret manager for hash value instead of only local file --- keep/providers/providers_service.py | 46 +++++----- tests/test_providers_provisioning_caching.py | 96 +++++++++++++------- 2 files changed, 88 insertions(+), 54 deletions(-) diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py index 091d367dbe..08d41c14c0 100644 --- a/keep/providers/providers_service.py +++ b/keep/providers/providers_service.py @@ -31,9 +31,6 @@ from keep.providers.providers_factory import ProvidersFactory from keep.secretmanager.secretmanagerfactory import SecretManagerFactory -DEFAULT_PROVIDER_HASH_STATE_FILE = "/state/{tenant_id}_providers_hash.txt" - - logger = logging.getLogger(__name__) @@ -432,7 +429,7 @@ def provision_provider_deduplication_rules( @staticmethod def write_provisioned_hash(tenant_id: str, hash_value: str): """ - Write the provisioned hash to Redis or file. + Write the provisioned hash to Redis or secret manager. Args: tenant_id (str): The tenant ID. @@ -443,16 +440,20 @@ def write_provisioned_hash(tenant_id: str, hash_value: str): r.set(f"{tenant_id}_providers_hash", hash_value) logger.info(f"Provisioned hash for tenant {tenant_id} written to Redis!") else: - with open( - DEFAULT_PROVIDER_HASH_STATE_FILE.format(tenant_id=tenant_id), "w" - ) as f: - f.write(hash_value) - logger.info(f"Provisioned hash for tenant {tenant_id} written to file!") + context_manager = ContextManager(tenant_id=tenant_id) + secret_manager = SecretManagerFactory.get_secret_manager(context_manager) + secret_manager.write_secret( + secret_name=f"{tenant_id}_providers_hash", + secret_value=hash_value, + ) + logger.info( + f"Provisioned hash for tenant {tenant_id} written to secret manager!" + ) @staticmethod def get_provisioned_hash(tenant_id: str) -> Optional[str]: """ - Get the provisioned hash from Redis or file. + Get the provisioned hash from Redis or secret manager. Args: tenant_id (str): The tenant ID. @@ -475,18 +476,19 @@ def get_provisioned_hash(tenant_id: str) -> Optional[str]: if previous_hash is None: try: - with open( - DEFAULT_PROVIDER_HASH_STATE_FILE.format(tenant_id=tenant_id), - "r", - encoding="utf-8", - ) as f: - previous_hash = f.read().strip() - logger.info(f"Provisioned hash for tenant {tenant_id} read from file.") - except FileNotFoundError: - logger.info(f"Provisioned hash file for tenant {tenant_id} not found.") + context_manager = ContextManager(tenant_id=tenant_id) + secret_manager = SecretManagerFactory.get_secret_manager( + context_manager + ) + previous_hash = secret_manager.read_secret( + f"{tenant_id}_providers_hash" + ) + logger.info( + f"Provisioned hash for tenant {tenant_id} read from secret manager." + ) except Exception as e: logger.warning( - f"Failed to read hash from file for tenant {tenant_id}: {e}" + f"Failed to read hash from secret manager for tenant {tenant_id}: {e}" ) return previous_hash if previous_hash else None @@ -565,7 +567,7 @@ def provision_providers(tenant_id: str): provisioned_providers_dir, provisioned_providers_json ) - # Get the previous hash from Redis or file + # Get the previous hash from Redis or secret manager previous_hash = ProvidersService.get_provisioned_hash(tenant_id) if providers_hash == previous_hash: logger.info( @@ -724,7 +726,7 @@ def provision_providers(tenant_id: str): logger.error("Provisioning failed, rolling back", extra={"exception": e}) session.rollback() finally: - # Store the hash in Redis or file + # Store the hash in Redis or secret manager try: ProvidersService.write_provisioned_hash(tenant_id, providers_hash) except Exception as e: diff --git a/tests/test_providers_provisioning_caching.py b/tests/test_providers_provisioning_caching.py index 4f2e99ef6a..d6df0a635c 100644 --- a/tests/test_providers_provisioning_caching.py +++ b/tests/test_providers_provisioning_caching.py @@ -37,17 +37,20 @@ def test_write_provisioned_hash_redis(tenant_id, hash_value): ) -def test_write_provisioned_hash_file(tenant_id, hash_value): - """Test writing hash to file when Redis is disabled""" - file_path = f"/state/{tenant_id}_providers_hash.txt" +def test_write_provisioned_hash_secret_manager(tenant_id, hash_value): + """Test writing hash to secret manager when Redis is disabled""" + mock_secret_manager = MagicMock() with patch("keep.providers.providers_service.REDIS", False), patch( - "builtins.open", mock_open() - ) as mock_file: + "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" + ) as mock_get_secret_manager: + mock_get_secret_manager.return_value = mock_secret_manager + ProvidersService.write_provisioned_hash(tenant_id, hash_value) - mock_file.assert_called_once_with(file_path, "w") - mock_file().write.assert_called_once_with(hash_value) + mock_secret_manager.write_secret.assert_called_once_with( + secret_name=f"{tenant_id}_providers_hash", secret_value=hash_value + ) def test_get_provisioned_hash_redis_success(tenant_id, hash_value): @@ -66,33 +69,53 @@ def test_get_provisioned_hash_redis_success(tenant_id, hash_value): def test_get_provisioned_hash_redis_error(tenant_id, hash_value): - """Test falling back to file when Redis fails""" + """Test falling back to secret manager when Redis fails""" + mock_secret_manager = MagicMock() + mock_secret_manager.read_secret.return_value = hash_value + with patch("keep.providers.providers_service.REDIS", True), patch( "redis.Redis" - ) as mock_redis, patch("builtins.open", mock_open(read_data=hash_value)): + ) as mock_redis, patch( + "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" + ) as mock_get_secret_manager: mock_redis.return_value.__enter__.side_effect = redis.RedisError("Test error") + mock_get_secret_manager.return_value = mock_secret_manager result = ProvidersService.get_provisioned_hash(tenant_id) assert result == hash_value + mock_secret_manager.read_secret.assert_called_once_with( + f"{tenant_id}_providers_hash" + ) + +def test_get_provisioned_hash_secret_manager_success(tenant_id, hash_value): + """Test getting hash from secret manager successfully""" + mock_secret_manager = MagicMock() + mock_secret_manager.read_secret.return_value = hash_value -def test_get_provisioned_hash_file_success(tenant_id, hash_value): - """Test getting hash from file successfully""" with patch("keep.providers.providers_service.REDIS", False), patch( - "builtins.open", mock_open(read_data=hash_value) - ): + "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" + ) as mock_get_secret_manager: + mock_get_secret_manager.return_value = mock_secret_manager + result = ProvidersService.get_provisioned_hash(tenant_id) assert result == hash_value + mock_secret_manager.read_secret.assert_called_once_with( + f"{tenant_id}_providers_hash" + ) + +def test_get_provisioned_hash_secret_manager_error(tenant_id): + """Test handling secret manager error""" + mock_secret_manager = MagicMock() + mock_secret_manager.read_secret.side_effect = Exception("Secret not found") -def test_get_provisioned_hash_file_not_found(tenant_id): - """Test handling file not found error""" with patch("keep.providers.providers_service.REDIS", False), patch( - "builtins.open" - ) as mock_file: - mock_file.side_effect = FileNotFoundError() + "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" + ) as mock_get_secret_manager: + mock_get_secret_manager.return_value = mock_secret_manager result = ProvidersService.get_provisioned_hash(tenant_id) @@ -162,17 +185,20 @@ def test_write_provisioned_hash_redis_enabled(tenant_id, hash_value): ) -def test_write_provisioned_hash_redis_disabled(tenant_id, hash_value): - """Test writing hash to file when Redis is disabled""" - file_path = f"/state/{tenant_id}_providers_hash.txt" +def test_write_provisioned_hash_redis_disabled_secret_manager(tenant_id, hash_value): + """Test writing hash to secret manager when Redis is disabled""" + mock_secret_manager = MagicMock() with patch("keep.providers.providers_service.REDIS", False), patch( - "builtins.open", mock_open() - ) as mock_file: + "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" + ) as mock_get_secret_manager: + mock_get_secret_manager.return_value = mock_secret_manager + ProvidersService.write_provisioned_hash(tenant_id, hash_value) - mock_file.assert_called_once_with(file_path, "w") - mock_file().write.assert_called_once_with(hash_value) + mock_secret_manager.write_secret.assert_called_once_with( + secret_name=f"{tenant_id}_providers_hash", secret_value=hash_value + ) def test_get_provisioned_hash_redis_enabled_success(tenant_id, hash_value): @@ -206,16 +232,22 @@ def test_get_provisioned_hash_redis_enabled_none_value(tenant_id): mock_redis_instance.get.assert_called_once_with(f"{tenant_id}_providers_hash") -def test_get_provisioned_hash_redis_disabled_success(tenant_id, hash_value): - """Test getting hash from file successfully when Redis is disabled""" - with patch("keep.providers.providers_service.REDIS", False), patch( - "builtins.open", mock_open(read_data=hash_value) - ), patch("redis.Redis") as mock_redis: +def test_get_provisioned_hash_redis_preferred(tenant_id, hash_value): + """Test that Redis is preferred over secret manager when Redis works""" + with patch("keep.providers.providers_service.REDIS", True), patch( + "redis.Redis" + ) as mock_redis, patch( + "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" + ) as mock_get_secret_manager: + mock_redis_instance = MagicMock() + mock_redis_instance.get.return_value = hash_value.encode() + mock_redis.return_value.__enter__.return_value = mock_redis_instance + result = ProvidersService.get_provisioned_hash(tenant_id) assert result == hash_value - # Should not try to read from Redis - mock_redis.assert_not_called() + # Should not try to use secret manager when Redis works + mock_get_secret_manager.assert_not_called() def test_get_provisioned_hash_redis_enabled_byte_decoding(tenant_id): From a4b373121926d134850f8cc0ce3e48b28cc4c000 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sat, 5 Apr 2025 10:45:03 +0700 Subject: [PATCH 07/18] docs: update provisioning documentation --- docs/deployment/provision/overview.mdx | 3 ++- docs/deployment/provision/provider.mdx | 32 ++++++++++---------------- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/docs/deployment/provision/overview.mdx b/docs/deployment/provision/overview.mdx index be567c4725..ce4d803f5f 100644 --- a/docs/deployment/provision/overview.mdx +++ b/docs/deployment/provision/overview.mdx @@ -27,7 +27,8 @@ Provisioning in Keep is controlled through environment variables and configurati | Provisioning Type | Environment Variable | Purpose | | ---------------------- | ------------------------------ | ------------------------------------------------------------------------- | | **Provider** | `KEEP_PROVIDERS` | JSON string containing provider configurations with deduplication rules | -| **Workflow** | `KEEP_WORKFLOW` | One workflow to provision right from the env variable. | +| **Providers** | `KEEP_PROVIDERS_DIRECTORY` | Directory path containing provider configuration files | +| **Workflow** | `KEEP_WORKFLOW` | One workflow to provision right from the env variable | | **Workflows** | `KEEP_WORKFLOWS_DIRECTORY` | Directory path containing workflow configuration files | | **Dashboard** | `KEEP_DASHBOARDS` | JSON string containing dashboard configurations | diff --git a/docs/deployment/provision/provider.mdx b/docs/deployment/provision/provider.mdx index c73b91cbdf..a6c63f4a4f 100644 --- a/docs/deployment/provision/provider.mdx +++ b/docs/deployment/provision/provider.mdx @@ -102,26 +102,18 @@ To see the full list of supported providers and their detailed configuration opt ### Update Provisioned Providers -#### Using KEEP_PROVIDERS +Keep uses a consistent process for updating provider configurations regardless of whether you use `KEEP_PROVIDERS` or `KEEP_PROVIDERS_DIRECTORY`. -Provider configurations can be updated dynamically by changing the `KEEP_PROVIDERS` environment variable. +#### Provisioning Process -On every restart, Keep reads this environment variable and determines which providers need to be added or removed. +When Keep starts or restarts, it follows these steps to manage provider configurations: -This process allows for flexible management of data sources without requiring manual intervention. By simply updating the `KEEP_PROVIDERS` variable and restarting the application, you can efficiently add new providers, remove existing ones, or modify their configurations. - -The high-level provisioning mechanism: -1. Keep reads the `KEEP_PROVIDERS` value. -2. Keep checks if there are any provisioned providers that are no longer in the `KEEP_PROVIDERS` value, and deletes them. -3. Keep installs all providers from the `KEEP_PROVIDERS` value. - -#### Using KEEP_PROVIDERS_DIRECTORY - -Provider configurations can be updated dynamically by changing the YAML files in the `KEEP_PROVIDERS_DIRECTORY` directory. - -On every restart, Keep reads the YAML files in the `KEEP_PROVIDERS_DIRECTORY` directory and determines which providers need to be added or removed. - -The high-level provisioning mechanism: -1. Keep reads the YAML files in the `KEEP_PROVIDERS_DIRECTORY` directory. -2. Keep checks if there are any provisioned providers that are no longer in the YAML files, and deletes them. -3. Keep installs all providers from the YAML files. +1. **Read Configurations**: Loads provider definitions from either the `KEEP_PROVIDERS` environment variable or YAML files in the `KEEP_PROVIDERS_DIRECTORY`. +2. **Calculate Configuration Hash**: Generates a hash of the current configurations to detect changes. +3. **Check for Changes**: Compares the new hash with the previously stored hash (in Redis or secret manager). +4. **Update When Changed**: If configurations have changed: + - Backup the current state for potential rollback + - Delete all existing provisioned providers + - Provision new providers with their deduplication rules + - If any errors occur during provisioning, automatically rollback to the previous state +5. **Skip When Unchanged**: If configurations haven't changed since the last startup, Keep skips the re-provisioning process to improve startup performance. From add8374e72005446b9c7403c60f1ed380cf88fa0 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sat, 5 Apr 2025 11:28:26 +0700 Subject: [PATCH 08/18] test: update deduplication rules and adjust provisioning tests --- tests/deduplication/test_deduplications.py | 34 ++++++++++++---------- tests/test_provisioning.py | 4 +-- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/tests/deduplication/test_deduplications.py b/tests/deduplication/test_deduplications.py index 1f367868d7..6500474e7c 100644 --- a/tests/deduplication/test_deduplications.py +++ b/tests/deduplication/test_deduplications.py @@ -10,8 +10,8 @@ from keep.api.core.db import get_last_alerts from keep.api.core.dependencies import SINGLE_TENANT_UUID -from keep.api.models.alert import DeduplicationRuleDto, AlertStatus -from keep.api.models.db.alert import AlertDeduplicationRule, AlertDeduplicationEvent, Alert +from keep.api.models.alert import AlertStatus +from keep.api.models.db.alert import Alert, AlertDeduplicationRule from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts from keep.providers.providers_factory import ProvidersFactory from tests.fixtures.client import client, setup_api_key, test_app # noqa @@ -359,7 +359,7 @@ def test_custom_deduplication_rule_behaviour(db_session, client, test_app): [ { "AUTH_TYPE": "NOAUTH", - "KEEP_PROVIDERS": '{"keepDatadog":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', + "KEEP_PROVIDERS": '{"keepDatadogCustomRule":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', }, ], indirect=True, @@ -432,7 +432,7 @@ def test_custom_deduplication_rule_2(db_session, client, test_app): [ { "AUTH_TYPE": "NOAUTH", - "KEEP_PROVIDERS": '{"keepDatadog":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', + "KEEP_PROVIDERS": '{"keepDatadogUpdateRule":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', }, ], indirect=True, @@ -557,7 +557,7 @@ def test_update_deduplication_rule_linked_provider(db_session, client, test_app) [ { "AUTH_TYPE": "NOAUTH", - "KEEP_PROVIDERS": '{"keepDatadog":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', + "KEEP_PROVIDERS": '{"keepDatadogDeleteRule":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', }, ], indirect=True, @@ -857,7 +857,9 @@ def test_full_deduplication_last_received(db_session, create_alert): db_session.exec(text("DELETE FROM alertdeduplicationrule")) dedup = AlertDeduplicationRule( name="Test Rule", - fingerprint_fields=["service",], + fingerprint_fields=[ + "service", + ], full_deduplication=True, ignore_fields=["fingerprint", "lastReceived", "id"], is_provisioned=True, @@ -879,30 +881,30 @@ def test_full_deduplication_last_received(db_session, create_alert): None, AlertStatus.FIRING, dt1, - { - "source": ["keep"], - "service": "service" - }, + {"source": ["keep"], "service": "service"}, ) assert db_session.query(Alert).count() == 1 alerts = get_last_alerts(SINGLE_TENANT_UUID) alerts_dto = convert_db_alerts_to_dto_alerts(alerts) - assert alerts_dto[0].lastReceived == dt1.astimezone(pytz.UTC).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + assert ( + alerts_dto[0].lastReceived + == dt1.astimezone(pytz.UTC).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + ) create_alert( None, AlertStatus.FIRING, dt2, - { - "source": ["keep"], - "service": "service" - }, + {"source": ["keep"], "service": "service"}, ) assert db_session.query(Alert).count() == 1 alerts = get_last_alerts(SINGLE_TENANT_UUID) alerts_dto = convert_db_alerts_to_dto_alerts(alerts) - assert alerts_dto[0].lastReceived == dt2.astimezone(pytz.UTC).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + assert ( + alerts_dto[0].lastReceived + == dt2.astimezone(pytz.UTC).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + ) diff --git a/tests/test_provisioning.py b/tests/test_provisioning.py index 0d4ccafd14..b527597016 100644 --- a/tests/test_provisioning.py +++ b/tests/test_provisioning.py @@ -191,12 +191,12 @@ def test_provision_provider(db_session, client, test_app): def test_reprovision_provider(monkeypatch, db_session, client, test_app): response = client.get("/providers", headers={"x-api-key": "someapikey"}) assert response.status_code == 200 - # 3 workflows and 3 provisioned workflows providers = response.json() provisioned_providers = [ p for p in providers.get("installed_providers") if p.get("provisioned") ] - assert len(provisioned_providers) == 2 + # Skip the re-provisioning when configurations are not changed + assert len(provisioned_providers) == 0 # Step 2: Change environment variables (simulating new provisioning) monkeypatch.setenv( From 7be8b47f0fa8b374ddc36d1d325f95e5b3aa7069 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sat, 5 Apr 2025 11:32:53 +0700 Subject: [PATCH 09/18] test: fix KEEP_PROVIDERS configuration for VictoriaMetrics in provisioning tests --- tests/test_provisioning.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_provisioning.py b/tests/test_provisioning.py index b527597016..c5ea522268 100644 --- a/tests/test_provisioning.py +++ b/tests/test_provisioning.py @@ -162,7 +162,7 @@ def test_reprovision_workflow(monkeypatch, db_session, client, test_app): [ { "AUTH_TYPE": "NOAUTH", - "KEEP_PROVIDERS": '{"keepVictoriaMetrics":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort": 1234}},"keepClickhouse1":{"type":"clickhouse","authentication":{"host":"http://localhost","port":1234,"username":"keep","password":"keep","database":"keep-db"}}}', + "KEEP_PROVIDERS": '{"keepVictoriaMetrics1":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort": 1234}},"keepClickhouse1":{"type":"clickhouse","authentication":{"host":"http://localhost","port":1234,"username":"keep","password":"keep","database":"keep-db"}}}', }, ], indirect=True, @@ -183,7 +183,7 @@ def test_provision_provider(db_session, client, test_app): [ { "AUTH_TYPE": "NOAUTH", - "KEEP_PROVIDERS": '{"keepVictoriaMetrics":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort": 1234}},"keepClickhouse1":{"type":"clickhouse","authentication":{"host":"http://localhost","port":1234,"username":"keep","password":"keep","database":"keep-db"}}}', + "KEEP_PROVIDERS": '{"keepVictoriaMetric2":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort": 1234}},"keepClickhouse1":{"type":"clickhouse","authentication":{"host":"http://localhost","port":1234,"username":"keep","password":"keep","database":"keep-db"}}}', }, ], indirect=True, @@ -195,8 +195,7 @@ def test_reprovision_provider(monkeypatch, db_session, client, test_app): provisioned_providers = [ p for p in providers.get("installed_providers") if p.get("provisioned") ] - # Skip the re-provisioning when configurations are not changed - assert len(provisioned_providers) == 0 + assert len(provisioned_providers) == 2 # Step 2: Change environment variables (simulating new provisioning) monkeypatch.setenv( From 301a9a9f539f27beb38afe8ab80a423e0caa09f8 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sun, 6 Apr 2025 14:46:14 +0700 Subject: [PATCH 10/18] refactor: simplify session management in ProvidersService by using existed_or_new_session --- keep/providers/providers_service.py | 51 ++++++++++++++--------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py index 08d41c14c0..8d2b1bb147 100644 --- a/keep/providers/providers_service.py +++ b/keep/providers/providers_service.py @@ -18,6 +18,7 @@ from keep.api.core.config import config from keep.api.core.db import ( engine, + existed_or_new_session, get_all_provisioned_providers, get_provider_by_name, get_provider_logs, @@ -191,11 +192,6 @@ def install_provider( secret_value=json.dumps(config), ) - session_managed = False - if not session: - session = Session(engine) - session_managed = True - provider_model = Provider( id=provider_unique_id, tenant_id=tenant_id, @@ -209,30 +205,31 @@ def install_provider( provisioned=provisioned, pulling_enabled=pulling_enabled, ) - try: - session.add(provider_model) - if commit: - session.commit() - except IntegrityError as e: - if "FOREIGN KEY constraint" in str(e): - raise + + with existed_or_new_session(session) as session: try: - # if the provider is already installed, delete the secret - logger.warning( - "Provider already installed, deleting secret", - extra={"error": str(e)}, - ) - secret_manager.delete_secret( - secret_name=secret_name, + session.add(provider_model) + if commit: + session.commit() + except IntegrityError as e: + if "FOREIGN KEY constraint" in str(e): + raise + try: + # if the provider is already installed, delete the secret + logger.warning( + "Provider already installed, deleting secret", + extra={"error": str(e)}, + ) + secret_manager.delete_secret( + secret_name=secret_name, + ) + logger.warning("Secret deleted") + except Exception: + logger.exception("Failed to delete the secret") + pass + raise HTTPException( + status_code=409, detail="Provider already installed" ) - logger.warning("Secret deleted") - except Exception: - logger.exception("Failed to delete the secret") - pass - raise HTTPException(status_code=409, detail="Provider already installed") - finally: - if session_managed: - session.close() if provider_model.consumer: try: From 4b33436e3f87a7c8beff4136ea3e4f5624d92b36 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sat, 12 Apr 2025 20:04:39 +0700 Subject: [PATCH 11/18] refactor: alter logic to provision providers --- keep/api/consts.py | 3 - keep/providers/providers_service.py | 521 ++++++++----------- tests/test_providers_provisioning_caching.py | 303 ----------- 3 files changed, 205 insertions(+), 622 deletions(-) delete mode 100644 tests/test_providers_provisioning_caching.py diff --git a/keep/api/consts.py b/keep/api/consts.py index 0c90c5f293..a3c82c66f4 100644 --- a/keep/api/consts.py +++ b/keep/api/consts.py @@ -41,9 +41,6 @@ KEEP_ARQ_QUEUE_WORKFLOWS = "workflows" REDIS = os.environ.get("REDIS", "false") == "true" -REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") -REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) -REDIS_DB = int(os.environ.get("REDIS_DB", 0)) if REDIS: KEEP_ARQ_TASK_POOL = os.environ.get("KEEP_ARQ_TASK_POOL", KEEP_ARQ_TASK_POOL_ALL) diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py index 8d2b1bb147..e4b77bb576 100644 --- a/keep/providers/providers_service.py +++ b/keep/providers/providers_service.py @@ -1,4 +1,3 @@ -import hashlib import json import logging import os @@ -6,7 +5,6 @@ import uuid from typing import Any, Dict, List, Optional -import redis from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select @@ -14,7 +12,6 @@ from keep.api.alert_deduplicator.deduplication_rules_provisioning import ( provision_deduplication_rules, ) -from keep.api.consts import REDIS, REDIS_DB, REDIS_HOST, REDIS_PORT from keep.api.core.config import config from keep.api.core.db import ( engine, @@ -154,8 +151,6 @@ def install_provider( provisioned: bool = False, validate_scopes: bool = True, pulling_enabled: bool = True, - session: Optional[Session] = None, - commit: bool = True, ) -> Dict[str, Any]: provider_unique_id = uuid.uuid4().hex logger.info( @@ -192,59 +187,41 @@ def install_provider( secret_value=json.dumps(config), ) - provider_model = Provider( - id=provider_unique_id, - tenant_id=tenant_id, - name=provider_name, - type=provider_type, - installed_by=installed_by, - installation_time=time.time(), - configuration_key=secret_name, - validatedScopes=validated_scopes, - consumer=provider.is_consumer, - provisioned=provisioned, - pulling_enabled=pulling_enabled, - ) - - with existed_or_new_session(session) as session: + with Session(engine) as session: + provider_model = Provider( + id=provider_unique_id, + tenant_id=tenant_id, + name=provider_name, + type=provider_type, + installed_by=installed_by, + installation_time=time.time(), + configuration_key=secret_name, + validatedScopes=validated_scopes, + consumer=provider.is_consumer, + provisioned=provisioned, + pulling_enabled=pulling_enabled, + ) try: session.add(provider_model) - if commit: - session.commit() + session.commit() except IntegrityError as e: if "FOREIGN KEY constraint" in str(e): raise + + if provider_model.consumer: try: - # if the provider is already installed, delete the secret - logger.warning( - "Provider already installed, deleting secret", - extra={"error": str(e)}, - ) - secret_manager.delete_secret( - secret_name=secret_name, - ) - logger.warning("Secret deleted") + event_subscriber = EventSubscriber.get_instance() + event_subscriber.add_consumer(provider) except Exception: - logger.exception("Failed to delete the secret") - pass - raise HTTPException( - status_code=409, detail="Provider already installed" - ) + logger.exception("Failed to register provider as a consumer") - if provider_model.consumer: - try: - event_subscriber = EventSubscriber.get_instance() - event_subscriber.add_consumer(provider) - except Exception: - logger.exception("Failed to register provider as a consumer") - - return { - "provider": provider_model, - "type": provider_type, - "id": provider_unique_id, - "details": config, - "validatedScopes": validated_scopes, - } + return { + "provider": provider_model, + "type": provider_type, + "id": provider_unique_id, + "details": config, + "validatedScopes": validated_scopes, + } @staticmethod def update_provider( @@ -252,56 +229,106 @@ def update_provider( provider_id: str, provider_info: Dict[str, Any], updated_by: str, - session: Session, + session: Optional[Session] = None, ) -> Dict[str, Any]: - provider = session.exec( - select(Provider).where( - (Provider.tenant_id == tenant_id) & (Provider.id == provider_id) - ) - ).one_or_none() + with existed_or_new_session(session) as session: + provider = session.exec( + select(Provider).where( + (Provider.tenant_id == tenant_id) & (Provider.id == provider_id) + ) + ).one_or_none() - if not provider: - raise HTTPException(404, detail="Provider not found") + if not provider: + raise HTTPException(404, detail="Provider not found") - if provider.provisioned: - raise HTTPException(403, detail="Cannot update a provisioned provider") + if provider.provisioned: + raise HTTPException(403, detail="Cannot update a provisioned provider") - pulling_enabled = provider_info.pop("pulling_enabled", True) + pulling_enabled = provider_info.pop("pulling_enabled", True) - # if pulling_enabled is "true" or "false" cast it to boolean - if isinstance(pulling_enabled, str): - pulling_enabled = pulling_enabled.lower() == "true" + # if pulling_enabled is "true" or "false" cast it to boolean + if isinstance(pulling_enabled, str): + pulling_enabled = pulling_enabled.lower() == "true" - provider_config = { - "authentication": provider_info, - "name": provider.name, - } + provider_config = { + "authentication": provider_info, + "name": provider.name, + } - context_manager = ContextManager(tenant_id=tenant_id) - try: - provider_instance = ProvidersFactory.get_provider( - context_manager, provider_id, provider.type, provider_config + context_manager = ContextManager(tenant_id=tenant_id) + try: + provider_instance = ProvidersFactory.get_provider( + context_manager, provider_id, provider.type, provider_config + ) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + validated_scopes = provider_instance.validate_scopes() + + secret_manager = SecretManagerFactory.get_secret_manager(context_manager) + secret_manager.write_secret( + secret_name=provider.configuration_key, + secret_value=json.dumps(provider_config), ) - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - validated_scopes = provider_instance.validate_scopes() + provider.installed_by = updated_by + provider.validatedScopes = validated_scopes + provider.pulling_enabled = pulling_enabled + session.commit() - secret_manager = SecretManagerFactory.get_secret_manager(context_manager) - secret_manager.write_secret( - secret_name=provider.configuration_key, - secret_value=json.dumps(provider_config), - ) + return { + "provider": provider, + "details": provider_config, + "validatedScopes": validated_scopes, + } - provider.installed_by = updated_by - provider.validatedScopes = validated_scopes - provider.pulling_enabled = pulling_enabled - session.commit() + @staticmethod + def upsert_provider( + tenant_id: str, + provider_name: str, + provider_type: str, + provider_config: Dict[str, Any], + provisioned: bool = False, + validate_scopes: bool = True, + provisioned_providers_names: List[str] = [], + ) -> Dict[str, Any]: + installed_provider_info = None + try: + # First check if the provider is already installed + # If it is, update it, otherwise install it + if provider_name in provisioned_providers_names: + logger.info( + f"Provider {provider_name} already provisioned, updating..." + ) + provider = get_provider_by_name(tenant_id, provider_name) + installed_provider_info = ProvidersService.update_provider( + tenant_id=tenant_id, + provider_id=provider.id, + provider_info=provider_config, + updated_by="system", + ) + logger.info(f"Provider {provider_name} updated successfully") + else: + logger.info(f"Provider {provider_name} not existing, installing...") + installed_provider_info = ProvidersService.install_provider( + tenant_id=tenant_id, + installed_by="system", + provider_id=provider_type, + provider_name=provider_name, + provider_type=provider_type, + provider_config=provider_config, + provisioned=provisioned, + validate_scopes=validate_scopes, + ) + logger.info(f"Provider {provider_name} provisioned successfully") + except Exception as e: + logger.error( + "Error provisioning provider from env var", + extra={"exception": e}, + ) + raise HTTPException(status_code=400, detail=str(e)) - return { - "details": provider_config, - "validatedScopes": validated_scopes, - } + return installed_provider_info @staticmethod def delete_provider( @@ -309,7 +336,6 @@ def delete_provider( provider_id: str, session: Session, allow_provisioned=False, - commit: bool = True, ): provider_model: Provider = session.exec( select(Provider).where( @@ -355,8 +381,7 @@ def delete_provider( logger.exception(msg="Provider deleted but failed to clean up provider") session.delete(provider_model) - if commit: - session.commit() + session.commit() @staticmethod def validate_provider_scopes( @@ -423,102 +448,6 @@ def provision_provider_deduplication_rules( provider=provider, ) - @staticmethod - def write_provisioned_hash(tenant_id: str, hash_value: str): - """ - Write the provisioned hash to Redis or secret manager. - - Args: - tenant_id (str): The tenant ID. - hash_value (str): The hash value to write. - """ - if REDIS: - r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) - r.set(f"{tenant_id}_providers_hash", hash_value) - logger.info(f"Provisioned hash for tenant {tenant_id} written to Redis!") - else: - context_manager = ContextManager(tenant_id=tenant_id) - secret_manager = SecretManagerFactory.get_secret_manager(context_manager) - secret_manager.write_secret( - secret_name=f"{tenant_id}_providers_hash", - secret_value=hash_value, - ) - logger.info( - f"Provisioned hash for tenant {tenant_id} written to secret manager!" - ) - - @staticmethod - def get_provisioned_hash(tenant_id: str) -> Optional[str]: - """ - Get the provisioned hash from Redis or secret manager. - - Args: - tenant_id (str): The tenant ID. - - Returns: - Optional[str]: The provisioned hash, or None if not found. - """ - previous_hash = None - if REDIS: - try: - with redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) as r: - previous_hash = r.get(f"{tenant_id}_providers_hash") - if isinstance(previous_hash, bytes): - previous_hash = previous_hash.decode("utf-8").strip() - logger.info( - f"Provisioned hash for tenant {tenant_id}: {previous_hash or 'Not found'}" - ) - except redis.RedisError as e: - logger.warning(f"Redis error for tenant {tenant_id}: {e}") - - if previous_hash is None: - try: - context_manager = ContextManager(tenant_id=tenant_id) - secret_manager = SecretManagerFactory.get_secret_manager( - context_manager - ) - previous_hash = secret_manager.read_secret( - f"{tenant_id}_providers_hash" - ) - logger.info( - f"Provisioned hash for tenant {tenant_id} read from secret manager." - ) - except Exception as e: - logger.warning( - f"Failed to read hash from secret manager for tenant {tenant_id}: {e}" - ) - - return previous_hash if previous_hash else None - - @staticmethod - def calculate_provider_hash( - provisioned_providers_dir: Optional[str] = None, - provisioned_providers_json: Optional[str] = None, - ) -> str: - """ - Calculate the hash of the provider configurations. - - Args: - provisioned_providers_dir (Optional[str]): Directory containing provider YAML files. - provisioned_providers_json (Optional[str]): JSON string of provider configurations. - - Returns: - str: SHA256 hash of the provider configurations. - """ - if provisioned_providers_json: - providers_data = provisioned_providers_json - elif provisioned_providers_dir: - providers_data = [] - for file in os.listdir(provisioned_providers_dir): - if file.endswith((".yaml", ".yml")): - provider_path = os.path.join(provisioned_providers_dir, file) - with open(provider_path, "r") as yaml_file: - providers_data.append(yaml_file.read()) - else: - providers_data = "" # No providers to provision - - return hashlib.sha256(json.dumps(providers_data).encode("utf-8")).hexdigest() - @staticmethod def provision_providers(tenant_id: str): """ @@ -549,6 +478,10 @@ def provision_providers(tenant_id: str): # Get all existing provisioned providers provisioned_providers = get_all_provisioned_providers(tenant_id) + provisioned_providers_names = [ + provider.name for provider in provisioned_providers + ] + incoming_providers_names = set() if not (provisioned_providers_dir or provisioned_providers_json): if provisioned_providers: @@ -559,89 +492,90 @@ def provision_providers(tenant_id: str): logger.info("No providers for provisioning found. Nothing to do.") return - # Calculate the hash of the provider configurations - providers_hash = ProvidersService.calculate_provider_hash( - provisioned_providers_dir, provisioned_providers_json - ) + try: + ### Provisioning from env var + if provisioned_providers_json is not None: + # Avoid circular import + from keep.parser.parser import Parser - # Get the previous hash from Redis or secret manager - previous_hash = ProvidersService.get_provisioned_hash(tenant_id) - if providers_hash == previous_hash: - logger.info( - "Provider configurations have not changed. Skipping provisioning." - ) - return - else: - logger.info("Provider configurations have changed. Provisioning providers.") + parser = Parser() + context_manager = ContextManager(tenant_id=tenant_id) + parser._parse_providers_from_env(context_manager) + env_providers = context_manager.providers_context - # Do all the provisioning within a transaction - session = Session(engine) - try: - with session.begin(): - ### We do delete all the provisioned providers and begin provisioning from the beginning. - logger.info( - f"Deleting all provisioned providers for tenant {tenant_id}" - ) - for provisioned_provider in provisioned_providers: + for provider_name, provider_config in env_providers.items(): try: - logger.info(f"Deleting provider {provisioned_provider.name}") - ProvidersService.delete_provider( - tenant_id, - provisioned_provider.id, - session, - allow_provisioned=True, - commit=False, + provider_name = provider_config["name"] + provider_type = provider_config["type"] + provider_config = provider_config["authentication"] + + # Perform upsert operation for the provider + installed_provider_info = ProvidersService.upsert_provider( + tenant_id=tenant_id, + provider_name=provider_name, + provider_type=provider_type, + provider_config=provider_config, + provisioned=True, + validate_scopes=False, + provisioned_providers_names=provisioned_providers_names, ) - logger.info(f"Provider {provisioned_provider.name} deleted") except Exception as e: - logger.exception( - "Failed to delete provisioned provider", + logger.error( + "Error provisioning provider from env var", extra={"exception": e}, ) continue - # Flush the session to ensure all deletions are committed - session.flush() + provider = installed_provider_info["provider"] + incoming_providers_names.add(provider_name) - ### Provisioning from env var - if provisioned_providers_json is not None: - # Avoid circular import - from keep.parser.parser import Parser + # Configure deduplication rules + deduplication_rules = provider_config.get("deduplication_rules", {}) + if deduplication_rules: + logger.info( + f"Provisioning deduplication rules for provider {provider_name}" + ) + ProvidersService.provision_provider_deduplication_rules( + tenant_id=tenant_id, + provider=provider, + deduplication_rules=deduplication_rules, + ) - parser = Parser() - context_manager = ContextManager(tenant_id=tenant_id) - parser._parse_providers_from_env(context_manager) - env_providers = context_manager.providers_context + ### Provisioning from the directory + if provisioned_providers_dir is not None: + for file in os.listdir(provisioned_providers_dir): + if file.endswith((".yaml", ".yml")): + logger.info(f"Provisioning provider from {file}") + provider_path = os.path.join(provisioned_providers_dir, file) - for provider_name, provider_config in env_providers.items(): - # We skip checking if the provider is already installed, as it will skip the new configurations - # and we want to update the provisioned provider with the new configuration - logger.info(f"Provisioning provider {provider_name}") try: - installed_provider_info = ProvidersService.install_provider( + with open(provider_path, "r") as yaml_file: + provider_yaml = cyaml.safe_load(yaml_file.read()) + provider_name = provider_yaml["name"] + provider_type = provider_yaml["type"] + provider_config = provider_yaml.get( + "authentication", {} + ) + + # Perform upsert operation for the provider + installed_provider_info = ProvidersService.upsert_provider( tenant_id=tenant_id, - installed_by="system", - provider_id=provider_config["type"], provider_name=provider_name, - provider_type=provider_config["type"], - provider_config=provider_config["authentication"], + provider_type=provider_type, + provider_config=provider_config, provisioned=True, validate_scopes=False, - session=session, - commit=False, - ) - provider = installed_provider_info["provider"] - logger.info( - f"Provider {provider_name} provisioned successfully" + provisioned_providers_names=provisioned_providers_names, ) except Exception as e: logger.error( - "Error provisioning provider from env var", + "Error provisioning provider from directory", extra={"exception": e}, ) + continue - # Flush the provider so that we can provision its deduplication rules - session.flush() + provider = installed_provider_info["provider"] + incoming_providers_names.add(provider_name) # Configure deduplication rules deduplication_rules = provider_config.get( @@ -657,78 +591,33 @@ def provision_providers(tenant_id: str): deduplication_rules=deduplication_rules, ) - ### Provisioning from the directory - if provisioned_providers_dir is not None: - for file in os.listdir(provisioned_providers_dir): - if file.endswith((".yaml", ".yml")): - logger.info(f"Provisioning provider from {file}") - provider_path = os.path.join( - provisioned_providers_dir, file - ) + # Delete providers that are not in the incoming list + for provider in provisioned_providers: + if provider.name not in incoming_providers_names: + try: + logger.info( + f"Provider {provider.name} not found in incoming provisioned providers, deleting..." + ) + ProvidersService.delete_provider( + tenant_id=tenant_id, + provider_id=provider.id, + session=None, + allow_provisioned=True, + ) + logger.info(f"Provider {provider.name} deleted successfully") + except Exception as e: + logger.error( + f"Error deleting provider {provider.name}", + extra={"exception": e}, + ) + continue - try: - with open(provider_path, "r") as yaml_file: - provider_yaml = cyaml.safe_load(yaml_file.read()) - provider_name = provider_yaml["name"] - provider_type = provider_yaml["type"] - provider_config = provider_yaml.get( - "authentication", {} - ) - - # We skip checking if the provider is already installed, as it will skip the new configurations - # and we want to update the provisioned provider with the new configuration - logger.info(f"Installing provider {provider_name}") - installed_provider_info = ( - ProvidersService.install_provider( - tenant_id=tenant_id, - installed_by="system", - provider_id=provider_type, - provider_name=provider_name, - provider_type=provider_type, - provider_config=provider_config, - provisioned=True, - validate_scopes=False, - session=session, - commit=False, - ) - ) - provider = installed_provider_info["provider"] - logger.info( - f"Provider {provider_name} provisioned successfully" - ) - - # Flush the provider so that we can provision its deduplication rules - session.flush() - - # Configure deduplication rules - deduplication_rules = provider_yaml.get( - "deduplication_rules", {} - ) - if deduplication_rules: - logger.info( - f"Provisioning deduplication rules for provider {provider_name}" - ) - ProvidersService.provision_provider_deduplication_rules( - tenant_id=tenant_id, - provider=provider, - deduplication_rules=deduplication_rules, - ) - except Exception as e: - logger.error( - "Error provisioning provider from directory", - extra={"exception": e}, - ) - continue + logger.info( + "Provisioning completed successfully. Provisioned providers: %s", + incoming_providers_names, + ) except Exception as e: logger.error("Provisioning failed, rolling back", extra={"exception": e}) - session.rollback() - finally: - # Store the hash in Redis or secret manager - try: - ProvidersService.write_provisioned_hash(tenant_id, providers_hash) - except Exception as e: - logger.warning(f"Failed to store hash: {e}") - session.close() @staticmethod def get_provider_logs( diff --git a/tests/test_providers_provisioning_caching.py b/tests/test_providers_provisioning_caching.py deleted file mode 100644 index d6df0a635c..0000000000 --- a/tests/test_providers_provisioning_caching.py +++ /dev/null @@ -1,303 +0,0 @@ -import hashlib -import json -from unittest.mock import MagicMock, mock_open, patch - -import pytest -import redis - -from keep.api.consts import REDIS_DB, REDIS_HOST, REDIS_PORT -from keep.providers.providers_service import ProvidersService - - -@pytest.fixture -def tenant_id(): - return "test_tenant" - - -@pytest.fixture -def hash_value(): - return "test_hash" - - -def test_write_provisioned_hash_redis(tenant_id, hash_value): - """Test writing hash to Redis when Redis is enabled""" - with patch("keep.providers.providers_service.REDIS", True), patch( - "redis.Redis" - ) as mock_redis: - mock_redis_instance = MagicMock() - mock_redis.return_value = mock_redis_instance - - ProvidersService.write_provisioned_hash(tenant_id, hash_value) - - mock_redis.assert_called_once_with( - host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB - ) - mock_redis_instance.set.assert_called_once_with( - f"{tenant_id}_providers_hash", hash_value - ) - - -def test_write_provisioned_hash_secret_manager(tenant_id, hash_value): - """Test writing hash to secret manager when Redis is disabled""" - mock_secret_manager = MagicMock() - - with patch("keep.providers.providers_service.REDIS", False), patch( - "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" - ) as mock_get_secret_manager: - mock_get_secret_manager.return_value = mock_secret_manager - - ProvidersService.write_provisioned_hash(tenant_id, hash_value) - - mock_secret_manager.write_secret.assert_called_once_with( - secret_name=f"{tenant_id}_providers_hash", secret_value=hash_value - ) - - -def test_get_provisioned_hash_redis_success(tenant_id, hash_value): - """Test getting hash from Redis successfully""" - with patch("keep.providers.providers_service.REDIS", True), patch( - "redis.Redis" - ) as mock_redis: - mock_redis_instance = MagicMock() - mock_redis_instance.get.return_value = hash_value.encode() - mock_redis.return_value.__enter__.return_value = mock_redis_instance - - result = ProvidersService.get_provisioned_hash(tenant_id) - - assert result == hash_value - mock_redis_instance.get.assert_called_once_with(f"{tenant_id}_providers_hash") - - -def test_get_provisioned_hash_redis_error(tenant_id, hash_value): - """Test falling back to secret manager when Redis fails""" - mock_secret_manager = MagicMock() - mock_secret_manager.read_secret.return_value = hash_value - - with patch("keep.providers.providers_service.REDIS", True), patch( - "redis.Redis" - ) as mock_redis, patch( - "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" - ) as mock_get_secret_manager: - mock_redis.return_value.__enter__.side_effect = redis.RedisError("Test error") - mock_get_secret_manager.return_value = mock_secret_manager - - result = ProvidersService.get_provisioned_hash(tenant_id) - - assert result == hash_value - mock_secret_manager.read_secret.assert_called_once_with( - f"{tenant_id}_providers_hash" - ) - - -def test_get_provisioned_hash_secret_manager_success(tenant_id, hash_value): - """Test getting hash from secret manager successfully""" - mock_secret_manager = MagicMock() - mock_secret_manager.read_secret.return_value = hash_value - - with patch("keep.providers.providers_service.REDIS", False), patch( - "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" - ) as mock_get_secret_manager: - mock_get_secret_manager.return_value = mock_secret_manager - - result = ProvidersService.get_provisioned_hash(tenant_id) - - assert result == hash_value - mock_secret_manager.read_secret.assert_called_once_with( - f"{tenant_id}_providers_hash" - ) - - -def test_get_provisioned_hash_secret_manager_error(tenant_id): - """Test handling secret manager error""" - mock_secret_manager = MagicMock() - mock_secret_manager.read_secret.side_effect = Exception("Secret not found") - - with patch("keep.providers.providers_service.REDIS", False), patch( - "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" - ) as mock_get_secret_manager: - mock_get_secret_manager.return_value = mock_secret_manager - - result = ProvidersService.get_provisioned_hash(tenant_id) - - assert result is None - - -def test_calculate_provider_hash_json(): - """Test calculating hash from JSON input""" - json_input = '{"provider": "test"}' - expected_hash = hashlib.sha256(json.dumps(json_input).encode("utf-8")).hexdigest() - - result = ProvidersService.calculate_provider_hash( - provisioned_providers_json=json_input - ) - - assert result == expected_hash - - -def test_calculate_provider_hash_directory(): - """Test calculating hash from directory input""" - test_dir = "/test/providers" - yaml_content = "provider: test" - - with patch("os.listdir") as mock_listdir, patch("os.path.join") as mock_join, patch( - "builtins.open", mock_open(read_data=yaml_content) - ): - mock_listdir.return_value = ["provider1.yaml", "provider2.yml", "other.txt"] - mock_join.side_effect = lambda *args: f"{args[0]}/{args[1]}" - - result = ProvidersService.calculate_provider_hash( - provisioned_providers_dir=test_dir - ) - - expected_data = [yaml_content, yaml_content] # Two YAML files - expected_hash = hashlib.sha256( - json.dumps(expected_data).encode("utf-8") - ).hexdigest() - - assert result == expected_hash - assert mock_listdir.call_count == 1 - assert mock_join.call_count == 2 - - -def test_calculate_provider_hash_no_input(): - """Test calculating hash with no input""" - result = ProvidersService.calculate_provider_hash() - expected_hash = hashlib.sha256(json.dumps("").encode("utf-8")).hexdigest() - - assert result == expected_hash - - -def test_write_provisioned_hash_redis_enabled(tenant_id, hash_value): - """Test writing hash to Redis when Redis is enabled""" - with patch("keep.providers.providers_service.REDIS", True), patch( - "redis.Redis" - ) as mock_redis: - mock_redis_instance = MagicMock() - mock_redis.return_value = mock_redis_instance - - ProvidersService.write_provisioned_hash(tenant_id, hash_value) - - mock_redis.assert_called_once_with( - host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB - ) - mock_redis_instance.set.assert_called_once_with( - f"{tenant_id}_providers_hash", hash_value - ) - - -def test_write_provisioned_hash_redis_disabled_secret_manager(tenant_id, hash_value): - """Test writing hash to secret manager when Redis is disabled""" - mock_secret_manager = MagicMock() - - with patch("keep.providers.providers_service.REDIS", False), patch( - "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" - ) as mock_get_secret_manager: - mock_get_secret_manager.return_value = mock_secret_manager - - ProvidersService.write_provisioned_hash(tenant_id, hash_value) - - mock_secret_manager.write_secret.assert_called_once_with( - secret_name=f"{tenant_id}_providers_hash", secret_value=hash_value - ) - - -def test_get_provisioned_hash_redis_enabled_success(tenant_id, hash_value): - """Test getting hash from Redis successfully when Redis is enabled""" - with patch("keep.providers.providers_service.REDIS", True), patch( - "redis.Redis" - ) as mock_redis: - mock_redis_instance = MagicMock() - mock_redis_instance.get.return_value = hash_value.encode() - mock_redis.return_value.__enter__.return_value = mock_redis_instance - - result = ProvidersService.get_provisioned_hash(tenant_id) - - assert result == hash_value - mock_redis_instance.get.assert_called_once_with(f"{tenant_id}_providers_hash") - - -def test_get_provisioned_hash_redis_enabled_none_value(tenant_id): - """Test getting None from Redis when Redis is enabled but no value exists""" - with patch("keep.providers.providers_service.REDIS", True), patch( - "redis.Redis" - ) as mock_redis: - # Mock Redis returning None - mock_redis_instance = MagicMock() - mock_redis_instance.get.return_value = None - mock_redis.return_value.__enter__.return_value = mock_redis_instance - - result = ProvidersService.get_provisioned_hash(tenant_id) - - assert result is None - mock_redis_instance.get.assert_called_once_with(f"{tenant_id}_providers_hash") - - -def test_get_provisioned_hash_redis_preferred(tenant_id, hash_value): - """Test that Redis is preferred over secret manager when Redis works""" - with patch("keep.providers.providers_service.REDIS", True), patch( - "redis.Redis" - ) as mock_redis, patch( - "keep.providers.providers_service.SecretManagerFactory.get_secret_manager" - ) as mock_get_secret_manager: - mock_redis_instance = MagicMock() - mock_redis_instance.get.return_value = hash_value.encode() - mock_redis.return_value.__enter__.return_value = mock_redis_instance - - result = ProvidersService.get_provisioned_hash(tenant_id) - - assert result == hash_value - # Should not try to use secret manager when Redis works - mock_get_secret_manager.assert_not_called() - - -def test_get_provisioned_hash_redis_enabled_byte_decoding(tenant_id): - """Test proper decoding of bytes from Redis""" - encoded_hash = b"test_hash_with_whitespace \n" - expected_hash = "test_hash_with_whitespace" - - with patch("keep.providers.providers_service.REDIS", True), patch( - "redis.Redis" - ) as mock_redis: - mock_redis_instance = MagicMock() - mock_redis_instance.get.return_value = encoded_hash - mock_redis.return_value.__enter__.return_value = mock_redis_instance - - result = ProvidersService.get_provisioned_hash(tenant_id) - - assert result == expected_hash - mock_redis_instance.get.assert_called_once_with(f"{tenant_id}_providers_hash") - - -def test_calculate_provider_hash_consistency(tenant_id): - """Test that hash calculation is consistent for the same input""" - - # Test with JSON input - json_input_1 = '{"provider": "test"}' - json_input_2 = '{"provider": "test"}' - - hash_1 = ProvidersService.calculate_provider_hash( - provisioned_providers_json=json_input_1 - ) - hash_2 = ProvidersService.calculate_provider_hash( - provisioned_providers_json=json_input_2 - ) - - assert hash_1 == hash_2 - - # Test with directory input - yaml_content = "provider: test" - - with patch("os.listdir") as mock_listdir, patch("os.path.join") as mock_join, patch( - "builtins.open", mock_open(read_data=yaml_content) - ): - mock_listdir.return_value = ["provider1.yaml"] - mock_join.side_effect = lambda *args: f"{args[0]}/{args[1]}" - - hash_3 = ProvidersService.calculate_provider_hash( - provisioned_providers_dir="/test/dir" - ) - hash_4 = ProvidersService.calculate_provider_hash( - provisioned_providers_dir="/test/dir" - ) - - assert hash_3 == hash_4 From 68e4cdff2da82b04996dda7c934f4bba2494ad09 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sat, 12 Apr 2025 22:52:09 +0700 Subject: [PATCH 12/18] refactor: enhance deduplication rules provisioning logic --- .../deduplication_rules_provisioning.py | 26 +- keep/api/core/db.py | 43 +++- keep/providers/providers_service.py | 231 +++++++++++------- 3 files changed, 200 insertions(+), 100 deletions(-) diff --git a/keep/api/alert_deduplicator/deduplication_rules_provisioning.py b/keep/api/alert_deduplicator/deduplication_rules_provisioning.py index 215c50a1f8..d1741e4479 100644 --- a/keep/api/alert_deduplicator/deduplication_rules_provisioning.py +++ b/keep/api/alert_deduplicator/deduplication_rules_provisioning.py @@ -21,9 +21,11 @@ def provision_deduplication_rules( tenant_id (str): The ID of the tenant for which deduplication rules are being provisioned. provider (Provider): The provider for which the deduplication rules are being provisioned. """ - enrich_with_providers_info(deduplication_rules, provider) + enrich_with_provider_info(deduplication_rules, provider) - all_deduplication_rules_from_db = db.get_all_deduplication_rules(tenant_id) + all_deduplication_rules_from_db = db.get_all_deduplication_rules_by_provider( + tenant_id, provider.id, provider.type + ) provisioned_deduplication_rules = [ rule for rule in all_deduplication_rules_from_db if rule.is_provisioned ] @@ -89,8 +91,26 @@ def provision_deduplication_rules( is_provisioned=True, ) + logger.info( + "Provisioned deduplication rules %s successfully", + deduplication_rule_name, + ) + + # Delete provisioned deduplication rules that are not provisioned anymore + for rule_name, rule in provisioned_deduplication_rules_from_db_dict.items(): + if rule_name not in deduplication_rules: + logger.info( + "Deduplication rule with name '%s' no longer in configuration, deleting from DB", + rule_name, + ) + db.delete_deduplication_rule(rule_id=str(rule.id), tenant_id=tenant_id) + logger.info( + "Deleted deduplication rule %s successfully", + rule_name, + ) + -def enrich_with_providers_info(deduplication_rules: dict[str, any], provider: Provider): +def enrich_with_provider_info(deduplication_rules: dict[str, any], provider: Provider): """ Enriches passed deduplication rules with provider ID and type information. diff --git a/keep/api/core/db.py b/keep/api/core/db.py index 8595d22972..3c2ee18fbd 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -40,7 +40,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.exc import IntegrityError, OperationalError -from sqlalchemy.orm import joinedload, subqueryload, foreign +from sqlalchemy.orm import foreign, joinedload, subqueryload from sqlalchemy.orm.exc import StaleDataError from sqlalchemy.sql import exists, expression from sqlmodel import Session, SQLModel, col, or_, select, text @@ -2242,6 +2242,18 @@ def get_all_deduplication_rules(tenant_id): return rules +def get_all_deduplication_rules_by_provider(tenant_id, provider_id, provider_type): + with Session(engine) as session: + rules = session.exec( + select(AlertDeduplicationRule).where( + AlertDeduplicationRule.tenant_id == tenant_id, + AlertDeduplicationRule.provider_id == provider_id, + AlertDeduplicationRule.provider_type == provider_type, + ) + ).all() + return rules + + def get_deduplication_rule_by_id(tenant_id, rule_id: str): rule_uuid = __convert_to_uuid(rule_id) if not rule_uuid: @@ -3679,18 +3691,23 @@ def get_incident_by_id( if isinstance(incident_id, str): incident_id = __convert_to_uuid(incident_id, should_raise=True) with existed_or_new_session(session) as session: - query = session.query( - Incident, - AlertEnrichment, - ).outerjoin( - AlertEnrichment, - and_( - Incident.tenant_id == AlertEnrichment.tenant_id, - cast(col(Incident.id), String) == foreign(AlertEnrichment.alert_fingerprint), - ), - ).filter( - Incident.tenant_id == tenant_id, - Incident.id == incident_id, + query = ( + session.query( + Incident, + AlertEnrichment, + ) + .outerjoin( + AlertEnrichment, + and_( + Incident.tenant_id == AlertEnrichment.tenant_id, + cast(col(Incident.id), String) + == foreign(AlertEnrichment.alert_fingerprint), + ), + ) + .filter( + Incident.tenant_id == tenant_id, + Incident.id == incident_id, + ) ) incident_with_enrichments = query.first() if incident_with_enrichments: diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py index 66aae8aa4a..0cbf3f3c4b 100644 --- a/keep/providers/providers_service.py +++ b/keep/providers/providers_service.py @@ -14,8 +14,10 @@ ) from keep.api.core.config import config from keep.api.core.db import ( + delete_deduplication_rule, engine, existed_or_new_session, + get_all_deduplication_rules_by_provider, get_all_provisioned_providers, get_provider_by_name, get_provider_logs, @@ -231,6 +233,7 @@ def update_provider( provider_info: Dict[str, Any], updated_by: str, session: Optional[Session] = None, + allow_provisioned: bool = False, ) -> Dict[str, Any]: with existed_or_new_session(session) as session: provider = session.exec( @@ -242,7 +245,7 @@ def update_provider( if not provider: raise HTTPException(404, detail="Provider not found") - if provider.provisioned: + if provider.provisioned and not allow_provisioned: raise HTTPException(403, detail="Cannot update a provisioned provider") pulling_enabled = provider_info.pop("pulling_enabled", True) @@ -307,6 +310,7 @@ def upsert_provider( provider_id=provider.id, provider_info=provider_config, updated_by="system", + allow_provisioned=True, ) logger.info(f"Provider {provider_name} updated successfully") else: @@ -335,54 +339,71 @@ def upsert_provider( def delete_provider( tenant_id: str, provider_id: str, - session: Session, + session: Optional[Session] = None, allow_provisioned=False, ): - provider_model: Provider = session.exec( - select(Provider).where( - (Provider.tenant_id == tenant_id) & (Provider.id == provider_id) - ) - ).one_or_none() + with existed_or_new_session(session) as session: + provider_model: Optional[Provider] = session.exec( + select(Provider).where( + (Provider.tenant_id == tenant_id) & (Provider.id == provider_id) + ) + ).one_or_none() - if not provider_model: - raise HTTPException(404, detail="Provider not found") + if provider_model is None: + raise HTTPException(404, detail="Provider not found") - if provider_model.provisioned and not allow_provisioned: - raise HTTPException(403, detail="Cannot delete a provisioned provider") + if provider_model.provisioned and not allow_provisioned: + raise HTTPException(403, detail="Cannot delete a provisioned provider") - context_manager = ContextManager(tenant_id=tenant_id) - secret_manager = SecretManagerFactory.get_secret_manager(context_manager) - config = secret_manager.read_secret( - provider_model.configuration_key, is_json=True - ) + # Delete all associated deduplication rules + try: + deduplication_rules = get_all_deduplication_rules_by_provider( + tenant_id, provider_model.id, provider_model.type + ) + for rule in deduplication_rules: + logger.info( + f"Deleting deduplication rule {rule.name} for provider {provider_model.name}" + ) + delete_deduplication_rule(str(rule.id), tenant_id) + except Exception as e: + logger.exception( + "Failed to delete deduplication rules for provider", + extra={"exception": e}, + ) - try: - secret_manager.delete_secret(provider_model.configuration_key) - except Exception: - logger.exception("Failed to delete the provider secret") + context_manager = ContextManager(tenant_id=tenant_id) + secret_manager = SecretManagerFactory.get_secret_manager(context_manager) + config = secret_manager.read_secret( + provider_model.configuration_key, is_json=True + ) - if provider_model.consumer: try: - event_subscriber = EventSubscriber.get_instance() - event_subscriber.remove_consumer(provider_model) + secret_manager.delete_secret(provider_model.configuration_key) except Exception: - logger.exception("Failed to unregister provider as a consumer") + logger.exception("Failed to delete the provider secret") - try: - provider = ProvidersFactory.get_provider( - context_manager, provider_model.id, provider_model.type, config - ) - provider.clean_up() - except NotImplementedError: - logger.info( - "Being deleted provider of type %s does not have a clean_up method", - provider_model.type, - ) - except Exception: - logger.exception(msg="Provider deleted but failed to clean up provider") + if provider_model.consumer: + try: + event_subscriber = EventSubscriber.get_instance() + event_subscriber.remove_consumer(provider_model) + except Exception: + logger.exception("Failed to unregister provider as a consumer") + + try: + provider = ProvidersFactory.get_provider( + context_manager, provider_model.id, provider_model.type, config + ) + provider.clean_up() + except NotImplementedError: + logger.info( + "Being deleted provider of type %s does not have a clean_up method", + provider_model.type, + ) + except Exception: + logger.exception(msg="Provider deleted but failed to clean up provider") - session.delete(provider_model) - session.commit() + session.delete(provider_model) + session.commit() @staticmethod def validate_provider_scopes( @@ -423,6 +444,7 @@ def provision_provider_deduplication_rules( tenant_id: str, provider: Provider, deduplication_rules: Dict[str, Dict[str, Any]], + session: Optional[Session] = None, ): """ Provision deduplication rules for a provider. @@ -431,29 +453,42 @@ def provision_provider_deduplication_rules( tenant_id (str): The tenant ID. provider (Provider): The provider to provision the deduplication rules for. deduplication_rules (Dict[str, Dict[str, Any]]): The deduplication rules to provision. + session (Optional[Session]): SQLAlchemy session to use. """ + with existed_or_new_session(session) as session: + # Ensure provider is attached to the session + if provider not in session: + provider = session.merge(provider) + + # Provision the deduplication rules + deduplication_rules_dict: dict[str, dict] = {} + for rule_name, rule_config in deduplication_rules.items(): + logger.info(f"Provisioning deduplication rule {rule_name}") + rule_config["name"] = rule_name + rule_config["provider_name"] = provider.name + rule_config["provider_type"] = provider.type + deduplication_rules_dict[rule_name] = rule_config - # Provision the deduplication rules - deduplication_rules_dict: dict[str, dict] = {} - for rule_name, rule_config in deduplication_rules.items(): - logger.info(f"Provisioning deduplication rule {rule_name}") - rule_config["name"] = rule_name - rule_config["provider_name"] = provider.name - rule_config["provider_type"] = provider.type - deduplication_rules_dict[rule_name] = rule_config - - # Provision deduplication rules - provision_deduplication_rules( - deduplication_rules=deduplication_rules_dict, - tenant_id=tenant_id, - provider=provider, - ) + try: + # Provision deduplication rules + provision_deduplication_rules( + deduplication_rules=deduplication_rules_dict, + tenant_id=tenant_id, + provider=provider, + ) + except Exception as e: + logger.error( + "Provisioning failed, rolling back", extra={"exception": e} + ) + session.rollback() + raise def install_webhook( tenant_id: str, provider_type: str, provider_id: str, session: Session ) -> bool: context_manager = ContextManager( - tenant_id=tenant_id, workflow_id="" # this is not in a workflow scope + tenant_id=tenant_id, + workflow_id="", # this is not in a workflow scope ) secret_manager = SecretManagerFactory.get_secret_manager(context_manager) provider_secret_name = f"{tenant_id}_{provider_type}_{provider_id}" @@ -566,11 +601,20 @@ def provision_providers(tenant_id: str): parser._parse_providers_from_env(context_manager) env_providers = context_manager.providers_context - for provider_name, provider_config in env_providers.items(): + for provider_name, provider_info in env_providers.items(): + # We need this to avoid failure in upsert operation results in + # the deletion of the old provisioned provider + incoming_providers_names.add(provider_name) + try: - provider_name = provider_config["name"] - provider_type = provider_config["type"] - provider_config = provider_config["authentication"] + provider_type = provider_info.get("type") + if not provider_type: + logger.error( + f"Provider {provider_name} does not have a type" + ) + continue + + provider_config = provider_info.get("authentication", {}) # Perform upsert operation for the provider installed_provider_info = ProvidersService.upsert_provider( @@ -590,19 +634,20 @@ def provision_providers(tenant_id: str): continue provider = installed_provider_info["provider"] - incoming_providers_names.add(provider_name) # Configure deduplication rules - deduplication_rules = provider_config.get("deduplication_rules", {}) - if deduplication_rules: + deduplication_rules = provider_info.get("deduplication_rules") + if deduplication_rules is not None: logger.info( f"Provisioning deduplication rules for provider {provider_name}" ) - ProvidersService.provision_provider_deduplication_rules( - tenant_id=tenant_id, - provider=provider, - deduplication_rules=deduplication_rules, - ) + with Session(engine) as session: + ProvidersService.provision_provider_deduplication_rules( + tenant_id=tenant_id, + provider=provider, + deduplication_rules=deduplication_rules, + session=session, + ) ### Provisioning from the directory if provisioned_providers_dir is not None: @@ -613,10 +658,26 @@ def provision_providers(tenant_id: str): try: with open(provider_path, "r") as yaml_file: - provider_yaml = cyaml.safe_load(yaml_file.read()) - provider_name = provider_yaml["name"] - provider_type = provider_yaml["type"] - provider_config = provider_yaml.get( + provider_info = cyaml.safe_load(yaml_file.read()) + provider_name = provider_info.get("name") + if not provider_name: + logger.error( + f"Provider {provider_path} does not have a name" + ) + continue + + # We need this to avoid failure in upsert operation results in + # the deletion of the old provisioned provider + incoming_providers_names.add(provider_name) + + provider_type = provider_info.get("type") + if not provider_type: + logger.error( + f"Provider {provider_path} does not have a type" + ) + continue + + provider_config = provider_info.get( "authentication", {} ) @@ -638,21 +699,20 @@ def provision_providers(tenant_id: str): continue provider = installed_provider_info["provider"] - incoming_providers_names.add(provider_name) # Configure deduplication rules - deduplication_rules = provider_config.get( - "deduplication_rules", {} - ) - if deduplication_rules: + deduplication_rules = provider_info.get("deduplication_rules") + if deduplication_rules is not None: logger.info( f"Provisioning deduplication rules for provider {provider_name}" ) - ProvidersService.provision_provider_deduplication_rules( - tenant_id=tenant_id, - provider=provider, - deduplication_rules=deduplication_rules, - ) + with Session(engine) as session: + ProvidersService.provision_provider_deduplication_rules( + tenant_id=tenant_id, + provider=provider, + deduplication_rules=deduplication_rules, + session=session, + ) # Delete providers that are not in the incoming list for provider in provisioned_providers: @@ -664,7 +724,6 @@ def provision_providers(tenant_id: str): ProvidersService.delete_provider( tenant_id=tenant_id, provider_id=provider.id, - session=None, allow_provisioned=True, ) logger.info(f"Provider {provider.name} deleted successfully") @@ -676,11 +735,15 @@ def provision_providers(tenant_id: str): continue logger.info( - "Provisioning completed successfully. Provisioned providers: %s", - incoming_providers_names, + "Providers provisioning completed. Provisioned providers: %s", + ( + ", ".join(incoming_providers_names) + if incoming_providers_names + else "None" + ), ) except Exception as e: - logger.error("Provisioning failed, rolling back", extra={"exception": e}) + logger.error("Provisioning failed", extra={"exception": e}) @staticmethod def get_provider_logs( From 456544d8710a7ecdce17d293457357e25561cc02 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sat, 12 Apr 2025 22:52:39 +0700 Subject: [PATCH 13/18] test: add unit tests for provider deletion and deduplication rule handling --- tests/test_providers_service.py | 89 ++++++++++++ tests/test_provisioning.py | 236 ++++++++++++++++++++++++++++++-- 2 files changed, 313 insertions(+), 12 deletions(-) create mode 100644 tests/test_providers_service.py diff --git a/tests/test_providers_service.py b/tests/test_providers_service.py new file mode 100644 index 0000000000..89544a55d7 --- /dev/null +++ b/tests/test_providers_service.py @@ -0,0 +1,89 @@ +from unittest.mock import MagicMock, patch + +import pytest +from sqlmodel import Session + +from keep.providers.providers_service import ProvidersService + + +@pytest.fixture +def mock_db_session(): + session = MagicMock(spec=Session) + return session + + +@pytest.fixture +def mock_provider(): + provider = MagicMock() + provider.id = "test-provider-id" + provider.type = "test-provider-type" + provider.name = "test-provider" + provider.tenant_id = "test-tenant-id" + provider.provisioned = False + return provider + + +@patch("keep.providers.providers_service.ContextManager") +@patch("keep.providers.providers_service.SecretManagerFactory") +@patch("keep.providers.providers_service.ProvidersFactory") +@patch("keep.providers.providers_service.EventSubscriber") +@patch("keep.providers.providers_service.select") +@patch("keep.providers.providers_service.get_all_deduplication_rules_by_provider") +@patch("keep.providers.providers_service.delete_deduplication_rule") +def test_delete_provider_cascade_deletes_deduplication_rules( + mock_delete_deduplication_rule, + mock_get_rules, + mock_select, + mock_event_subscriber, + mock_providers_factory, + mock_secret_manager_factory, + mock_context_manager, + mock_provider, + mock_db_session, +): + # Set up mocks + mock_select_obj = MagicMock() + mock_select.return_value = mock_select_obj + mock_where_obj = MagicMock() + mock_select_obj.where.return_value = mock_where_obj + mock_db_session.exec.return_value.one_or_none.return_value = mock_provider + + # Set up deduplication rules + mock_rule1 = MagicMock() + mock_rule1.id = "rule-id-1" + mock_rule1.name = "test-rule-1" + + mock_rule2 = MagicMock() + mock_rule2.id = "rule-id-2" + mock_rule2.name = "test-rule-2" + + mock_get_rules.return_value = [mock_rule1, mock_rule2] + + # Set up secret manager + mock_secret_manager = MagicMock() + mock_secret_manager_factory.get_secret_manager.return_value = mock_secret_manager + + # Create a provider and mock provider objects + mock_provider_obj = MagicMock() + mock_providers_factory.get_provider.return_value = mock_provider_obj + + # Call delete_provider + ProvidersService.delete_provider( + tenant_id="test-tenant-id", + provider_id="test-provider-id", + session=mock_db_session, + ) + + # Assert deduplication rules were fetched + mock_get_rules.assert_called_once_with( + "test-tenant-id", mock_provider.id, mock_provider.type + ) + + # Assert deduplication rules were deleted + assert mock_delete_deduplication_rule.call_count == 2 + mock_delete_deduplication_rule.assert_any_call("rule-id-1", "test-tenant-id") + mock_delete_deduplication_rule.assert_any_call("rule-id-2", "test-tenant-id") + + # Assert provider was deleted + mock_db_session.delete.assert_called_once_with(mock_provider) + mock_db_session.commit.assert_called_once() diff --git a/tests/test_provisioning.py b/tests/test_provisioning.py index c5ea522268..ca21156fe1 100644 --- a/tests/test_provisioning.py +++ b/tests/test_provisioning.py @@ -387,12 +387,15 @@ def test_no_provisioned_providers_and_unset_env_vars( from keep.providers.providers_service import ProvidersService # Mock get_all_provisioned_providers to return an empty list - with patch( - "keep.providers.providers_service.get_all_provisioned_providers", - return_value=[], - ) as mock_get_providers, patch( - "keep.providers.providers_service.ProvidersService.delete_provider" - ) as mock_delete_provider: + with ( + patch( + "keep.providers.providers_service.get_all_provisioned_providers", + return_value=[], + ) as mock_get_providers, + patch( + "keep.providers.providers_service.ProvidersService.delete_provider" + ) as mock_delete_provider, + ): # Call provision_providers without setting any env vars ProvidersService.provision_providers("test-tenant") @@ -421,12 +424,15 @@ def test_delete_provisioned_providers_when_env_vars_unset( mock_provider = MagicMock(id="test-id", name="test-provider", type="test-type") # Mock get_all_provisioned_providers to return our mock provider - with patch( - "keep.providers.providers_service.get_all_provisioned_providers", - return_value=[mock_provider], - ) as mock_get_providers, patch( - "keep.providers.providers_service.ProvidersService.delete_provider" - ) as mock_delete_provider: + with ( + patch( + "keep.providers.providers_service.get_all_provisioned_providers", + return_value=[mock_provider], + ) as mock_get_providers, + patch( + "keep.providers.providers_service.ProvidersService.delete_provider" + ) as mock_delete_provider, + ): # Call provision_providers without setting any env vars ProvidersService.provision_providers("test-tenant") @@ -595,3 +601,209 @@ def test_provision_provider_with_multiple_deduplication_rules( # Verify both rules are associated with the same provider assert rule1["provider_type"] == "victoriametrics" assert rule2["provider_type"] == "victoriametrics" + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"rule1":{"description":"First rule","fingerprint_fields":["fingerprint","source"]},"rule2":{"description":"Second rule","fingerprint_fields":["alert_id"]}}}}', + }, + ], + indirect=True, +) +def test_update_deduplication_rules_when_reprovisioning( + monkeypatch, db_session, client, test_app +): + """Test that old deduplication rules are deleted and new ones are created when reprovisioning a provider with different rules""" + + # First verify initial provider and both rules are installed + response = client.get("/deduplications", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + rules = response.json() + assert len(rules) - 1 == 2 # Subtract 1 to exclude the default rule + + rule_names = [r["name"] for r in rules] + assert "rule1" in rule_names + assert "rule2" in rule_names + + # Update provider config with one rule removed and one rule updated and one new rule + monkeypatch.setenv( + "KEEP_PROVIDERS", + '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"rule1":{"description":"Updated first rule","fingerprint_fields":["fingerprint","source","severity"]},"rule3":{"description":"New rule","fingerprint_fields":["alert_id","group"]}}}}', + ) + + # Reload the app to apply the new environment changes + importlib.reload(sys.modules["keep.api.api"]) + from keep.api.api import get_app + + app = get_app() + + # Manually trigger the startup event + for event_handler in app.router.on_startup: + asyncio.run(event_handler()) + + # Manually trigger the provision resources + from keep.api.config import provision_resources + + provision_resources() + + client = TestClient(app) + + # Verify the rules were updated correctly + response = client.get("/deduplications", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + rules = response.json() + + rule_names = [r["name"] for r in rules] + assert "rule1" in rule_names + assert "rule2" not in rule_names # rule2 should be deleted + assert "rule3" in rule_names # rule3 should be added + + # Verify rule1 was updated + rule1 = next(r for r in rules if r["name"] == "rule1") + assert rule1["description"] == "Updated first rule" + assert rule1["fingerprint_fields"] == ["fingerprint", "source", "severity"] + + # Verify rule3 was added + rule3 = next(r for r in rules if r["name"] == "rule3") + assert rule3["description"] == "New rule" + assert rule3["fingerprint_fields"] == ["alert_id", "group"] + + # Verify both rules are associated with the same provider + assert rule1["provider_type"] == "victoriametrics" + assert rule3["provider_type"] == "victoriametrics" + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule1":{"description":"VM Rule","fingerprint_fields":["fingerprint"]}}}, "opsgenie_provider":{"type":"opsgenie","authentication":{"api_key":"somekey"},"deduplication_rules":{"og_rule1":{"description":"OG Rule","fingerprint_fields":["id"]}}}}', + }, + ], + indirect=True, +) +def test_multiple_providers_with_deduplication_rules( + monkeypatch, db_session, client, test_app +): + """Test that deduplication rules for different providers don't interfere with each other""" + + # First verify both providers and their rules are installed + response = client.get("/deduplications", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + rules = response.json() + + rule_names = [r["name"] for r in rules] + assert "vm_rule1" in rule_names + assert "og_rule1" in rule_names + + # Update only the vm_provider, removing its rule and adding a new one + monkeypatch.setenv( + "KEEP_PROVIDERS", + '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule2":{"description":"New VM Rule","fingerprint_fields":["name"]}}}, "opsgenie_provider":{"type":"opsgenie","authentication":{"api_key":"somekey"},"deduplication_rules":{"og_rule1":{"description":"OG Rule","fingerprint_fields":["id"]}}}}', + ) + + # Reload the app to apply the new environment changes + importlib.reload(sys.modules["keep.api.api"]) + from keep.api.api import get_app + + app = get_app() + + # Manually trigger the startup event + for event_handler in app.router.on_startup: + asyncio.run(event_handler()) + + # Manually trigger the provision resources + from keep.api.config import provision_resources + + provision_resources() + + client = TestClient(app) + + # Verify the rules were updated correctly + response = client.get("/deduplications", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + rules = response.json() + + rule_names = [r["name"] for r in rules] + assert "vm_rule1" not in rule_names # vm_rule1 should be deleted + assert "vm_rule2" in rule_names # vm_rule2 should be added + assert "og_rule1" in rule_names # og_rule1 should be kept + + # Verify vm_rule2 was added correctly + vm_rule2 = next(r for r in rules if r["name"] == "vm_rule2") + assert vm_rule2["description"] == "New VM Rule" + assert vm_rule2["fingerprint_fields"] == ["name"] + assert vm_rule2["provider_type"] == "victoriametrics" + + # Verify og_rule1 was kept unchanged + og_rule1 = next(r for r in rules if r["name"] == "og_rule1") + assert og_rule1["description"] == "OG Rule" + assert og_rule1["fingerprint_fields"] == ["id"] + assert og_rule1["provider_type"] == "opsgenie" + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule1":{"description":"VM Rule","fingerprint_fields":["fingerprint"]}}}, "opsgenie_provider":{"type":"opsgenie","authentication":{"api_key":"somekey"},"deduplication_rules":{"og_rule1":{"description":"OG Rule","fingerprint_fields":["id"]}}}}', + }, + ], + indirect=True, +) +def test_deleting_provider_removes_deduplication_rules( + monkeypatch, db_session, client, test_app +): + """Test that when a provider is deleted, its associated deduplication rules are deleted as well""" + + # First verify both providers and their rules are installed + response = client.get("/deduplications", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + rules = response.json() + + rule_names = [r["name"] for r in rules] + assert "vm_rule1" in rule_names + assert "og_rule1" in rule_names + + # Remove the opsgenie_provider completely + monkeypatch.setenv( + "KEEP_PROVIDERS", + '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule1":{"description":"VM Rule","fingerprint_fields":["fingerprint"]}}}}', + ) + + # Reload the app to apply the new environment changes + importlib.reload(sys.modules["keep.api.api"]) + from keep.api.api import get_app + + app = get_app() + + # Manually trigger the startup event + for event_handler in app.router.on_startup: + asyncio.run(event_handler()) + + # Manually trigger the provision resources + from keep.api.config import provision_resources + + provision_resources() + + client = TestClient(app) + + # Verify the rules were updated correctly + response = client.get("/deduplications", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + rules = response.json() + + rule_names = [r["name"] for r in rules] + assert "vm_rule1" in rule_names # vm_rule1 should still exist + assert "og_rule1" not in rule_names # og_rule1 should be deleted + + # Verify vm_rule1 is unchanged + vm_rule1 = next(r for r in rules if r["name"] == "vm_rule1") + assert vm_rule1["description"] == "VM Rule" + assert vm_rule1["fingerprint_fields"] == ["fingerprint"] + assert vm_rule1["provider_type"] == "victoriametrics" From c149182c961eeb5eaefd0eca66375ae4ad86a150 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sun, 13 Apr 2025 00:07:11 +0700 Subject: [PATCH 14/18] feat: add validate_scopes parameter to ProvidersService and improve deduplication rules logging --- keep/providers/providers_service.py | 51 ++++++++++++++++------------- tests/test_provisioning.py | 34 +++++++++---------- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py index 0cbf3f3c4b..c58fd3becc 100644 --- a/keep/providers/providers_service.py +++ b/keep/providers/providers_service.py @@ -234,6 +234,7 @@ def update_provider( updated_by: str, session: Optional[Session] = None, allow_provisioned: bool = False, + validate_scopes: bool = True, ) -> Dict[str, Any]: with existed_or_new_session(session) as session: provider = session.exec( @@ -267,7 +268,10 @@ def update_provider( except Exception as e: raise HTTPException(status_code=400, detail=str(e)) - validated_scopes = provider_instance.validate_scopes() + if validate_scopes: + validated_scopes = provider_instance.validate_scopes() + else: + validated_scopes = {} secret_manager = SecretManagerFactory.get_secret_manager(context_manager) secret_manager.write_secret( @@ -311,6 +315,7 @@ def upsert_provider( provider_info=provider_config, updated_by="system", allow_provisioned=True, + validate_scopes=validate_scopes, ) logger.info(f"Provider {provider_name} updated successfully") else: @@ -636,18 +641,17 @@ def provision_providers(tenant_id: str): provider = installed_provider_info["provider"] # Configure deduplication rules - deduplication_rules = provider_info.get("deduplication_rules") - if deduplication_rules is not None: - logger.info( - f"Provisioning deduplication rules for provider {provider_name}" + deduplication_rules = provider_info.get("deduplication_rules", {}) + logger.info( + f"Provisioning deduplication rules for provider {provider_name}" + ) + with Session(engine) as session: + ProvidersService.provision_provider_deduplication_rules( + tenant_id=tenant_id, + provider=provider, + deduplication_rules=deduplication_rules, + session=session, ) - with Session(engine) as session: - ProvidersService.provision_provider_deduplication_rules( - tenant_id=tenant_id, - provider=provider, - deduplication_rules=deduplication_rules, - session=session, - ) ### Provisioning from the directory if provisioned_providers_dir is not None: @@ -701,18 +705,19 @@ def provision_providers(tenant_id: str): provider = installed_provider_info["provider"] # Configure deduplication rules - deduplication_rules = provider_info.get("deduplication_rules") - if deduplication_rules is not None: - logger.info( - f"Provisioning deduplication rules for provider {provider_name}" + deduplication_rules = provider_info.get( + "deduplication_rules", {} + ) + logger.info( + f"Provisioning deduplication rules for provider {provider_name}" + ) + with Session(engine) as session: + ProvidersService.provision_provider_deduplication_rules( + tenant_id=tenant_id, + provider=provider, + deduplication_rules=deduplication_rules, + session=session, ) - with Session(engine) as session: - ProvidersService.provision_provider_deduplication_rules( - tenant_id=tenant_id, - provider=provider, - deduplication_rules=deduplication_rules, - session=session, - ) # Delete providers that are not in the incoming list for provider in provisioned_providers: diff --git a/tests/test_provisioning.py b/tests/test_provisioning.py index ca21156fe1..e173f97b2d 100644 --- a/tests/test_provisioning.py +++ b/tests/test_provisioning.py @@ -416,7 +416,7 @@ def test_delete_provisioned_providers_when_env_vars_unset( ): """Test deleting provisioned providers when env vars are unset""" # Import necessary modules - from unittest.mock import ANY, MagicMock, patch + from unittest.mock import MagicMock, patch from keep.providers.providers_service import ProvidersService @@ -441,11 +441,9 @@ def test_delete_provisioned_providers_when_env_vars_unset( # Verify delete_provider was called with correct parameters mock_delete_provider.assert_called_once_with( - "test-tenant", - "test-id", - ANY, # Session object + tenant_id="test-tenant", + provider_id="test-id", allow_provisioned=True, - commit=False, ) @@ -681,7 +679,7 @@ def test_update_deduplication_rules_when_reprovisioning( [ { "AUTH_TYPE": "NOAUTH", - "KEEP_PROVIDERS": '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule1":{"description":"VM Rule","fingerprint_fields":["fingerprint"]}}}, "opsgenie_provider":{"type":"opsgenie","authentication":{"api_key":"somekey"},"deduplication_rules":{"og_rule1":{"description":"OG Rule","fingerprint_fields":["id"]}}}}', + "KEEP_PROVIDERS": '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule1":{"description":"VM Rule","fingerprint_fields":["fingerprint"]}}}, "pagerduty_provider":{"type":"pagerduty","authentication":{"api_key":"somekey","routing_key":"routingkey123"},"deduplication_rules":{"pd_rule1":{"description":"PD Rule","fingerprint_fields":["id"]}}}}', }, ], indirect=True, @@ -698,12 +696,12 @@ def test_multiple_providers_with_deduplication_rules( rule_names = [r["name"] for r in rules] assert "vm_rule1" in rule_names - assert "og_rule1" in rule_names + assert "pd_rule1" in rule_names # Update only the vm_provider, removing its rule and adding a new one monkeypatch.setenv( "KEEP_PROVIDERS", - '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule2":{"description":"New VM Rule","fingerprint_fields":["name"]}}}, "opsgenie_provider":{"type":"opsgenie","authentication":{"api_key":"somekey"},"deduplication_rules":{"og_rule1":{"description":"OG Rule","fingerprint_fields":["id"]}}}}', + '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule2":{"description":"New VM Rule","fingerprint_fields":["name"]}}}, "pagerduty_provider":{"type":"pagerduty","authentication":{"api_key":"somekey"},"deduplication_rules":{"pd_rule1":{"description":"PD Rule","fingerprint_fields":["id"]}}}}', ) # Reload the app to apply the new environment changes @@ -731,7 +729,7 @@ def test_multiple_providers_with_deduplication_rules( rule_names = [r["name"] for r in rules] assert "vm_rule1" not in rule_names # vm_rule1 should be deleted assert "vm_rule2" in rule_names # vm_rule2 should be added - assert "og_rule1" in rule_names # og_rule1 should be kept + assert "pd_rule1" in rule_names # pd_rule1 should be kept # Verify vm_rule2 was added correctly vm_rule2 = next(r for r in rules if r["name"] == "vm_rule2") @@ -739,11 +737,11 @@ def test_multiple_providers_with_deduplication_rules( assert vm_rule2["fingerprint_fields"] == ["name"] assert vm_rule2["provider_type"] == "victoriametrics" - # Verify og_rule1 was kept unchanged - og_rule1 = next(r for r in rules if r["name"] == "og_rule1") - assert og_rule1["description"] == "OG Rule" - assert og_rule1["fingerprint_fields"] == ["id"] - assert og_rule1["provider_type"] == "opsgenie" + # Verify pd_rule1 was kept unchanged + pd_rule1 = next(r for r in rules if r["name"] == "pd_rule1") + assert pd_rule1["description"] == "PD Rule" + assert pd_rule1["fingerprint_fields"] == ["id"] + assert pd_rule1["provider_type"] == "pagerduty" @pytest.mark.parametrize( @@ -751,7 +749,7 @@ def test_multiple_providers_with_deduplication_rules( [ { "AUTH_TYPE": "NOAUTH", - "KEEP_PROVIDERS": '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule1":{"description":"VM Rule","fingerprint_fields":["fingerprint"]}}}, "opsgenie_provider":{"type":"opsgenie","authentication":{"api_key":"somekey"},"deduplication_rules":{"og_rule1":{"description":"OG Rule","fingerprint_fields":["id"]}}}}', + "KEEP_PROVIDERS": '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule1":{"description":"VM Rule","fingerprint_fields":["fingerprint"]}}}, "pagerduty_provider":{"type":"pagerduty","authentication":{"api_key":"somekey","routing_key":"routingkey123"},"deduplication_rules":{"pd_rule1":{"description":"PD Rule","fingerprint_fields":["id"]}}}}', }, ], indirect=True, @@ -768,9 +766,9 @@ def test_deleting_provider_removes_deduplication_rules( rule_names = [r["name"] for r in rules] assert "vm_rule1" in rule_names - assert "og_rule1" in rule_names + assert "pd_rule1" in rule_names - # Remove the opsgenie_provider completely + # Remove the pagerduty_provider completely monkeypatch.setenv( "KEEP_PROVIDERS", '{"vm_provider":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort":1234},"deduplication_rules":{"vm_rule1":{"description":"VM Rule","fingerprint_fields":["fingerprint"]}}}}', @@ -800,7 +798,7 @@ def test_deleting_provider_removes_deduplication_rules( rule_names = [r["name"] for r in rules] assert "vm_rule1" in rule_names # vm_rule1 should still exist - assert "og_rule1" not in rule_names # og_rule1 should be deleted + assert "pd_rule1" not in rule_names # pd_rule1 should be deleted # Verify vm_rule1 is unchanged vm_rule1 = next(r for r in rules if r["name"] == "vm_rule1") From 8ca1f2ce6fd21dff9828d0b7f1adf3def766655a Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sun, 13 Apr 2025 00:42:37 +0700 Subject: [PATCH 15/18] refactor: streamline deduplication rules provisioning and improve logging in ProvidersService --- keep/providers/providers_service.py | 71 ++++------ tests/test_providers_yaml_provisioning.py | 163 +++++++++++++--------- 2 files changed, 119 insertions(+), 115 deletions(-) diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py index c58fd3becc..492933deb5 100644 --- a/keep/providers/providers_service.py +++ b/keep/providers/providers_service.py @@ -219,7 +219,6 @@ def install_provider( logger.exception("Failed to register provider as a consumer") return { - "provider": provider_model, "type": provider_type, "id": provider_unique_id, "details": config, @@ -285,7 +284,6 @@ def update_provider( session.commit() return { - "provider": provider, "details": provider_config, "validatedScopes": validated_scopes, } @@ -449,44 +447,25 @@ def provision_provider_deduplication_rules( tenant_id: str, provider: Provider, deduplication_rules: Dict[str, Dict[str, Any]], - session: Optional[Session] = None, ): - """ - Provision deduplication rules for a provider. - - Args: - tenant_id (str): The tenant ID. - provider (Provider): The provider to provision the deduplication rules for. - deduplication_rules (Dict[str, Dict[str, Any]]): The deduplication rules to provision. - session (Optional[Session]): SQLAlchemy session to use. - """ - with existed_or_new_session(session) as session: - # Ensure provider is attached to the session - if provider not in session: - provider = session.merge(provider) - - # Provision the deduplication rules - deduplication_rules_dict: dict[str, dict] = {} - for rule_name, rule_config in deduplication_rules.items(): - logger.info(f"Provisioning deduplication rule {rule_name}") - rule_config["name"] = rule_name - rule_config["provider_name"] = provider.name - rule_config["provider_type"] = provider.type - deduplication_rules_dict[rule_name] = rule_config + # Provision the deduplication rules + deduplication_rules_dict: dict[str, dict] = {} + for rule_name, rule_config in deduplication_rules.items(): + logger.info(f"Provisioning deduplication rule {rule_name}") + rule_config["name"] = rule_name + rule_config["provider_name"] = provider.name + rule_config["provider_type"] = provider.type + deduplication_rules_dict[rule_name] = rule_config - try: - # Provision deduplication rules - provision_deduplication_rules( - deduplication_rules=deduplication_rules_dict, - tenant_id=tenant_id, - provider=provider, - ) - except Exception as e: - logger.error( - "Provisioning failed, rolling back", extra={"exception": e} - ) - session.rollback() - raise + try: + # Provision deduplication rules + provision_deduplication_rules( + deduplication_rules=deduplication_rules_dict, + tenant_id=tenant_id, + provider=provider, + ) + except Exception as e: + logger.exception(f"Failed to provision deduplication rules: {e}") def install_webhook( tenant_id: str, provider_type: str, provider_id: str, session: Session @@ -622,7 +601,7 @@ def provision_providers(tenant_id: str): provider_config = provider_info.get("authentication", {}) # Perform upsert operation for the provider - installed_provider_info = ProvidersService.upsert_provider( + ProvidersService.upsert_provider( tenant_id=tenant_id, provider_name=provider_name, provider_type=provider_type, @@ -638,20 +617,18 @@ def provision_providers(tenant_id: str): ) continue - provider = installed_provider_info["provider"] + provider = get_provider_by_name(tenant_id, provider_name) # Configure deduplication rules deduplication_rules = provider_info.get("deduplication_rules", {}) logger.info( f"Provisioning deduplication rules for provider {provider_name}" ) - with Session(engine) as session: - ProvidersService.provision_provider_deduplication_rules( - tenant_id=tenant_id, - provider=provider, - deduplication_rules=deduplication_rules, - session=session, - ) + ProvidersService.provision_provider_deduplication_rules( + tenant_id=tenant_id, + provider=provider, + deduplication_rules=deduplication_rules, + ) ### Provisioning from the directory if provisioned_providers_dir is not None: diff --git a/tests/test_providers_yaml_provisioning.py b/tests/test_providers_yaml_provisioning.py index 91aa3287c6..6e95ddb299 100644 --- a/tests/test_providers_yaml_provisioning.py +++ b/tests/test_providers_yaml_provisioning.py @@ -69,19 +69,23 @@ def test_provision_provider_from_yaml(temp_providers_dir, sample_provider_yaml, # Mock environment variables with patch.dict(os.environ, {"KEEP_PROVIDERS_DIRECTORY": temp_providers_dir}): - with patch( - "keep.providers.providers_service.ProvidersService.is_provider_installed", - return_value=False, - ), patch( - "keep.providers.providers_service.ProvidersService.install_provider", - return_value=mock_provider, - ) as mock_install, patch( - "keep.providers.providers_service.provision_deduplication_rules" - ) as mock_provision_rules, patch( - "keep.api.core.db.get_all_provisioned_providers", return_value=[] - ), patch( - "keep.providers.providers_factory.ProvidersFactory.get_installed_providers", - return_value=[mock_provider], + with ( + patch( + "keep.providers.providers_service.ProvidersService.is_provider_installed", + return_value=False, + ), + patch( + "keep.providers.providers_service.ProvidersService.install_provider", + return_value=mock_provider, + ) as mock_install, + patch( + "keep.providers.providers_service.ProvidersService.provision_provider_deduplication_rules" + ) as mock_provision_provider_rules, + patch("keep.api.core.db.get_all_provisioned_providers", return_value=[]), + patch( + "keep.providers.providers_factory.ProvidersFactory.get_installed_providers", + return_value=[mock_provider], + ), ): # Call the provisioning function ProvidersService.provision_providers("test-tenant") @@ -98,15 +102,11 @@ def test_provision_provider_from_yaml(temp_providers_dir, sample_provider_yaml, } # Verify deduplication rules provisioning was called - mock_provision_rules.assert_called_once() - call_args = mock_provision_rules.call_args[1] + mock_provision_provider_rules.assert_called_once() + call_args = mock_provision_provider_rules.call_args[1] assert call_args["tenant_id"] == "test-tenant" - assert len(call_args["deduplication_rules"]) > 0 - rule = list(call_args["deduplication_rules"].values())[0] - assert rule["description"] == "Test deduplication rule" - assert rule["fingerprint_fields"] == ["fingerprint", "source"] - assert rule["full_deduplication"] is True - assert rule["ignore_fields"] == ["name"] + assert "provider" in call_args + assert "deduplication_rules" in call_args def test_invalid_yaml_file(temp_providers_dir): @@ -119,12 +119,15 @@ def test_invalid_yaml_file(temp_providers_dir): # Mock environment variables with patch.dict(os.environ, {"KEEP_PROVIDERS_DIRECTORY": temp_providers_dir}): # Mock database operations - with patch( - "keep.providers.providers_service.ProvidersService.is_provider_installed", - return_value=False, - ), patch( - "keep.providers.providers_service.ProvidersService.install_provider" - ) as mock_install: + with ( + patch( + "keep.providers.providers_service.ProvidersService.is_provider_installed", + return_value=False, + ), + patch( + "keep.providers.providers_service.ProvidersService.install_provider" + ) as mock_install, + ): # Call the provisioning function ProvidersService.provision_providers("test-tenant") @@ -149,12 +152,15 @@ def test_missing_required_fields(temp_providers_dir): # Mock environment variables with patch.dict(os.environ, {"KEEP_PROVIDERS_DIRECTORY": temp_providers_dir}): # Mock database operations - with patch( - "keep.providers.providers_service.ProvidersService.is_provider_installed", - return_value=False, - ), patch( - "keep.providers.providers_service.ProvidersService.install_provider" - ) as mock_install: + with ( + patch( + "keep.providers.providers_service.ProvidersService.is_provider_installed", + return_value=False, + ), + patch( + "keep.providers.providers_service.ProvidersService.install_provider" + ) as mock_install, + ): # Call the provisioning function ProvidersService.provision_providers("test-tenant") @@ -206,16 +212,19 @@ def test_provider_yaml_with_multiple_deduplication_rules(temp_providers_dir, cap # Mock environment variables and services with patch.dict(os.environ, {"KEEP_PROVIDERS_DIRECTORY": temp_providers_dir}): - with patch( - "keep.providers.providers_service.ProvidersService.is_provider_installed", - return_value=False, - ), patch( - "keep.providers.providers_service.ProvidersService.install_provider", - return_value=mock_provider, - ) as mock_install, patch( - "keep.providers.providers_service.provision_deduplication_rules" - ) as mock_provision_rules, patch( - "keep.api.core.db.get_all_provisioned_providers", return_value=[] + with ( + patch( + "keep.providers.providers_service.ProvidersService.is_provider_installed", + return_value=False, + ), + patch( + "keep.providers.providers_service.ProvidersService.install_provider", + return_value=mock_provider, + ) as mock_install, + patch( + "keep.providers.providers_service.ProvidersService.provision_provider_deduplication_rules" + ) as mock_provision_provider_rules, + patch("keep.api.core.db.get_all_provisioned_providers", return_value=[]), ): # Call the provisioning function ProvidersService.provision_providers("test-tenant") @@ -224,8 +233,8 @@ def test_provider_yaml_with_multiple_deduplication_rules(temp_providers_dir, cap mock_install.assert_called_once() # Verify deduplication rules provisioning - mock_provision_rules.assert_called_once() - call_args = mock_provision_rules.call_args[1] + mock_provision_provider_rules.assert_called_once() + call_args = mock_provision_provider_rules.call_args[1] assert call_args["tenant_id"] == "test-tenant" rules = call_args["deduplication_rules"] @@ -272,16 +281,19 @@ def test_provider_yaml_with_empty_deduplication_rules(temp_providers_dir, caplog # Mock environment variables and services with patch.dict(os.environ, {"KEEP_PROVIDERS_DIRECTORY": temp_providers_dir}): - with patch( - "keep.providers.providers_service.ProvidersService.is_provider_installed", - return_value=False, - ), patch( - "keep.providers.providers_service.ProvidersService.install_provider", - return_value=mock_provider, - ) as mock_install, patch( - "keep.providers.providers_service.provision_deduplication_rules" - ) as mock_provision_rules, patch( - "keep.api.core.db.get_all_provisioned_providers", return_value=[] + with ( + patch( + "keep.providers.providers_service.ProvidersService.is_provider_installed", + return_value=False, + ), + patch( + "keep.providers.providers_service.ProvidersService.install_provider", + return_value=mock_provider, + ) as mock_install, + patch( + "keep.providers.providers_service.ProvidersService.provision_provider_deduplication_rules" + ) as mock_provision_provider_rules, + patch("keep.api.core.db.get_all_provisioned_providers", return_value=[]), ): # Call the provisioning function ProvidersService.provision_providers("test-tenant") @@ -290,7 +302,10 @@ def test_provider_yaml_with_empty_deduplication_rules(temp_providers_dir, caplog mock_install.assert_called_once() # Verify deduplication rules provisioning was called with empty rules - mock_provision_rules.assert_not_called() + mock_provision_provider_rules.assert_called_once() + call_args = mock_provision_provider_rules.call_args[1] + assert call_args["tenant_id"] == "test-tenant" + assert call_args["deduplication_rules"] == {} def test_provider_yaml_with_invalid_deduplication_rules(temp_providers_dir, caplog): @@ -324,16 +339,28 @@ def test_provider_yaml_with_invalid_deduplication_rules(temp_providers_dir, capl # Mock environment variables and services with patch.dict(os.environ, {"KEEP_PROVIDERS_DIRECTORY": temp_providers_dir}): - with patch( - "keep.providers.providers_service.ProvidersService.is_provider_installed", - return_value=False, - ), patch( - "keep.providers.providers_service.ProvidersService.install_provider", - return_value=mock_provider, - ) as mock_install, patch( - "keep.providers.providers_service.provision_deduplication_rules" - ) as mock_provision_rules, patch( - "keep.api.core.db.get_all_provisioned_providers", return_value=[] + with ( + patch( + "keep.providers.providers_service.ProvidersService.is_provider_installed", + return_value=False, + ), + patch( + "keep.providers.providers_service.ProvidersService.install_provider", + return_value=mock_provider, + ) as mock_install, + patch( + "keep.providers.providers_service.ProvidersService.provision_provider_deduplication_rules" + ) as mock_provision_provider_rules, + patch("keep.api.core.db.get_all_provisioned_providers", return_value=[]), + patch( + "sqlmodel.Session", + MagicMock( + return_value=MagicMock( + __enter__=MagicMock(return_value=MagicMock()), + __exit__=MagicMock(), + ) + ), + ), ): # Call the provisioning function ProvidersService.provision_providers("test-tenant") @@ -342,8 +369,8 @@ def test_provider_yaml_with_invalid_deduplication_rules(temp_providers_dir, capl mock_install.assert_called_once() # Verify deduplication rules provisioning was called - mock_provision_rules.assert_called_once() - call_args = mock_provision_rules.call_args[1] + mock_provision_provider_rules.assert_called_once() + call_args = mock_provision_provider_rules.call_args[1] assert call_args["tenant_id"] == "test-tenant" # Even invalid rules should be passed through, validation happens in provision_deduplication_rules From 22620e8558360284330f2138a47b0b69d9e601f0 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sun, 13 Apr 2025 00:54:58 +0700 Subject: [PATCH 16/18] delete: remove obsolete unit tests for provider deletion and deduplication rule handling --- tests/test_providers_service.py | 89 --------------------------------- 1 file changed, 89 deletions(-) delete mode 100644 tests/test_providers_service.py diff --git a/tests/test_providers_service.py b/tests/test_providers_service.py deleted file mode 100644 index 89544a55d7..0000000000 --- a/tests/test_providers_service.py +++ /dev/null @@ -1,89 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from sqlmodel import Session - -from keep.providers.providers_service import ProvidersService - - -@pytest.fixture -def mock_db_session(): - session = MagicMock(spec=Session) - return session - - -@pytest.fixture -def mock_provider(): - provider = MagicMock() - provider.id = "test-provider-id" - provider.type = "test-provider-type" - provider.name = "test-provider" - provider.tenant_id = "test-tenant-id" - provider.provisioned = False - return provider - - -@patch("keep.providers.providers_service.ContextManager") -@patch("keep.providers.providers_service.SecretManagerFactory") -@patch("keep.providers.providers_service.ProvidersFactory") -@patch("keep.providers.providers_service.EventSubscriber") -@patch("keep.providers.providers_service.select") -@patch("keep.providers.providers_service.get_all_deduplication_rules_by_provider") -@patch("keep.providers.providers_service.delete_deduplication_rule") -def test_delete_provider_cascade_deletes_deduplication_rules( - mock_delete_deduplication_rule, - mock_get_rules, - mock_select, - mock_event_subscriber, - mock_providers_factory, - mock_secret_manager_factory, - mock_context_manager, - mock_provider, - mock_db_session, -): - # Set up mocks - mock_select_obj = MagicMock() - mock_select.return_value = mock_select_obj - mock_where_obj = MagicMock() - mock_select_obj.where.return_value = mock_where_obj - mock_db_session.exec.return_value.one_or_none.return_value = mock_provider - - # Set up deduplication rules - mock_rule1 = MagicMock() - mock_rule1.id = "rule-id-1" - mock_rule1.name = "test-rule-1" - - mock_rule2 = MagicMock() - mock_rule2.id = "rule-id-2" - mock_rule2.name = "test-rule-2" - - mock_get_rules.return_value = [mock_rule1, mock_rule2] - - # Set up secret manager - mock_secret_manager = MagicMock() - mock_secret_manager_factory.get_secret_manager.return_value = mock_secret_manager - - # Create a provider and mock provider objects - mock_provider_obj = MagicMock() - mock_providers_factory.get_provider.return_value = mock_provider_obj - - # Call delete_provider - ProvidersService.delete_provider( - tenant_id="test-tenant-id", - provider_id="test-provider-id", - session=mock_db_session, - ) - - # Assert deduplication rules were fetched - mock_get_rules.assert_called_once_with( - "test-tenant-id", mock_provider.id, mock_provider.type - ) - - # Assert deduplication rules were deleted - assert mock_delete_deduplication_rule.call_count == 2 - mock_delete_deduplication_rule.assert_any_call("rule-id-1", "test-tenant-id") - mock_delete_deduplication_rule.assert_any_call("rule-id-2", "test-tenant-id") - - # Assert provider was deleted - mock_db_session.delete.assert_called_once_with(mock_provider) - mock_db_session.commit.assert_called_once() From 5c6ba43d6e8ccafbf6dd10e639e6f719622a4efe Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sun, 13 Apr 2025 00:59:59 +0700 Subject: [PATCH 17/18] refactor: update provider provisioning logic to improve configuration handling and rollback mechanisms --- docs/deployment/provision/provider.mdx | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/docs/deployment/provision/provider.mdx b/docs/deployment/provision/provider.mdx index a6c63f4a4f..9173a1db6b 100644 --- a/docs/deployment/provision/provider.mdx +++ b/docs/deployment/provision/provider.mdx @@ -99,7 +99,6 @@ deduplication_rules: Keep supports a wide range of provider types. Each provider type has its own specific configuration requirements. To see the full list of supported providers and their detailed configuration options, please refer to our comprehensive provider documentation. - ### Update Provisioned Providers Keep uses a consistent process for updating provider configurations regardless of whether you use `KEEP_PROVIDERS` or `KEEP_PROVIDERS_DIRECTORY`. @@ -109,11 +108,6 @@ Keep uses a consistent process for updating provider configurations regardless o When Keep starts or restarts, it follows these steps to manage provider configurations: 1. **Read Configurations**: Loads provider definitions from either the `KEEP_PROVIDERS` environment variable or YAML files in the `KEEP_PROVIDERS_DIRECTORY`. -2. **Calculate Configuration Hash**: Generates a hash of the current configurations to detect changes. -3. **Check for Changes**: Compares the new hash with the previously stored hash (in Redis or secret manager). -4. **Update When Changed**: If configurations have changed: - - Backup the current state for potential rollback - - Delete all existing provisioned providers - - Provision new providers with their deduplication rules - - If any errors occur during provisioning, automatically rollback to the previous state -5. **Skip When Unchanged**: If configurations haven't changed since the last startup, Keep skips the re-provisioning process to improve startup performance. +2. **Create New Providers**: Installs any providers listed in the configuration that are not already present. +3. **Update When Changed**: If an existing provider's configuration has changed, Keep reapplies the configuration, including deduplication rules. If errors occur during this update, changes are automatically rolled back. +4. **Delete Providers**: Deletes any currently installed providers that are not found in the loaded configuration. From 012e14b07c30997136327b1994127fa7874221e6 Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Sun, 13 Apr 2025 01:15:14 +0700 Subject: [PATCH 18/18] fix: update KEEP_PROVIDERS key in deduplication tests for consistency --- tests/deduplication/test_deduplications.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/deduplication/test_deduplications.py b/tests/deduplication/test_deduplications.py index 6500474e7c..f95117bdf7 100644 --- a/tests/deduplication/test_deduplications.py +++ b/tests/deduplication/test_deduplications.py @@ -359,7 +359,7 @@ def test_custom_deduplication_rule_behaviour(db_session, client, test_app): [ { "AUTH_TYPE": "NOAUTH", - "KEEP_PROVIDERS": '{"keepDatadogCustomRule":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', + "KEEP_PROVIDERS": '{"keepDatadog":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', }, ], indirect=True, @@ -432,7 +432,7 @@ def test_custom_deduplication_rule_2(db_session, client, test_app): [ { "AUTH_TYPE": "NOAUTH", - "KEEP_PROVIDERS": '{"keepDatadogUpdateRule":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', + "KEEP_PROVIDERS": '{"keepDatadog":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', }, ], indirect=True, @@ -557,7 +557,7 @@ def test_update_deduplication_rule_linked_provider(db_session, client, test_app) [ { "AUTH_TYPE": "NOAUTH", - "KEEP_PROVIDERS": '{"keepDatadogDeleteRule":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', + "KEEP_PROVIDERS": '{"keepDatadog":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', }, ], indirect=True,