From 3d461c50be205b70159b2a280bb5db7fd98d2158 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Tue, 11 Mar 2025 14:45:14 +0100 Subject: [PATCH 01/22] First step - get rid of collections and unused methods --- src/crawlee/storage_clients/_base/__init__.py | 6 +- .../storage_clients/_base/_dataset_client.py | 29 +--- .../_base/_dataset_collection_client.py | 59 --------- .../_base/_key_value_store_client.py | 15 --- .../_key_value_store_collection_client.py | 59 --------- .../_base/_request_queue_client.py | 31 ++--- .../_base/_request_queue_collection_client.py | 59 --------- .../storage_clients/_base/_storage_client.py | 15 --- src/crawlee/storage_clients/_base/_types.py | 9 -- .../storage_clients/_memory/__init__.py | 6 - .../_memory/_dataset_client.py | 111 +++++----------- .../_memory/_dataset_collection_client.py | 62 --------- .../_memory/_key_value_store_client.py | 45 +------ .../_key_value_store_collection_client.py | 62 --------- .../_memory/_memory_storage_client.py | 15 --- .../_memory/_request_queue_client.py | 125 ++++++------------ .../_request_queue_collection_client.py | 62 --------- src/crawlee/storages/_creation_management.py | 18 +-- src/crawlee/storages/_dataset.py | 1 - src/crawlee/storages/_request_queue.py | 1 - 20 files changed, 96 insertions(+), 694 deletions(-) delete mode 100644 src/crawlee/storage_clients/_base/_dataset_collection_client.py delete mode 100644 src/crawlee/storage_clients/_base/_key_value_store_collection_client.py delete mode 100644 src/crawlee/storage_clients/_base/_request_queue_collection_client.py delete mode 100644 src/crawlee/storage_clients/_memory/_dataset_collection_client.py delete mode 100644 src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py delete mode 100644 src/crawlee/storage_clients/_memory/_request_queue_collection_client.py diff --git a/src/crawlee/storage_clients/_base/__init__.py b/src/crawlee/storage_clients/_base/__init__.py index 5194da8768..ae8151e15f 100644 --- a/src/crawlee/storage_clients/_base/__init__.py +++ b/src/crawlee/storage_clients/_base/__init__.py @@ -1,11 +1,8 @@ from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient from ._storage_client import StorageClient -from ._types import ResourceClient, ResourceCollectionClient +from ._types import ResourceClient __all__ = [ 'DatasetClient', @@ -15,6 +12,5 @@ 'RequestQueueClient', 'RequestQueueCollectionClient', 'ResourceClient', - 'ResourceCollectionClient', 'StorageClient', ] diff --git a/src/crawlee/storage_clients/_base/_dataset_client.py b/src/crawlee/storage_clients/_base/_dataset_client.py index d8495b2dd0..02beb0c6d5 100644 --- a/src/crawlee/storage_clients/_base/_dataset_client.py +++ b/src/crawlee/storage_clients/_base/_dataset_client.py @@ -35,24 +35,17 @@ async def get(self) -> DatasetMetadata | None: """ @abstractmethod - async def update( - self, - *, - name: str | None = None, - ) -> DatasetMetadata: - """Update the dataset metadata. + async def delete(self) -> None: + """Permanently delete the dataset managed by this client.""" - Args: - name: New new name for the dataset. + @abstractmethod + async def push_items(self, items: JsonSerializable) -> None: + """Push items to the dataset. - Returns: - An object reflecting the updated dataset metadata. + Args: + items: The items which to push in the dataset. They must be JSON serializable. """ - @abstractmethod - async def delete(self) -> None: - """Permanently delete the dataset managed by this client.""" - @abstractmethod async def list_items( self, @@ -221,11 +214,3 @@ async def stream_items( Yields: The dataset items in a streaming response. """ - - @abstractmethod - async def push_items(self, items: JsonSerializable) -> None: - """Push items to the dataset. - - Args: - items: The items which to push in the dataset. They must be JSON serializable. - """ diff --git a/src/crawlee/storage_clients/_base/_dataset_collection_client.py b/src/crawlee/storage_clients/_base/_dataset_collection_client.py deleted file mode 100644 index 8530655c8c..0000000000 --- a/src/crawlee/storage_clients/_base/_dataset_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import DatasetListPage, DatasetMetadata - - -@docs_group('Abstract classes') -class DatasetCollectionClient(ABC): - """An abstract class for dataset collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> DatasetMetadata: - """Retrieve an existing dataset by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the dataset to retrieve or create. If provided, the method will attempt - to find a dataset with the ID. - name: Optional name of the dataset resource to retrieve or create. If provided, the method will - attempt to find a dataset with this name. - schema: Optional schema for the dataset resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created dataset. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> DatasetListPage: - """List the available datasets. - - Args: - unnamed: Whether to list only the unnamed datasets. - limit: Maximum number of datasets to return. - offset: Number of datasets to skip from the beginning of the list. - desc: Whether to sort the datasets in descending order. - - Returns: - The list of available datasets matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_key_value_store_client.py b/src/crawlee/storage_clients/_base/_key_value_store_client.py index 6a5d141be6..91f73993b0 100644 --- a/src/crawlee/storage_clients/_base/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_base/_key_value_store_client.py @@ -29,21 +29,6 @@ async def get(self) -> KeyValueStoreMetadata | None: An object containing the key-value store's details, or None if the key-value store does not exist. """ - @abstractmethod - async def update( - self, - *, - name: str | None = None, - ) -> KeyValueStoreMetadata: - """Update the key-value store metadata. - - Args: - name: New new name for the key-value store. - - Returns: - An object reflecting the updated key-value store metadata. - """ - @abstractmethod async def delete(self) -> None: """Permanently delete the key-value store managed by this client.""" diff --git a/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py b/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py deleted file mode 100644 index b447cf49b1..0000000000 --- a/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import KeyValueStoreListPage, KeyValueStoreMetadata - - -@docs_group('Abstract classes') -class KeyValueStoreCollectionClient(ABC): - """An abstract class for key-value store collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> KeyValueStoreMetadata: - """Retrieve an existing key-value store by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the key-value store to retrieve or create. If provided, the method will attempt - to find a key-value store with the ID. - name: Optional name of the key-value store resource to retrieve or create. If provided, the method will - attempt to find a key-value store with this name. - schema: Optional schema for the key-value store resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created key-value store. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> KeyValueStoreListPage: - """List the available key-value stores. - - Args: - unnamed: Whether to list only the unnamed key-value stores. - limit: Maximum number of key-value stores to return. - offset: Number of key-value stores to skip from the beginning of the list. - desc: Whether to sort the key-value stores in descending order. - - Returns: - The list of available key-value stores matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_request_queue_client.py b/src/crawlee/storage_clients/_base/_request_queue_client.py index 06b180801a..f43766461c 100644 --- a/src/crawlee/storage_clients/_base/_request_queue_client.py +++ b/src/crawlee/storage_clients/_base/_request_queue_client.py @@ -35,21 +35,6 @@ async def get(self) -> RequestQueueMetadata | None: An object containing the request queue's details, or None if the request queue does not exist. """ - @abstractmethod - async def update( - self, - *, - name: str | None = None, - ) -> RequestQueueMetadata: - """Update the request queue metadata. - - Args: - name: New new name for the request queue. - - Returns: - An object reflecting the updated request queue metadata. - """ - @abstractmethod async def delete(self) -> None: """Permanently delete the request queue managed by this client.""" @@ -150,6 +135,14 @@ async def delete_request(self, request_id: str) -> None: request_id: ID of the request to delete. """ + @abstractmethod + async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: + """Delete given requests from the queue. + + Args: + requests: The requests to delete from the queue. + """ + @abstractmethod async def prolong_request_lock( self, @@ -179,11 +172,3 @@ async def delete_request_lock( request_id: ID of the request to delete the lock. forefront: Whether to put the request in the beginning or the end of the queue after the lock is deleted. """ - - @abstractmethod - async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: - """Delete given requests from the queue. - - Args: - requests: The requests to delete from the queue. - """ diff --git a/src/crawlee/storage_clients/_base/_request_queue_collection_client.py b/src/crawlee/storage_clients/_base/_request_queue_collection_client.py deleted file mode 100644 index 7de876c344..0000000000 --- a/src/crawlee/storage_clients/_base/_request_queue_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import RequestQueueListPage, RequestQueueMetadata - - -@docs_group('Abstract classes') -class RequestQueueCollectionClient(ABC): - """An abstract class for request queue collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> RequestQueueMetadata: - """Retrieve an existing request queue by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the request queue to retrieve or create. If provided, the method will attempt - to find a request queue with the ID. - name: Optional name of the request queue resource to retrieve or create. If provided, the method will - attempt to find a request queue with this name. - schema: Optional schema for the request queue resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created request queue. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> RequestQueueListPage: - """List the available request queues. - - Args: - unnamed: Whether to list only the unnamed request queues. - limit: Maximum number of request queues to return. - offset: Number of request queues to skip from the beginning of the list. - desc: Whether to sort the request queues in descending order. - - Returns: - The list of available request queues matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_storage_client.py b/src/crawlee/storage_clients/_base/_storage_client.py index 4f022cf30a..de5d229443 100644 --- a/src/crawlee/storage_clients/_base/_storage_client.py +++ b/src/crawlee/storage_clients/_base/_storage_client.py @@ -9,11 +9,8 @@ if TYPE_CHECKING: from ._dataset_client import DatasetClient - from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient - from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._request_queue_client import RequestQueueClient - from ._request_queue_collection_client import RequestQueueCollectionClient @docs_group('Abstract classes') @@ -28,26 +25,14 @@ class StorageClient(ABC): def dataset(self, id: str) -> DatasetClient: """Get a subclient for a specific dataset by its ID.""" - @abstractmethod - def datasets(self) -> DatasetCollectionClient: - """Get a subclient for dataset collection operations.""" - @abstractmethod def key_value_store(self, id: str) -> KeyValueStoreClient: """Get a subclient for a specific key-value store by its ID.""" - @abstractmethod - def key_value_stores(self) -> KeyValueStoreCollectionClient: - """Get a subclient for key-value store collection operations.""" - @abstractmethod def request_queue(self, id: str) -> RequestQueueClient: """Get a subclient for a specific request queue by its ID.""" - @abstractmethod - def request_queues(self) -> RequestQueueCollectionClient: - """Get a subclient for request queue collection operations.""" - @abstractmethod async def purge_on_start(self) -> None: """Perform a purge of the default storages. diff --git a/src/crawlee/storage_clients/_base/_types.py b/src/crawlee/storage_clients/_base/_types.py index a5cf1325f5..f644fe5410 100644 --- a/src/crawlee/storage_clients/_base/_types.py +++ b/src/crawlee/storage_clients/_base/_types.py @@ -3,20 +3,11 @@ from typing import Union from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient ResourceClient = Union[ DatasetClient, KeyValueStoreClient, RequestQueueClient, ] - -ResourceCollectionClient = Union[ - DatasetCollectionClient, - KeyValueStoreCollectionClient, - RequestQueueCollectionClient, -] diff --git a/src/crawlee/storage_clients/_memory/__init__.py b/src/crawlee/storage_clients/_memory/__init__.py index 09912e124d..355797673d 100644 --- a/src/crawlee/storage_clients/_memory/__init__.py +++ b/src/crawlee/storage_clients/_memory/__init__.py @@ -1,17 +1,11 @@ from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._memory_storage_client import MemoryStorageClient from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient __all__ = [ 'DatasetClient', - 'DatasetCollectionClient', 'KeyValueStoreClient', - 'KeyValueStoreCollectionClient', 'MemoryStorageClient', 'RequestQueueClient', - 'RequestQueueCollectionClient', ] diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 50c8c7c8d4..40ad5c2b13 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -12,8 +12,8 @@ from crawlee._types import StorageTypes from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import force_rename, json_dumps +from crawlee._utils.data_processing import raise_on_non_existing_storage +from crawlee._utils.file import json_dumps from crawlee.storage_clients._base import DatasetClient as BaseDatasetClient from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata @@ -96,7 +96,25 @@ async def get(self) -> DatasetMetadata | None: return None @override - async def update(self, *, name: str | None = None) -> DatasetMetadata: + async def delete(self) -> None: + dataset = next( + (dataset for dataset in self._memory_storage_client.datasets_handled if dataset.id == self.id), None + ) + + if dataset is not None: + async with dataset.file_operation_lock: + self._memory_storage_client.datasets_handled.remove(dataset) + dataset.item_count = 0 + dataset.dataset_entries.clear() + + if os.path.exists(dataset.resource_directory): + await asyncio.to_thread(shutil.rmtree, dataset.resource_directory) + + @override + async def push_items( + self, + items: JsonSerializable, + ) -> None: # Check by id existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( resource_client_class=DatasetClient, @@ -108,48 +126,26 @@ async def update(self, *, name: str | None = None) -> DatasetMetadata: if existing_dataset_by_id is None: raise_on_non_existing_storage(StorageTypes.DATASET, self.id) - # Skip if no changes - if name is None: - return existing_dataset_by_id.resource_info - - async with existing_dataset_by_id.file_operation_lock: - # Check that name is not in use already - existing_dataset_by_name = next( - ( - dataset - for dataset in self._memory_storage_client.datasets_handled - if dataset.name and dataset.name.lower() == name.lower() - ), - None, - ) + normalized = self._normalize_items(items) - if existing_dataset_by_name is not None: - raise_on_duplicate_storage(StorageTypes.DATASET, 'name', name) + added_ids: list[str] = [] + for entry in normalized: + existing_dataset_by_id.item_count += 1 + idx = self._generate_local_entry_name(existing_dataset_by_id.item_count) - previous_dir = existing_dataset_by_id.resource_directory - existing_dataset_by_id.name = name + existing_dataset_by_id.dataset_entries[idx] = entry + added_ids.append(idx) - await force_rename(previous_dir, existing_dataset_by_id.resource_directory) + data_entries = [(id, existing_dataset_by_id.dataset_entries[id]) for id in added_ids] - # Update timestamps + async with existing_dataset_by_id.file_operation_lock: await existing_dataset_by_id.update_timestamps(has_been_modified=True) - return existing_dataset_by_id.resource_info - - @override - async def delete(self) -> None: - dataset = next( - (dataset for dataset in self._memory_storage_client.datasets_handled if dataset.id == self.id), None - ) - - if dataset is not None: - async with dataset.file_operation_lock: - self._memory_storage_client.datasets_handled.remove(dataset) - dataset.item_count = 0 - dataset.dataset_entries.clear() - - if os.path.exists(dataset.resource_directory): - await asyncio.to_thread(shutil.rmtree, dataset.resource_directory) + await self._persist_dataset_items_to_disk( + data=data_entries, + entity_directory=existing_dataset_by_id.resource_directory, + persist_storage=self._memory_storage_client.persist_storage, + ) @override async def list_items( @@ -288,43 +284,6 @@ async def stream_items( ) -> AbstractAsyncContextManager[Response | None]: raise NotImplementedError('This method is not supported in memory storage.') - @override - async def push_items( - self, - items: JsonSerializable, - ) -> None: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) - - normalized = self._normalize_items(items) - - added_ids: list[str] = [] - for entry in normalized: - existing_dataset_by_id.item_count += 1 - idx = self._generate_local_entry_name(existing_dataset_by_id.item_count) - - existing_dataset_by_id.dataset_entries[idx] = entry - added_ids.append(idx) - - data_entries = [(id, existing_dataset_by_id.dataset_entries[id]) for id in added_ids] - - async with existing_dataset_by_id.file_operation_lock: - await existing_dataset_by_id.update_timestamps(has_been_modified=True) - - await self._persist_dataset_items_to_disk( - data=data_entries, - entity_directory=existing_dataset_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, - ) - async def _persist_dataset_items_to_disk( self, *, diff --git a/src/crawlee/storage_clients/_memory/_dataset_collection_client.py b/src/crawlee/storage_clients/_memory/_dataset_collection_client.py deleted file mode 100644 index 9e32b4086b..0000000000 --- a/src/crawlee/storage_clients/_memory/_dataset_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import DatasetCollectionClient as BaseDatasetCollectionClient -from crawlee.storage_clients.models import DatasetListPage, DatasetMetadata - -from ._creation_management import get_or_create_inner -from ._dataset_client import DatasetClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class DatasetCollectionClient(BaseDatasetCollectionClient): - """Subclient for manipulating datasets.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[DatasetClient]: - return self._memory_storage_client.datasets_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> DatasetMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=DatasetClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> DatasetListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return DatasetListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index ab9def0f06..e7f18fb175 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -12,8 +12,8 @@ from crawlee._types import StorageTypes from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import maybe_parse_body, raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import determine_file_extension, force_remove, force_rename, is_file_or_bytes, json_dumps +from crawlee._utils.data_processing import maybe_parse_body, raise_on_non_existing_storage +from crawlee._utils.file import determine_file_extension, force_remove, is_file_or_bytes, json_dumps from crawlee.storage_clients._base import KeyValueStoreClient as BaseKeyValueStoreClient from crawlee.storage_clients.models import ( KeyValueStoreKeyInfo, @@ -92,47 +92,6 @@ async def get(self) -> KeyValueStoreMetadata | None: return None - @override - async def update(self, *, name: str | None = None) -> KeyValueStoreMetadata: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - # Skip if no changes - if name is None: - return existing_store_by_id.resource_info - - async with existing_store_by_id.file_operation_lock: - # Check that name is not in use already - existing_store_by_name = next( - ( - store - for store in self._memory_storage_client.key_value_stores_handled - if store.name and store.name.lower() == name.lower() - ), - None, - ) - - if existing_store_by_name is not None: - raise_on_duplicate_storage(StorageTypes.KEY_VALUE_STORE, 'name', name) - - previous_dir = existing_store_by_id.resource_directory - existing_store_by_id.name = name - - await force_rename(previous_dir, existing_store_by_id.resource_directory) - - # Update timestamps - await existing_store_by_id.update_timestamps(has_been_modified=True) - - return existing_store_by_id.resource_info - @override async def delete(self) -> None: store = next( diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py deleted file mode 100644 index 939780449f..0000000000 --- a/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import KeyValueStoreCollectionClient as BaseKeyValueStoreCollectionClient -from crawlee.storage_clients.models import KeyValueStoreListPage, KeyValueStoreMetadata - -from ._creation_management import get_or_create_inner -from ._key_value_store_client import KeyValueStoreClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class KeyValueStoreCollectionClient(BaseKeyValueStoreCollectionClient): - """Subclient for manipulating key-value stores.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[KeyValueStoreClient]: - return self._memory_storage_client.key_value_stores_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> KeyValueStoreMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=KeyValueStoreClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> KeyValueStoreListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return KeyValueStoreListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_memory_storage_client.py b/src/crawlee/storage_clients/_memory/_memory_storage_client.py index 8000f41274..ec60f51145 100644 --- a/src/crawlee/storage_clients/_memory/_memory_storage_client.py +++ b/src/crawlee/storage_clients/_memory/_memory_storage_client.py @@ -15,11 +15,8 @@ from crawlee.storage_clients import StorageClient from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient if TYPE_CHECKING: from crawlee.storage_clients._base import ResourceClient @@ -146,26 +143,14 @@ def request_queues_directory(self) -> str: def dataset(self, id: str) -> DatasetClient: return DatasetClient(memory_storage_client=self, id=id) - @override - def datasets(self) -> DatasetCollectionClient: - return DatasetCollectionClient(memory_storage_client=self) - @override def key_value_store(self, id: str) -> KeyValueStoreClient: return KeyValueStoreClient(memory_storage_client=self, id=id) - @override - def key_value_stores(self) -> KeyValueStoreCollectionClient: - return KeyValueStoreCollectionClient(memory_storage_client=self) - @override def request_queue(self, id: str) -> RequestQueueClient: return RequestQueueClient(memory_storage_client=self, id=id) - @override - def request_queues(self) -> RequestQueueCollectionClient: - return RequestQueueCollectionClient(memory_storage_client=self) - @override async def purge_on_start(self) -> None: # Optimistic, non-blocking check diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index 0031e54abd..687260d91d 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -13,8 +13,8 @@ from crawlee._types import StorageTypes from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import force_remove, force_rename, json_dumps +from crawlee._utils.data_processing import raise_on_non_existing_storage +from crawlee._utils.file import force_remove, json_dumps from crawlee._utils.requests import unique_key_to_request_id from crawlee.storage_clients._base import RequestQueueClient as BaseRequestQueueClient from crawlee.storage_clients.models import ( @@ -113,47 +113,6 @@ async def get(self) -> RequestQueueMetadata | None: return None - @override - async def update(self, *, name: str | None = None) -> RequestQueueMetadata: - # Check by id - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - # Skip if no changes - if name is None: - return existing_queue_by_id.resource_info - - async with existing_queue_by_id.file_operation_lock: - # Check that name is not in use already - existing_queue_by_name = next( - ( - queue - for queue in self._memory_storage_client.request_queues_handled - if queue.name and queue.name.lower() == name.lower() - ), - None, - ) - - if existing_queue_by_name is not None: - raise_on_duplicate_storage(StorageTypes.REQUEST_QUEUE, 'name', name) - - previous_dir = existing_queue_by_id.resource_directory - existing_queue_by_id.name = name - - await force_rename(previous_dir, existing_queue_by_id.resource_directory) - - # Update timestamps - await existing_queue_by_id.update_timestamps(has_been_modified=True) - - return existing_queue_by_id.resource_info - @override async def delete(self) -> None: queue = next( @@ -292,6 +251,42 @@ async def add_request( was_already_handled=False, ) + @override + async def batch_add_requests( + self, + requests: Sequence[Request], + *, + forefront: bool = False, + ) -> BatchRequestsOperationResponse: + processed_requests = list[ProcessedRequest]() + unprocessed_requests = list[UnprocessedRequest]() + + for request in requests: + try: + processed_request = await self.add_request(request, forefront=forefront) + processed_requests.append( + ProcessedRequest( + id=processed_request.id, + unique_key=processed_request.unique_key, + was_already_present=processed_request.was_already_present, + was_already_handled=processed_request.was_already_handled, + ) + ) + except Exception as exc: # noqa: PERF203 + logger.warning(f'Error adding request to the queue: {exc}') + unprocessed_requests.append( + UnprocessedRequest( + unique_key=request.unique_key, + url=request.url, + method=request.method, + ) + ) + + return BatchRequestsOperationResponse( + processed_requests=processed_requests, + unprocessed_requests=unprocessed_requests, + ) + @override async def get_request(self, request_id: str) -> Request | None: existing_queue_by_id = find_or_create_client_by_id_or_name_inner( @@ -397,6 +392,10 @@ async def delete_request(self, request_id: str) -> None: request_id=request_id, ) + @override + async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: + raise NotImplementedError('This method is not supported in memory storage.') + @override async def prolong_request_lock( self, @@ -426,46 +425,6 @@ async def delete_request_lock( existing_queue_by_id._in_progress.discard(request_id) # noqa: SLF001 - @override - async def batch_add_requests( - self, - requests: Sequence[Request], - *, - forefront: bool = False, - ) -> BatchRequestsOperationResponse: - processed_requests = list[ProcessedRequest]() - unprocessed_requests = list[UnprocessedRequest]() - - for request in requests: - try: - processed_request = await self.add_request(request, forefront=forefront) - processed_requests.append( - ProcessedRequest( - id=processed_request.id, - unique_key=processed_request.unique_key, - was_already_present=processed_request.was_already_present, - was_already_handled=processed_request.was_already_handled, - ) - ) - except Exception as exc: # noqa: PERF203 - logger.warning(f'Error adding request to the queue: {exc}') - unprocessed_requests.append( - UnprocessedRequest( - unique_key=request.unique_key, - url=request.url, - method=request.method, - ) - ) - - return BatchRequestsOperationResponse( - processed_requests=processed_requests, - unprocessed_requests=unprocessed_requests, - ) - - @override - async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: - raise NotImplementedError('This method is not supported in memory storage.') - async def update_timestamps(self, *, has_been_modified: bool) -> None: """Update the timestamps of the request queue.""" self._accessed_at = datetime.now(timezone.utc) diff --git a/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py b/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py deleted file mode 100644 index 2f2df2be89..0000000000 --- a/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import RequestQueueCollectionClient as BaseRequestQueueCollectionClient -from crawlee.storage_clients.models import RequestQueueListPage, RequestQueueMetadata - -from ._creation_management import get_or_create_inner -from ._request_queue_client import RequestQueueClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class RequestQueueCollectionClient(BaseRequestQueueCollectionClient): - """Subclient for manipulating request queues.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[RequestQueueClient]: - return self._memory_storage_client.request_queues_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> RequestQueueMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=RequestQueueClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> RequestQueueListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return RequestQueueListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storages/_creation_management.py b/src/crawlee/storages/_creation_management.py index d7356a98b5..0ba1f0739e 100644 --- a/src/crawlee/storages/_creation_management.py +++ b/src/crawlee/storages/_creation_management.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from crawlee.configuration import Configuration - from crawlee.storage_clients._base import ResourceClient, ResourceCollectionClient, StorageClient + from crawlee.storage_clients._base import ResourceClient, StorageClient TResource = TypeVar('TResource', Dataset, KeyValueStore, RequestQueue) @@ -208,19 +208,3 @@ def _get_resource_client( return storage_client.request_queue(id) raise ValueError(f'Unknown storage class label: {storage_class.__name__}') - - -def _get_resource_collection_client( - storage_class: type, - storage_client: StorageClient, -) -> ResourceCollectionClient: - if issubclass(storage_class, Dataset): - return storage_client.datasets() - - if issubclass(storage_class, KeyValueStore): - return storage_client.key_value_stores() - - if issubclass(storage_class, RequestQueue): - return storage_client.request_queues() - - raise ValueError(f'Unknown storage class: {storage_class.__name__}') diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 7cb58ae817..c19c28d58a 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -205,7 +205,6 @@ def __init__(self, id: str, name: str | None, storage_client: StorageClient) -> # Get resource clients from the storage client. self._resource_client = storage_client.dataset(self._id) - self._resource_collection_client = storage_client.datasets() @classmethod def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> Dataset: diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index b3274ccc81..6ef58c047a 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -96,7 +96,6 @@ def __init__( # Get resource clients from storage client self._resource_client = storage_client.request_queue(self._id) - self._resource_collection_client = storage_client.request_queues() self._request_lock_time = timedelta(minutes=3) self._queue_paused_for_migration = False From c0810256d9154ff504c6cfdbb57d1872aaade6dd Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Fri, 14 Mar 2025 17:04:23 +0100 Subject: [PATCH 02/22] Update of dataset and its clients --- pyproject.toml | 1 - src/crawlee/_service_locator.py | 8 +- src/crawlee/_types.py | 4 - .../_browserforge_adapter.py | 6 +- src/crawlee/storage_clients/__init__.py | 9 +- src/crawlee/storage_clients/_base/__init__.py | 5 - .../storage_clients/_base/_dataset_client.py | 214 +++---- .../storage_clients/_base/_storage_client.py | 43 +- src/crawlee/storage_clients/_base/_types.py | 13 - .../storage_clients/_file_system/__init__.py | 3 + .../_file_system/_dataset_client.py | 396 +++++++++++++ .../_file_system/_key_value_store.py | 11 + .../_file_system/_request_queue.py | 11 + .../_file_system/_storage_client.py | 13 + .../storage_clients/_file_system/_utils.py | 23 + .../storage_clients/_file_system/py.typed | 0 .../storage_clients/_memory/__init__.py | 12 +- .../_memory/_creation_management.py | 429 -------------- .../_memory/_dataset_client.py | 420 +++++--------- .../_memory/_key_value_store_client.py | 379 +----------- .../_memory/_memory_storage_client.py | 343 ----------- .../_memory/_request_queue_client.py | 512 +---------------- .../_memory/_storage_client.py | 13 + src/crawlee/storages/_creation_management.py | 8 +- src/crawlee/storages/_dataset.py | 540 +++++++----------- src/crawlee/storages/_request_queue.py | 11 +- src/crawlee/storages/_types.py | 159 ++++++ .../test_adaptive_playwright_crawler.py | 2 +- .../_memory/test_memory_storage_client.py | 2 +- tests/unit/storages/test_dataset.py | 2 +- 30 files changed, 1052 insertions(+), 2540 deletions(-) delete mode 100644 src/crawlee/storage_clients/_base/_types.py create mode 100644 src/crawlee/storage_clients/_file_system/__init__.py create mode 100644 src/crawlee/storage_clients/_file_system/_dataset_client.py create mode 100644 src/crawlee/storage_clients/_file_system/_key_value_store.py create mode 100644 src/crawlee/storage_clients/_file_system/_request_queue.py create mode 100644 src/crawlee/storage_clients/_file_system/_storage_client.py create mode 100644 src/crawlee/storage_clients/_file_system/_utils.py create mode 100644 src/crawlee/storage_clients/_file_system/py.typed delete mode 100644 src/crawlee/storage_clients/_memory/_creation_management.py delete mode 100644 src/crawlee/storage_clients/_memory/_memory_storage_client.py create mode 100644 src/crawlee/storage_clients/_memory/_storage_client.py create mode 100644 src/crawlee/storages/_types.py diff --git a/pyproject.toml b/pyproject.toml index bf857af4f7..ece3fe7956 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,7 +143,6 @@ ignore = [ "PLR0911", # Too many return statements "PLR0913", # Too many arguments in function definition "PLR0915", # Too many statements - "PTH", # flake8-use-pathlib "PYI034", # `__aenter__` methods in classes like `{name}` usually return `self` at runtime "PYI036", # The second argument in `__aexit__` should be annotated with `object` or `BaseException | None` "S102", # Use of `exec` detected diff --git a/src/crawlee/_service_locator.py b/src/crawlee/_service_locator.py index 31bc36c63c..08b25c35cd 100644 --- a/src/crawlee/_service_locator.py +++ b/src/crawlee/_service_locator.py @@ -77,13 +77,9 @@ def set_event_manager(self, event_manager: EventManager) -> None: def get_storage_client(self) -> StorageClient: """Get the storage client.""" if self._storage_client is None: - from crawlee.storage_clients import MemoryStorageClient + from crawlee.storage_clients import file_system_storage_client - self._storage_client = ( - MemoryStorageClient.from_config(config=self._configuration) - if self._configuration - else MemoryStorageClient.from_config() - ) + self._storage_client = file_system_storage_client self._storage_client_was_retrieved = True return self._storage_client diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index c68ae63df9..9b6cb0f2e7 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -275,10 +275,6 @@ async def push_data( **kwargs: Unpack[PushDataKwargs], ) -> None: """Track a call to the `push_data` context helper.""" - from crawlee.storages._dataset import Dataset - - await Dataset.check_and_serialize(data) - self.push_data_calls.append( PushDataFunctionCall( data=data, diff --git a/src/crawlee/fingerprint_suite/_browserforge_adapter.py b/src/crawlee/fingerprint_suite/_browserforge_adapter.py index d64ddd59f0..11f9f82d79 100644 --- a/src/crawlee/fingerprint_suite/_browserforge_adapter.py +++ b/src/crawlee/fingerprint_suite/_browserforge_adapter.py @@ -1,10 +1,10 @@ from __future__ import annotations -import os.path from collections.abc import Iterable from copy import deepcopy from functools import reduce from operator import or_ +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal from browserforge.bayesian_network import extract_json @@ -253,9 +253,9 @@ def generate(self, browser_type: SupportedBrowserType = 'chromium') -> dict[str, def get_available_header_network() -> dict: """Get header network that contains possible header values.""" - if os.path.isfile(DATA_DIR / 'header-network.zip'): + if Path(DATA_DIR / 'header-network.zip').is_file(): return extract_json(DATA_DIR / 'header-network.zip') - if os.path.isfile(DATA_DIR / 'header-network-definition.zip'): + if Path(DATA_DIR / 'header-network-definition.zip').is_file(): return extract_json(DATA_DIR / 'header-network-definition.zip') raise FileNotFoundError('Missing header-network file.') diff --git a/src/crawlee/storage_clients/__init__.py b/src/crawlee/storage_clients/__init__.py index 66d352d7a7..848c160c37 100644 --- a/src/crawlee/storage_clients/__init__.py +++ b/src/crawlee/storage_clients/__init__.py @@ -1,4 +1,9 @@ from ._base import StorageClient -from ._memory import MemoryStorageClient +from ._file_system import file_system_storage_client +from ._memory import memory_storage_client -__all__ = ['MemoryStorageClient', 'StorageClient'] +__all__ = [ + 'StorageClient', + 'file_system_storage_client', + 'memory_storage_client' +] diff --git a/src/crawlee/storage_clients/_base/__init__.py b/src/crawlee/storage_clients/_base/__init__.py index ae8151e15f..73298560da 100644 --- a/src/crawlee/storage_clients/_base/__init__.py +++ b/src/crawlee/storage_clients/_base/__init__.py @@ -2,15 +2,10 @@ from ._key_value_store_client import KeyValueStoreClient from ._request_queue_client import RequestQueueClient from ._storage_client import StorageClient -from ._types import ResourceClient __all__ = [ 'DatasetClient', - 'DatasetCollectionClient', 'KeyValueStoreClient', - 'KeyValueStoreCollectionClient', 'RequestQueueClient', - 'RequestQueueCollectionClient', - 'ResourceClient', 'StorageClient', ] diff --git a/src/crawlee/storage_clients/_base/_dataset_client.py b/src/crawlee/storage_clients/_base/_dataset_client.py index 02beb0c6d5..d68bb35c84 100644 --- a/src/crawlee/storage_clients/_base/_dataset_client.py +++ b/src/crawlee/storage_clients/_base/_dataset_client.py @@ -7,12 +7,11 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from contextlib import AbstractAsyncContextManager + from datetime import datetime + from pathlib import Path + from typing import Any - from httpx import Response - - from crawlee._types import JsonSerializable - from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata + from crawlee.storage_clients.models import DatasetItemsListPage @docs_group('Abstract classes') @@ -23,35 +22,78 @@ class DatasetClient(ABC): client, like a memory storage client. """ - _LIST_ITEMS_LIMIT = 999_999_999_999 - """This is what API returns in the x-apify-pagination-limit header when no limit query parameter is used.""" + @property + @abstractmethod + def id(self) -> str: + """The ID of the dataset.""" + + @property + @abstractmethod + def name(self) -> str | None: + """The name of the dataset.""" + + @property + @abstractmethod + def created_at(self) -> datetime: + """The time at which the dataset was created.""" + + @property + @abstractmethod + def accessed_at(self) -> datetime: + """The time at which the dataset was last accessed.""" + @property @abstractmethod - async def get(self) -> DatasetMetadata | None: - """Get metadata about the dataset being managed by this client. + def modified_at(self) -> datetime: + """The time at which the dataset was last modified.""" + + @property + @abstractmethod + def item_count(self) -> int: + """The number of items in the dataset.""" + + @classmethod + @abstractmethod + async def open( + cls, + id: str | None, + name: str | None, + storage_dir: Path, + ) -> DatasetClient: + """Open existing or create a new dataset client. + + If a dataset with the given name already exists, the appropriate dataset client is returned. + Otherwise, a new dataset is created and client for it is returned. + + Args: + id: The ID of the dataset. + name: The name of the dataset. + storage_dir: The path to the storage directory. If the client persists data, it should use this directory. Returns: - An object containing the dataset's details, or None if the dataset does not exist. + A dataset client. """ @abstractmethod - async def delete(self) -> None: - """Permanently delete the dataset managed by this client.""" + async def drop(self) -> None: + """Drop the whole dataset and remove all its items. + + The backend method for the `Dataset.drop` call. + """ @abstractmethod - async def push_items(self, items: JsonSerializable) -> None: - """Push items to the dataset. + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + """Push data to the dataset. - Args: - items: The items which to push in the dataset. They must be JSON serializable. + The backend method for the `Dataset.push_data` call. """ @abstractmethod - async def list_items( + async def get_data( self, *, - offset: int | None = 0, - limit: int | None = _LIST_ITEMS_LIMIT, + offset: int = 0, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -62,31 +104,13 @@ async def list_items( flatten: list[str] | None = None, view: str | None = None, ) -> DatasetItemsListPage: - """Retrieve a paginated list of items from a dataset based on various filtering parameters. - - This method provides the flexibility to filter, sort, and modify the appearance of dataset items - when listed. Each parameter modifies the result set according to its purpose. The method also - supports pagination through 'offset' and 'limit' parameters. - - Args: - offset: The number of initial items to skip. - limit: The maximum number of items to return. - clean: If True, removes empty items and hidden fields, equivalent to 'skip_hidden' and 'skip_empty'. - desc: If True, items are returned in descending order, i.e., newest first. - fields: Specifies a subset of fields to include in each item. - omit: Specifies a subset of fields to exclude from each item. - unwind: Specifies a field that should be unwound. If it's an array, each element becomes a separate record. - skip_empty: If True, omits items that are empty after other filters have been applied. - skip_hidden: If True, omits fields starting with the '#' character. - flatten: A list of fields to flatten in each item. - view: The specific view of the dataset to use when retrieving items. + """Get data from the dataset. - Returns: - An object with filtered, sorted, and paginated dataset items plus pagination details. + The backend method for the `Dataset.get_data` call. """ @abstractmethod - async def iterate_items( + async def iterate( self, *, offset: int = 0, @@ -99,118 +123,12 @@ async def iterate_items( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict]: - """Iterate over items in the dataset according to specified filters and sorting. - - This method allows for asynchronously iterating through dataset items while applying various filters such as - skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` - parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and - `skip_hidden` parameters. + """Iterate over the dataset. - Args: - offset: The number of initial items to skip. - limit: The maximum number of items to iterate over. None means no limit. - clean: If True, removes empty items and hidden fields, equivalent to 'skip_hidden' and 'skip_empty'. - desc: If set to True, items are returned in descending order, i.e., newest first. - fields: Specifies a subset of fields to include in each item. - omit: Specifies a subset of fields to exclude from each item. - unwind: Specifies a field that should be unwound into separate items. - skip_empty: If set to True, omits items that are empty after other filters have been applied. - skip_hidden: If set to True, omits fields starting with the '#' character from the output. - - Yields: - An asynchronous iterator of dictionary objects, each representing a dataset item after applying - the specified filters and transformations. + The backend method for the `Dataset.iterate` call. """ # This syntax is to make mypy properly work with abstract AsyncIterator. # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators raise NotImplementedError if False: # type: ignore[unreachable] yield 0 - - @abstractmethod - async def get_items_as_bytes( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - flatten: list[str] | None = None, - ) -> bytes: - """Retrieve dataset items as bytes. - - Args: - item_format: Output format (e.g., 'json', 'csv'); default is 'json'. - offset: Number of items to skip; default is 0. - limit: Max number of items to return; no default limit. - desc: If True, results are returned in descending order. - clean: If True, filters out empty items and hidden fields. - bom: Include or exclude UTF-8 BOM; default behavior varies by format. - delimiter: Delimiter character for CSV; default is ','. - fields: List of fields to include in the results. - omit: List of fields to omit from the results. - unwind: Unwinds a field into separate records. - skip_empty: If True, skips empty items in the output. - skip_header_row: If True, skips the header row in CSV. - skip_hidden: If True, skips hidden fields in the output. - xml_root: Root element name for XML output; default is 'items'. - xml_row: Element name for each item in XML output; default is 'item'. - flatten: List of fields to flatten. - - Returns: - The dataset items as raw bytes. - """ - - @abstractmethod - async def stream_items( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - ) -> AbstractAsyncContextManager[Response | None]: - """Retrieve dataset items as a streaming response. - - Args: - item_format: Output format, options include json, jsonl, csv, html, xlsx, xml, rss; default is json. - offset: Number of items to skip at the start; default is 0. - limit: Maximum number of items to return; no default limit. - desc: If True, reverses the order of results. - clean: If True, filters out empty items and hidden fields. - bom: Include or exclude UTF-8 BOM; varies by format. - delimiter: Delimiter for CSV files; default is ','. - fields: List of fields to include in the output. - omit: List of fields to omit from the output. - unwind: Unwinds a field into separate records. - skip_empty: If True, empty items are omitted. - skip_header_row: If True, skips the header row in CSV. - skip_hidden: If True, hides fields starting with the # character. - xml_root: Custom root element name for XML output; default is 'items'. - xml_row: Custom element name for each item in XML; default is 'item'. - - Yields: - The dataset items in a streaming response. - """ diff --git a/src/crawlee/storage_clients/_base/_storage_client.py b/src/crawlee/storage_clients/_base/_storage_client.py index de5d229443..fc7e2e4d97 100644 --- a/src/crawlee/storage_clients/_base/_storage_client.py +++ b/src/crawlee/storage_clients/_base/_storage_client.py @@ -1,47 +1,16 @@ -# Inspiration: https://github.com/apify/crawlee/blob/v3.8.2/packages/types/src/storages.ts#L314:L328 - from __future__ import annotations -from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import TYPE_CHECKING -from crawlee._utils.docs import docs_group - if TYPE_CHECKING: from ._dataset_client import DatasetClient from ._key_value_store_client import KeyValueStoreClient from ._request_queue_client import RequestQueueClient -@docs_group('Abstract classes') -class StorageClient(ABC): - """Defines an abstract base for storage clients. - - It offers interfaces to get subclients for interacting with storage resources like datasets, key-value stores, - and request queues. - """ - - @abstractmethod - def dataset(self, id: str) -> DatasetClient: - """Get a subclient for a specific dataset by its ID.""" - - @abstractmethod - def key_value_store(self, id: str) -> KeyValueStoreClient: - """Get a subclient for a specific key-value store by its ID.""" - - @abstractmethod - def request_queue(self, id: str) -> RequestQueueClient: - """Get a subclient for a specific request queue by its ID.""" - - @abstractmethod - async def purge_on_start(self) -> None: - """Perform a purge of the default storages. - - This method ensures that the purge is executed only once during the lifetime of the instance. - It is primarily used to clean up residual data from previous runs to maintain a clean state. - If the storage client does not support purging, leave it empty. - """ - - def get_rate_limit_errors(self) -> dict[int, int]: - """Return statistics about rate limit errors encountered by the HTTP client in storage client.""" - return {} +@dataclass +class StorageClient: + dataset_client_class: type[DatasetClient] + key_value_store_client_class: type[KeyValueStoreClient] + request_queue_client_class: type[RequestQueueClient] diff --git a/src/crawlee/storage_clients/_base/_types.py b/src/crawlee/storage_clients/_base/_types.py deleted file mode 100644 index f644fe5410..0000000000 --- a/src/crawlee/storage_clients/_base/_types.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations - -from typing import Union - -from ._dataset_client import DatasetClient -from ._key_value_store_client import KeyValueStoreClient -from ._request_queue_client import RequestQueueClient - -ResourceClient = Union[ - DatasetClient, - KeyValueStoreClient, - RequestQueueClient, -] diff --git a/src/crawlee/storage_clients/_file_system/__init__.py b/src/crawlee/storage_clients/_file_system/__init__.py new file mode 100644 index 0000000000..3aa67ad6dc --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/__init__.py @@ -0,0 +1,3 @@ +from ._storage_client import file_system_storage_client + +__all__ = ['file_system_storage_client'] diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py new file mode 100644 index 0000000000..b05340cae2 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +from datetime import datetime, timezone +from logging import getLogger +from typing import TYPE_CHECKING + +from pydantic import ValidationError +from typing_extensions import override + +from crawlee._consts import METADATA_FILENAME +from crawlee._utils.crypto import crypto_random_object_id +from crawlee.storage_clients._base import DatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata + +from ._utils import json_dumps + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from pathlib import Path + from typing import Any + +logger = getLogger(__name__) + + +class FileSystemDatasetClient(DatasetClient): + """A file system storage implementation of the dataset client. + + This client stores dataset items as individual JSON files in a subdirectory. + The metadata of the dataset (timestamps, item count, etc.) is stored in a metadata file. + """ + + _DEFAULT_NAME = 'default' + """The name of the unnamed dataset.""" + + _STORAGE_SUBDIR = 'datasets' + """The name of the subdirectory where datasets are stored.""" + + _LOCAL_ENTRY_NAME_DIGITS = 9 + """Number of digits used for the file names (e.g., 000000019.json).""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + item_count: int, + storage_dir: Path, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemDatasetClient.open` class method to create a new instance. + """ + self._id = id + self._name = name + self._created_at = created_at + self._accessed_at = accessed_at + self._modified_at = modified_at + self._item_count = item_count + self._storage_dir = storage_dir + + # Internal attributes. + self._lock = asyncio.Lock() + """A lock to ensure that only one file operation is performed at a time.""" + + @override + @property + def id(self) -> str: + return self._id + + @override + @property + def name(self) -> str | None: + return self._name + + @override + @property + def created_at(self) -> datetime: + return self._created_at + + @override + @property + def accessed_at(self) -> datetime: + return self._accessed_at + + @override + @property + def modified_at(self) -> datetime: + return self._modified_at + + @override + @property + def item_count(self) -> int: + return self._item_count + + @property + def _path_to_dataset(self) -> Path: + """The full path to the dataset directory.""" + return self._storage_dir / self._STORAGE_SUBDIR / self._name + + @property + def _path_to_metadata(self) -> Path: + """The full path to the dataset metadata file.""" + return self._path_to_dataset / METADATA_FILENAME + + @override + @classmethod + async def open( + cls, + id: str | None, + name: str | None, + storage_dir: Path, + ) -> FileSystemDatasetClient: + """Open an existing dataset client or create a new one if it does not exist. + + If the dataset directory exists, this method reconstructs the client from the metadata file. + Otherwise, a new dataset client is created with a new unique ID. + + Args: + id: The dataset ID. + name: The dataset name; if not provided, defaults to the default name. + storage_dir: The base directory for storage. + + Returns: + A new instance of the file system dataset client. + """ + name = name or cls._DEFAULT_NAME + dataset_path = storage_dir / cls._STORAGE_SUBDIR / name + metadata_path = dataset_path / METADATA_FILENAME + + # If the dataset directory exists, reconstruct the client from the metadata file. + if dataset_path.exists(): + # If metadata file is missing, raise an error. + if not metadata_path.exists(): + raise ValueError(f'Metadata file not found for dataset "{name}"') + + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = DatasetMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for dataset "{name}"') from exc + + client = cls( + id=metadata.id, + name=name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + item_count=metadata.item_count, + storage_dir=storage_dir, + ) + + await client._update_metadata(update_accessed_at=True) + + # Otherwise, create a new dataset client. + else: + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=datetime.now(timezone.utc), + accessed_at=datetime.now(timezone.utc), + modified_at=datetime.now(timezone.utc), + item_count=0, + storage_dir=storage_dir, + ) + await client._update_metadata() + + return client + + @override + async def drop(self) -> None: + # If the dataset directory exists, remove it recursively. + if self._path_to_dataset.exists(): + async with self._lock: + await asyncio.to_thread(shutil.rmtree, self._path_to_dataset) + + @override + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + # If data is a list, push each item individually. + if isinstance(data, list): + for item in data: + await self._push_item(item) + else: + await self._push_item(data) + + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + # Check for unsupported arguments and log a warning if found. + unsupported_args = [clean, fields, omit, unwind, skip_hidden, flatten, view] + invalid = [arg for arg in unsupported_args if arg not in (False, None)] + if invalid: + logger.warning( + f'The arguments {invalid} of iterate_items are not supported by the {self.__class__.__name__} client.' + ) + + # If the dataset directory does not exist, log a warning and return an empty page. + if not self._path_to_dataset.exists(): + logger.warning(f'Dataset directory not found: {self._path_to_dataset}') + return DatasetItemsListPage( + count=0, + offset=offset, + limit=limit or 0, + total=0, + desc=desc, + items=[], + ) + + # Get the list of sorted data files. + data_files = await self._get_sorted_data_files() + total = len(data_files) + + # Reverse the order if descending order is requested. + if desc: + data_files.reverse() + + # Apply offset and limit slicing. + selected_files = data_files[offset:] + if limit is not None: + selected_files = selected_files[:limit] + + # Read and parse each data file. + items = [] + for file_path in selected_files: + try: + file_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + item = json.loads(file_content) + except Exception: + logger.exception(f'Error reading {file_path}, skipping the item.') + continue + + # Skip empty items if requested. + if skip_empty and not item: + continue + + items.append(item) + + await self._update_metadata(update_accessed_at=True) + + # Return a paginated list page of dataset items. + return DatasetItemsListPage( + count=len(items), + offset=offset, + limit=limit or total - offset, + total=total, + desc=desc, + items=items, + ) + + @override + async def iterate( + self, + *, + offset: int = 0, + limit: int | None = None, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> AsyncIterator[dict]: + # Check for unsupported arguments and log a warning if found. + unsupported_args = [clean, fields, omit, unwind, skip_hidden] + invalid = [arg for arg in unsupported_args if arg not in (False, None)] + if invalid: + logger.warning( + f'The arguments {invalid} of iterate_items are not supported by the {self.__class__.__name__} client.' + ) + + # If the dataset directory does not exist, log a warning and return immediately. + if not self._path_to_dataset.exists(): + logger.warning(f'Dataset directory not found: {self._path_to_dataset}') + return + + # Get the list of sorted data files. + data_files = await self._get_sorted_data_files() + + # Reverse the order if descending order is requested. + if desc: + data_files.reverse() + + # Apply offset and limit slicing. + selected_files = data_files[offset:] + if limit is not None: + selected_files = selected_files[:limit] + + # Iterate over each data file, reading and yielding its parsed content. + for file_path in selected_files: + try: + file_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + item = json.loads(file_content) + except Exception: + logger.exception(f'Error reading {file_path}, skipping the item.') + continue + + # Skip empty items if requested. + if skip_empty and not item: + continue + + yield item + + await self._update_metadata(update_accessed_at=True) + + async def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the dataset metadata file with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + metadata = DatasetMetadata( + id=self._id, + name=self._name, + created_at=self._created_at, + accessed_at=now if update_accessed_at else self._accessed_at, + modified_at=now if update_modified_at else self._modified_at, + item_count=self._item_count, + ) + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self._path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(metadata.model_dump()) + await asyncio.to_thread(self._path_to_metadata.write_text, data, encoding='utf-8') + + async def _push_item(self, item: dict[str, Any]) -> None: + """Push a single item to the dataset. + + This method increments the item count, writes the item as a JSON file with a zero-padded filename, + and updates the metadata. + """ + # Acquire the lock to perform file operations safely. + async with self._lock: + self._item_count += 1 + # Generate the filename for the new item using zero-padded numbering. + filename = f'{str(self._item_count).zfill(self._LOCAL_ENTRY_NAME_DIGITS)}.json' + file_path = self._path_to_dataset / filename + + # Ensure the dataset directory exists. + await asyncio.to_thread(self._path_to_dataset.mkdir, parents=True, exist_ok=True) + + # Dump the serialized item to the file. + data = await json_dumps(item) + await asyncio.to_thread(file_path.write_text, data, encoding='utf-8') + + async def _get_sorted_data_files(self) -> list[Path]: + """Retrieve and return a sorted list of data files in the dataset directory. + + The files are sorted numerically based on the filename (without extension). + The metadata file is excluded. + """ + # Retrieve and sort all JSON files in the dataset directory numerically. + files = await asyncio.to_thread( + sorted, + self._path_to_dataset.glob('*.json'), + key=lambda f: int(f.stem) if f.stem.isdigit() else 0, + ) + + # Remove the metadata file from the list if present. + if self._path_to_metadata in files: + files.remove(self._path_to_metadata) + + return files diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store.py b/src/crawlee/storage_clients/_file_system/_key_value_store.py new file mode 100644 index 0000000000..8bf0815b3a --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_key_value_store.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from logging import getLogger + +from crawlee.storage_clients._base import KeyValueStoreClient + +logger = getLogger(__name__) + + +class FileSystemKeyValueStoreClient(KeyValueStoreClient): + pass diff --git a/src/crawlee/storage_clients/_file_system/_request_queue.py b/src/crawlee/storage_clients/_file_system/_request_queue.py new file mode 100644 index 0000000000..f8a6bfe88e --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_request_queue.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from logging import getLogger + +from crawlee.storage_clients._base import RequestQueueClient + +logger = getLogger(__name__) + + +class FileSystemRequestQueueClient(RequestQueueClient): + pass diff --git a/src/crawlee/storage_clients/_file_system/_storage_client.py b/src/crawlee/storage_clients/_file_system/_storage_client.py new file mode 100644 index 0000000000..248d07b6f6 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_storage_client.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import FileSystemDatasetClient +from ._key_value_store import FileSystemKeyValueStoreClient +from ._request_queue import FileSystemRequestQueueClient + +file_system_storage_client = StorageClient( + dataset_client_class=FileSystemDatasetClient, + key_value_store_client_class=FileSystemKeyValueStoreClient, + request_queue_client_class=FileSystemRequestQueueClient, +) diff --git a/src/crawlee/storage_clients/_file_system/_utils.py b/src/crawlee/storage_clients/_file_system/_utils.py new file mode 100644 index 0000000000..5ad9121448 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_utils.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import asyncio +import json +from logging import getLogger +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + +logger = getLogger(__name__) + + +async def json_dumps(obj: Any) -> str: + """Serialize an object to a JSON-formatted string with specific settings. + + Args: + obj: The object to serialize. + + Returns: + A string containing the JSON representation of the input object. + """ + return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) diff --git a/src/crawlee/storage_clients/_file_system/py.typed b/src/crawlee/storage_clients/_file_system/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/crawlee/storage_clients/_memory/__init__.py b/src/crawlee/storage_clients/_memory/__init__.py index 355797673d..2463d516c2 100644 --- a/src/crawlee/storage_clients/_memory/__init__.py +++ b/src/crawlee/storage_clients/_memory/__init__.py @@ -1,11 +1,3 @@ -from ._dataset_client import DatasetClient -from ._key_value_store_client import KeyValueStoreClient -from ._memory_storage_client import MemoryStorageClient -from ._request_queue_client import RequestQueueClient +from ._storage_client import memory_storage_client -__all__ = [ - 'DatasetClient', - 'KeyValueStoreClient', - 'MemoryStorageClient', - 'RequestQueueClient', -] +__all__ = ['memory_storage_client'] diff --git a/src/crawlee/storage_clients/_memory/_creation_management.py b/src/crawlee/storage_clients/_memory/_creation_management.py deleted file mode 100644 index f6d4fc1c91..0000000000 --- a/src/crawlee/storage_clients/_memory/_creation_management.py +++ /dev/null @@ -1,429 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import mimetypes -import os -import pathlib -from datetime import datetime, timezone -from logging import getLogger -from typing import TYPE_CHECKING - -from crawlee._consts import METADATA_FILENAME -from crawlee._utils.data_processing import maybe_parse_body -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import ( - DatasetMetadata, - InternalRequest, - KeyValueStoreMetadata, - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, - RequestQueueMetadata, -) - -if TYPE_CHECKING: - from ._dataset_client import DatasetClient - from ._key_value_store_client import KeyValueStoreClient - from ._memory_storage_client import MemoryStorageClient, TResourceClient - from ._request_queue_client import RequestQueueClient - -logger = getLogger(__name__) - - -async def persist_metadata_if_enabled(*, data: dict, entity_directory: str, write_metadata: bool) -> None: - """Update or writes metadata to a specified directory. - - The function writes a given metadata dictionary to a JSON file within a specified directory. - The writing process is skipped if `write_metadata` is False. Before writing, it ensures that - the target directory exists, creating it if necessary. - - Args: - data: A dictionary containing metadata to be written. - entity_directory: The directory path where the metadata file should be stored. - write_metadata: A boolean flag indicating whether the metadata should be written to file. - """ - # Skip metadata write; ensure directory exists first - if not write_metadata: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Write the metadata to the file - file_path = os.path.join(entity_directory, METADATA_FILENAME) - f = await asyncio.to_thread(open, file_path, mode='wb') - try: - s = await json_dumps(data) - await asyncio.to_thread(f.write, s.encode('utf-8')) - finally: - await asyncio.to_thread(f.close) - - -def find_or_create_client_by_id_or_name_inner( - resource_client_class: type[TResourceClient], - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> TResourceClient | None: - """Locate or create a new storage client based on the given ID or name. - - This method attempts to find a storage client in the memory cache first. If not found, - it tries to locate a storage directory by name. If still not found, it searches through - storage directories for a matching ID or name in their metadata. If none exists, and the - specified ID is 'default', it checks for a default storage directory. If a storage client - is found or created, it is added to the memory cache. If no storage client can be located or - created, the method returns None. - - Args: - resource_client_class: The class of the resource client. - memory_storage_client: The memory storage client used to store and retrieve storage clients. - id: The unique identifier for the storage client. - name: The name of the storage client. - - Raises: - ValueError: If both id and name are None. - - Returns: - The found or created storage client, or None if no client could be found or created. - """ - from ._dataset_client import DatasetClient - from ._key_value_store_client import KeyValueStoreClient - from ._request_queue_client import RequestQueueClient - - if id is None and name is None: - raise ValueError('Either id or name must be specified.') - - # First check memory cache - found = memory_storage_client.get_cached_resource_client(resource_client_class, id, name) - - if found is not None: - return found - - storage_path = _determine_storage_path(resource_client_class, memory_storage_client, id, name) - - if not storage_path: - return None - - # Create from directory if storage path is found - if issubclass(resource_client_class, DatasetClient): - resource_client = create_dataset_from_directory(storage_path, memory_storage_client, id, name) - elif issubclass(resource_client_class, KeyValueStoreClient): - resource_client = create_kvs_from_directory(storage_path, memory_storage_client, id, name) - elif issubclass(resource_client_class, RequestQueueClient): - resource_client = create_rq_from_directory(storage_path, memory_storage_client, id, name) - else: - raise TypeError('Invalid resource client class.') - - memory_storage_client.add_resource_client_to_cache(resource_client) - return resource_client - - -async def get_or_create_inner( - *, - memory_storage_client: MemoryStorageClient, - storage_client_cache: list[TResourceClient], - resource_client_class: type[TResourceClient], - name: str | None = None, - id: str | None = None, -) -> TResourceClient: - """Retrieve a named storage, or create a new one when it doesn't exist. - - Args: - memory_storage_client: The memory storage client. - storage_client_cache: The cache of storage clients. - resource_client_class: The class of the storage to retrieve or create. - name: The name of the storage to retrieve or create. - id: ID of the storage to retrieve or create. - - Returns: - The retrieved or newly-created storage. - """ - # If the name or id is provided, try to find the dataset in the cache - if name or id: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=resource_client_class, - memory_storage_client=memory_storage_client, - name=name, - id=id, - ) - if found: - return found - - # Otherwise, create a new one and add it to the cache - resource_client = resource_client_class( - id=id, - name=name, - memory_storage_client=memory_storage_client, - ) - - storage_client_cache.append(resource_client) - - # Write to the disk - await persist_metadata_if_enabled( - data=resource_client.resource_info.model_dump(), - entity_directory=resource_client.resource_directory, - write_metadata=memory_storage_client.write_metadata, - ) - - return resource_client - - -def create_dataset_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> DatasetClient: - from ._dataset_client import DatasetClient - - item_count = 0 - has_seen_metadata_file = False - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - has_seen_metadata_file = True - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = DatasetMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - item_count = resource_info.item_count - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - - # Load dataset entries - entries: dict[str, dict] = {} - - for entry in os.scandir(storage_directory): - if entry.is_file(): - if entry.name == METADATA_FILENAME: - has_seen_metadata_file = True - continue - - with open(os.path.join(storage_directory, entry.name), encoding='utf-8') as f: - entry_content = json.load(f) - - entry_name = entry.name.split('.')[0] - entries[entry_name] = entry_content - - if not has_seen_metadata_file: - item_count += 1 - - # Create new dataset client - new_client = DatasetClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - created_at=created_at, - accessed_at=accessed_at, - modified_at=modified_at, - item_count=item_count, - ) - - new_client.dataset_entries.update(entries) - return new_client - - -def create_kvs_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> KeyValueStoreClient: - from ._key_value_store_client import KeyValueStoreClient - - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = KeyValueStoreMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - - # Create new KVS client - new_client = KeyValueStoreClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - accessed_at=accessed_at, - created_at=created_at, - modified_at=modified_at, - ) - - # Scan the KVS folder, check each entry in there and parse it as a store record - for entry in os.scandir(storage_directory): - if not entry.is_file(): - continue - - # Ignore metadata files on their own - if entry.name.endswith(METADATA_FILENAME): - continue - - # Try checking if this file has a metadata file associated with it - record_metadata = None - record_metadata_filepath = os.path.join(storage_directory, f'{entry.name}.__metadata__.json') - - if os.path.exists(record_metadata_filepath): - with open(record_metadata_filepath, encoding='utf-8') as metadata_file: - try: - json_content = json.load(metadata_file) - record_metadata = KeyValueStoreRecordMetadata(**json_content) - - except Exception: - logger.warning( - f'Metadata of key-value store entry "{entry.name}" for store {name or id} could ' - 'not be parsed. The metadata file will be ignored.', - exc_info=True, - ) - - if not record_metadata: - content_type, _ = mimetypes.guess_type(entry.name) - if content_type is None: - content_type = 'application/octet-stream' - - record_metadata = KeyValueStoreRecordMetadata( - key=pathlib.Path(entry.name).stem, - content_type=content_type, - ) - - with open(os.path.join(storage_directory, entry.name), 'rb') as f: - file_content = f.read() - - try: - maybe_parse_body(file_content, record_metadata.content_type) - except Exception: - record_metadata.content_type = 'application/octet-stream' - logger.warning( - f'Key-value store entry "{record_metadata.key}" for store {name or id} could not be parsed.' - 'The entry will be assumed as binary.', - exc_info=True, - ) - - new_client.records[record_metadata.key] = KeyValueStoreRecord( - key=record_metadata.key, - content_type=record_metadata.content_type, - filename=entry.name, - value=file_content, - ) - - return new_client - - -def create_rq_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> RequestQueueClient: - from ._request_queue_client import RequestQueueClient - - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - handled_request_count = 0 - pending_request_count = 0 - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = RequestQueueMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - handled_request_count = resource_info.handled_request_count - pending_request_count = resource_info.pending_request_count - - # Load request entries - entries: dict[str, InternalRequest] = {} - - for entry in os.scandir(storage_directory): - if entry.is_file(): - if entry.name == METADATA_FILENAME: - continue - - with open(os.path.join(storage_directory, entry.name), encoding='utf-8') as f: - content = json.load(f) - - request = InternalRequest(**content) - - entries[request.id] = request - - # Create new RQ client - new_client = RequestQueueClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - accessed_at=accessed_at, - created_at=created_at, - modified_at=modified_at, - handled_request_count=handled_request_count, - pending_request_count=pending_request_count, - ) - - new_client.requests.update(entries) - return new_client - - -def _determine_storage_path( - resource_client_class: type[TResourceClient], - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> str | None: - storages_dir = memory_storage_client._get_storage_dir(resource_client_class) # noqa: SLF001 - default_id = memory_storage_client._get_default_storage_id(resource_client_class) # noqa: SLF001 - - # Try to find by name directly from directories - if name: - possible_storage_path = os.path.join(storages_dir, name) - if os.access(possible_storage_path, os.F_OK): - return possible_storage_path - - # If not found, try finding by metadata - if os.access(storages_dir, os.F_OK): - for entry in os.scandir(storages_dir): - if entry.is_dir(): - metadata_path = os.path.join(entry.path, METADATA_FILENAME) - if os.access(metadata_path, os.F_OK): - with open(metadata_path, encoding='utf-8') as metadata_file: - try: - metadata = json.load(metadata_file) - if (id and metadata.get('id') == id) or (name and metadata.get('name') == name): - return entry.path - except Exception: - logger.warning( - f'Metadata of store entry "{entry.name}" for store {name or id} could not be parsed. ' - 'The metadata file will be ignored.', - exc_info=True, - ) - - # Check for default storage directory as a last resort - if id == default_id: - possible_storage_path = os.path.join(storages_dir, default_id) - if os.access(possible_storage_path, os.F_OK): - return possible_storage_path - - return None diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 40ad5c2b13..279be563c9 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -1,158 +1,126 @@ from __future__ import annotations -import asyncio -import json -import os -import shutil from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any from typing_extensions import override -from crawlee._types import StorageTypes from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import raise_on_non_existing_storage -from crawlee._utils.file import json_dumps -from crawlee.storage_clients._base import DatasetClient as BaseDatasetClient -from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata - -from ._creation_management import find_or_create_client_by_id_or_name_inner +from crawlee.storage_clients._base import DatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage if TYPE_CHECKING: from collections.abc import AsyncIterator - from contextlib import AbstractAsyncContextManager - - from httpx import Response - - from crawlee._types import JsonSerializable - from crawlee.storage_clients import MemoryStorageClient + from pathlib import Path logger = getLogger(__name__) -class DatasetClient(BaseDatasetClient): - """Subclient for manipulating a single dataset.""" +class MemoryDatasetClient(DatasetClient): + """A memory implementation of the dataset client. - _LIST_ITEMS_LIMIT = 999_999_999_999 - """This is what API returns in the x-apify-pagination-limit header when no limit query parameter is used.""" + This client stores dataset items in memory using a dictionary. + No data is persisted to the file system. + """ - _LOCAL_ENTRY_NAME_DIGITS = 9 - """Number of characters of the dataset item file names, e.g.: 000000019.json - 9 digits.""" + _DEFAULT_NAME = 'default' + """The default name for the dataset when no name is provided.""" def __init__( self, *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, - item_count: int = 0, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + item_count: int, ) -> None: - self._memory_storage_client = memory_storage_client - self.id = id or crypto_random_object_id() - self.name = name - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) + """Initialize a new instance of the memory-only dataset client. - self.dataset_entries: dict[str, dict] = {} - self.file_operation_lock = asyncio.Lock() - self.item_count = item_count + Preferably use the `MemoryDatasetClient.open` class method to create a new instance. + """ + self._id = id + self._name = name + self._created_at = created_at + self._accessed_at = accessed_at + self._modified_at = modified_at + self._item_count = item_count - @property - def resource_info(self) -> DatasetMetadata: - """Get the resource info for the dataset client.""" - return DatasetMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - item_count=self.item_count, - ) + # Dictionary to hold dataset items; keys are zero-padded strings. + self._records = list[dict[str, Any]]() + @override @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.datasets_directory, self.name or self.id) + def id(self) -> str: + return self._id @override - async def get(self) -> DatasetMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info + @property + def name(self) -> str | None: + return self._name - return None + @override + @property + def created_at(self) -> datetime: + return self._created_at @override - async def delete(self) -> None: - dataset = next( - (dataset for dataset in self._memory_storage_client.datasets_handled if dataset.id == self.id), None - ) + @property + def accessed_at(self) -> datetime: + return self._accessed_at - if dataset is not None: - async with dataset.file_operation_lock: - self._memory_storage_client.datasets_handled.remove(dataset) - dataset.item_count = 0 - dataset.dataset_entries.clear() + @override + @property + def modified_at(self) -> datetime: + return self._modified_at - if os.path.exists(dataset.resource_directory): - await asyncio.to_thread(shutil.rmtree, dataset.resource_directory) + @override + @property + def item_count(self) -> int: + return self._item_count @override - async def push_items( - self, - items: JsonSerializable, - ) -> None: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + @classmethod + async def open( + cls, + id: str | None, + name: str | None, + storage_dir: Path, # Ignored in the memory-only implementation. + ) -> MemoryDatasetClient: + name = name or cls._DEFAULT_NAME + dataset_id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + return cls( + id=dataset_id, + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + item_count=0, ) - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) - - normalized = self._normalize_items(items) - - added_ids: list[str] = [] - for entry in normalized: - existing_dataset_by_id.item_count += 1 - idx = self._generate_local_entry_name(existing_dataset_by_id.item_count) - - existing_dataset_by_id.dataset_entries[idx] = entry - added_ids.append(idx) - - data_entries = [(id, existing_dataset_by_id.dataset_entries[id]) for id in added_ids] - - async with existing_dataset_by_id.file_operation_lock: - await existing_dataset_by_id.update_timestamps(has_been_modified=True) + @override + async def drop(self) -> None: + self._records.clear() + self._item_count = 0 - await self._persist_dataset_items_to_disk( - data=data_entries, - entity_directory=existing_dataset_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, - ) + @override + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + if isinstance(data, list): + for item in data: + await self._push_item(item) + else: + await self._push_item(data) + await self._update_metadata(update_accessed_at=True, update_modified_at=True) @override - async def list_items( + async def get_data( self, *, - offset: int | None = 0, - limit: int | None = _LIST_ITEMS_LIMIT, + offset: int = 0, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -163,47 +131,31 @@ async def list_items( flatten: list[str] | None = None, view: str | None = None, ) -> DatasetItemsListPage: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) - - async with existing_dataset_by_id.file_operation_lock: - start, end = existing_dataset_by_id.get_start_and_end_indexes( - max(existing_dataset_by_id.item_count - (offset or 0) - (limit or self._LIST_ITEMS_LIMIT), 0) - if desc - else offset or 0, - limit, + unsupported_args = [clean, fields, omit, unwind, skip_hidden, flatten, view] + invalid = [arg for arg in unsupported_args if arg not in (False, None)] + if invalid: + logger.warning( + f'The arguments {invalid} of iterate_items are not supported by the {self.__class__.__name__} client.' ) - items = [] - - for idx in range(start, end): - entry_number = self._generate_local_entry_name(idx) - items.append(existing_dataset_by_id.dataset_entries[entry_number]) - - await existing_dataset_by_id.update_timestamps(has_been_modified=False) - - if desc: - items.reverse() - - return DatasetItemsListPage( - count=len(items), - desc=desc or False, - items=items, - limit=limit or self._LIST_ITEMS_LIMIT, - offset=offset or 0, - total=existing_dataset_by_id.item_count, - ) + total = len(self._records) + items = self._records.copy() + if desc: + items = list(reversed(items)) + + sliced_items = items[offset : (offset + limit) if limit is not None else total] + await self._update_metadata(update_accessed_at=True) + return DatasetItemsListPage( + count=len(sliced_items), + offset=offset, + limit=limit or (total - offset), + total=total, + desc=desc, + items=sliced_items, + ) @override - async def iterate_items( + async def iterate( self, *, offset: int = 0, @@ -216,154 +168,44 @@ async def iterate_items( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict]: - cache_size = 1000 - first_item = offset - - # If there is no limit, set last_item to None until we get the total from the first API response - last_item = None if limit is None else offset + limit - current_offset = first_item - - while last_item is None or current_offset < last_item: - current_limit = cache_size if last_item is None else min(cache_size, last_item - current_offset) - - current_items_page = await self.list_items( - offset=current_offset, - limit=current_limit, - desc=desc, + unsupported_args = [clean, fields, omit, unwind, skip_hidden] + invalid = [arg for arg in unsupported_args if arg not in (False, None)] + if invalid: + logger.warning( + f'The arguments {invalid} of iterate_items are not supported by the {self.__class__.__name__} client.' ) - current_offset += current_items_page.count - if last_item is None or current_items_page.total < last_item: - last_item = current_items_page.total + items = self._records.copy() + if desc: + items = list(reversed(items)) - for item in current_items_page.items: - yield item + sliced_items = items[offset : (offset + limit) if limit is not None else len(items)] + for item in sliced_items: + if skip_empty and not item: + continue + yield item - @override - async def get_items_as_bytes( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - flatten: list[str] | None = None, - ) -> bytes: - raise NotImplementedError('This method is not supported in memory storage.') - - @override - async def stream_items( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - ) -> AbstractAsyncContextManager[Response | None]: - raise NotImplementedError('This method is not supported in memory storage.') + await self._update_metadata(update_accessed_at=True) - async def _persist_dataset_items_to_disk( + async def _update_metadata( self, *, - data: list[tuple[str, dict]], - entity_directory: str, - persist_storage: bool, + update_accessed_at: bool = False, + update_modified_at: bool = False, ) -> None: - """Write dataset items to the disk. - - The function iterates over a list of dataset items, each represented as a tuple of an identifier - and a dictionary, and writes them as individual JSON files in a specified directory. The function - will skip writing if `persist_storage` is False. Before writing, it ensures that the target - directory exists, creating it if necessary. + """Update the dataset metadata file with current information. Args: - data: A list of tuples, each containing an identifier (string) and a data dictionary. - entity_directory: The directory path where the dataset items should be stored. - persist_storage: A boolean flag indicating whether the data should be persisted to the disk. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. """ - # Skip writing files to the disk if the client has the option set to false - if not persist_storage: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Save all the new items to the disk - for idx, item in data: - file_path = os.path.join(entity_directory, f'{idx}.json') - f = await asyncio.to_thread(open, file_path, mode='w', encoding='utf-8') - try: - s = await json_dumps(item) - await asyncio.to_thread(f.write, s) - finally: - await asyncio.to_thread(f.close) - - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the dataset.""" - from ._creation_management import persist_metadata_if_enabled - - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) - - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) - - def get_start_and_end_indexes(self, offset: int, limit: int | None = None) -> tuple[int, int]: - """Calculate the start and end indexes for listing items.""" - actual_limit = limit or self.item_count - start = offset + 1 - end = min(offset + actual_limit, self.item_count) + 1 - return (start, end) - - def _generate_local_entry_name(self, idx: int) -> str: - return str(idx).zfill(self._LOCAL_ENTRY_NAME_DIGITS) - - def _normalize_items(self, items: JsonSerializable) -> list[dict]: - def normalize_item(item: Any) -> dict | None: - if isinstance(item, str): - item = json.loads(item) - - if isinstance(item, list): - received = ',\n'.join(item) - raise TypeError( - f'Each dataset item can only be a single JSON object, not an array. Received: [{received}]' - ) - - if (not isinstance(item, dict)) and item is not None: - raise TypeError(f'Each dataset item must be a JSON object. Received: {item}') - - return item - - if isinstance(items, str): - items = json.loads(items) - - result = list(map(normalize_item, items)) if isinstance(items, list) else [normalize_item(items)] - # filter(None, ..) returns items that are True - return list(filter(None, result)) + now = datetime.now(timezone.utc) + if update_accessed_at: + self._accessed_at = now + if update_modified_at: + self._modified_at = now + + async def _push_item(self, item: dict[str, Any]) -> None: + """Push a single item to the dataset.""" + self._item_count += 1 + self._records.append(item) diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index e7f18fb175..0dd8d9a0a5 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -1,384 +1,11 @@ from __future__ import annotations -import asyncio -import io -import os -import shutil -from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any -from typing_extensions import override - -from crawlee._types import StorageTypes -from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import maybe_parse_body, raise_on_non_existing_storage -from crawlee._utils.file import determine_file_extension, force_remove, is_file_or_bytes, json_dumps -from crawlee.storage_clients._base import KeyValueStoreClient as BaseKeyValueStoreClient -from crawlee.storage_clients.models import ( - KeyValueStoreKeyInfo, - KeyValueStoreListKeysPage, - KeyValueStoreMetadata, - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, -) - -from ._creation_management import find_or_create_client_by_id_or_name_inner, persist_metadata_if_enabled - -if TYPE_CHECKING: - from contextlib import AbstractAsyncContextManager - - from httpx import Response - - from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._base import KeyValueStoreClient logger = getLogger(__name__) -class KeyValueStoreClient(BaseKeyValueStoreClient): - """Subclient for manipulating a single key-value store.""" - - def __init__( - self, - *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, - ) -> None: - self.id = id or crypto_random_object_id() - self.name = name - - self._memory_storage_client = memory_storage_client - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) - - self.records: dict[str, KeyValueStoreRecord] = {} - self.file_operation_lock = asyncio.Lock() - - @property - def resource_info(self) -> KeyValueStoreMetadata: - """Get the resource info for the key-value store client.""" - return KeyValueStoreMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - user_id='1', - ) - - @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.key_value_stores_directory, self.name or self.id) - - @override - async def get(self) -> KeyValueStoreMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info - - return None - - @override - async def delete(self) -> None: - store = next( - (store for store in self._memory_storage_client.key_value_stores_handled if store.id == self.id), None - ) - - if store is not None: - async with store.file_operation_lock: - self._memory_storage_client.key_value_stores_handled.remove(store) - store.records.clear() - - if os.path.exists(store.resource_directory): - await asyncio.to_thread(shutil.rmtree, store.resource_directory) - - @override - async def list_keys( - self, - *, - limit: int = 1000, - exclusive_start_key: str | None = None, - ) -> KeyValueStoreListKeysPage: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - items: list[KeyValueStoreKeyInfo] = [] - - for record in existing_store_by_id.records.values(): - size = len(record.value) - items.append(KeyValueStoreKeyInfo(key=record.key, size=size)) - - if len(items) == 0: - return KeyValueStoreListKeysPage( - count=len(items), - limit=limit, - exclusive_start_key=exclusive_start_key, - is_truncated=False, - next_exclusive_start_key=None, - items=items, - ) - - # Lexically sort to emulate the API - items = sorted(items, key=lambda item: item.key) - - truncated_items = items - if exclusive_start_key is not None: - key_pos = next((idx for idx, item in enumerate(items) if item.key == exclusive_start_key), None) - if key_pos is not None: - truncated_items = items[(key_pos + 1) :] - - limited_items = truncated_items[:limit] - - last_item_in_store = items[-1] - last_selected_item = limited_items[-1] - is_last_selected_item_absolutely_last = last_item_in_store == last_selected_item - next_exclusive_start_key = None if is_last_selected_item_absolutely_last else last_selected_item.key - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=False) - - return KeyValueStoreListKeysPage( - count=len(items), - limit=limit, - exclusive_start_key=exclusive_start_key, - is_truncated=not is_last_selected_item_absolutely_last, - next_exclusive_start_key=next_exclusive_start_key, - items=limited_items, - ) - - @override - async def get_record(self, key: str) -> KeyValueStoreRecord | None: - return await self._get_record_internal(key) - - @override - async def get_record_as_bytes(self, key: str) -> KeyValueStoreRecord[bytes] | None: - return await self._get_record_internal(key, as_bytes=True) - - @override - async def stream_record(self, key: str) -> AbstractAsyncContextManager[KeyValueStoreRecord[Response] | None]: - raise NotImplementedError('This method is not supported in memory storage.') - - @override - async def set_record(self, key: str, value: Any, content_type: str | None = None) -> None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - if isinstance(value, io.IOBase): - raise NotImplementedError('File-like values are not supported in local memory storage') - - if content_type is None: - if is_file_or_bytes(value): - content_type = 'application/octet-stream' - elif isinstance(value, str): - content_type = 'text/plain; charset=utf-8' - else: - content_type = 'application/json; charset=utf-8' - - if 'application/json' in content_type and not is_file_or_bytes(value) and not isinstance(value, str): - s = await json_dumps(value) - value = s.encode('utf-8') - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=True) - record = KeyValueStoreRecord(key=key, value=value, content_type=content_type, filename=None) - - old_record = existing_store_by_id.records.get(key) - existing_store_by_id.records[key] = record - - if self._memory_storage_client.persist_storage: - record_filename = self._filename_from_record(record) - record.filename = record_filename - - if old_record is not None and self._filename_from_record(old_record) != record_filename: - await existing_store_by_id.delete_persisted_record(old_record) - - await existing_store_by_id.persist_record(record) - - @override - async def delete_record(self, key: str) -> None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - record = existing_store_by_id.records.get(key) - - if record is not None: - async with existing_store_by_id.file_operation_lock: - del existing_store_by_id.records[key] - await existing_store_by_id.update_timestamps(has_been_modified=True) - if self._memory_storage_client.persist_storage: - await existing_store_by_id.delete_persisted_record(record) - - @override - async def get_public_url(self, key: str) -> str: - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - record = await self._get_record_internal(key) - - if not record: - raise ValueError(f'Record with key "{key}" was not found.') - - resource_dir = existing_store_by_id.resource_directory - record_filename = self._filename_from_record(record) - record_path = os.path.join(resource_dir, record_filename) - return f'file://{record_path}' - - async def persist_record(self, record: KeyValueStoreRecord) -> None: - """Persist the specified record to the key-value store.""" - store_directory = self.resource_directory - record_filename = self._filename_from_record(record) - record.filename = record_filename - record.content_type = record.content_type or 'application/octet-stream' - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, store_directory, exist_ok=True) - - # Create files for the record - record_path = os.path.join(store_directory, record_filename) - record_metadata_path = os.path.join(store_directory, f'{record_filename}.__metadata__.json') - - # Convert to bytes if string - if isinstance(record.value, str): - record.value = record.value.encode('utf-8') - - f = await asyncio.to_thread(open, record_path, mode='wb') - try: - await asyncio.to_thread(f.write, record.value) - finally: - await asyncio.to_thread(f.close) - - if self._memory_storage_client.write_metadata: - metadata_f = await asyncio.to_thread(open, record_metadata_path, mode='wb') - - try: - record_metadata = KeyValueStoreRecordMetadata(key=record.key, content_type=record.content_type) - await asyncio.to_thread(metadata_f.write, record_metadata.model_dump_json(indent=2).encode('utf-8')) - finally: - await asyncio.to_thread(metadata_f.close) - - async def delete_persisted_record(self, record: KeyValueStoreRecord) -> None: - """Delete the specified record from the key-value store.""" - store_directory = self.resource_directory - record_filename = self._filename_from_record(record) - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, store_directory, exist_ok=True) - - # Create files for the record - record_path = os.path.join(store_directory, record_filename) - record_metadata_path = os.path.join(store_directory, record_filename + '.__metadata__.json') - - await force_remove(record_path) - await force_remove(record_metadata_path) - - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the key-value store.""" - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) - - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) - - async def _get_record_internal( - self, - key: str, - *, - as_bytes: bool = False, - ) -> KeyValueStoreRecord | None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - stored_record = existing_store_by_id.records.get(key) - - if stored_record is None: - return None - - record = KeyValueStoreRecord( - key=stored_record.key, - value=stored_record.value, - content_type=stored_record.content_type, - filename=stored_record.filename, - ) - - if not as_bytes: - try: - record.value = maybe_parse_body(record.value, str(record.content_type)) - except ValueError: - logger.exception('Error parsing key-value store record') - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=False) - - return record - - def _filename_from_record(self, record: KeyValueStoreRecord) -> str: - if record.filename is not None: - return record.filename - - if not record.content_type or record.content_type == 'application/octet-stream': - return record.key - - extension = determine_file_extension(record.content_type) - - if record.key.endswith(f'.{extension}'): - return record.key - - return f'{record.key}.{extension}' +class MemoryKeyValueStoreClient(KeyValueStoreClient): + pass diff --git a/src/crawlee/storage_clients/_memory/_memory_storage_client.py b/src/crawlee/storage_clients/_memory/_memory_storage_client.py deleted file mode 100644 index ec60f51145..0000000000 --- a/src/crawlee/storage_clients/_memory/_memory_storage_client.py +++ /dev/null @@ -1,343 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import os -import shutil -from logging import getLogger -from pathlib import Path -from typing import TYPE_CHECKING, TypeVar - -from typing_extensions import override - -from crawlee._utils.docs import docs_group -from crawlee.configuration import Configuration -from crawlee.storage_clients import StorageClient - -from ._dataset_client import DatasetClient -from ._key_value_store_client import KeyValueStoreClient -from ._request_queue_client import RequestQueueClient - -if TYPE_CHECKING: - from crawlee.storage_clients._base import ResourceClient - - -TResourceClient = TypeVar('TResourceClient', DatasetClient, KeyValueStoreClient, RequestQueueClient) - -logger = getLogger(__name__) - - -@docs_group('Classes') -class MemoryStorageClient(StorageClient): - """Represents an in-memory storage client for managing datasets, key-value stores, and request queues. - - It emulates in-memory storage similar to the Apify platform, supporting both in-memory and local file system-based - persistence. - - The behavior of the storage, such as data persistence and metadata writing, can be customized via initialization - parameters or environment variables. - """ - - _MIGRATING_KEY_VALUE_STORE_DIR_NAME = '__CRAWLEE_MIGRATING_KEY_VALUE_STORE' - """Name of the directory used to temporarily store files during the migration of the default key-value store.""" - - _TEMPORARY_DIR_NAME = '__CRAWLEE_TEMPORARY' - """Name of the directory used to temporarily store files during purges.""" - - _DATASETS_DIR_NAME = 'datasets' - """Name of the directory containing datasets.""" - - _KEY_VALUE_STORES_DIR_NAME = 'key_value_stores' - """Name of the directory containing key-value stores.""" - - _REQUEST_QUEUES_DIR_NAME = 'request_queues' - """Name of the directory containing request queues.""" - - def __init__( - self, - *, - write_metadata: bool, - persist_storage: bool, - storage_dir: str, - default_request_queue_id: str, - default_key_value_store_id: str, - default_dataset_id: str, - ) -> None: - """Initialize a new instance. - - In most cases, you should use the `from_config` constructor to create a new instance based on - the provided configuration. - - Args: - write_metadata: Whether to write metadata to the storage. - persist_storage: Whether to persist the storage. - storage_dir: Path to the storage directory. - default_request_queue_id: The default request queue ID. - default_key_value_store_id: The default key-value store ID. - default_dataset_id: The default dataset ID. - """ - # Set the internal attributes. - self._write_metadata = write_metadata - self._persist_storage = persist_storage - self._storage_dir = storage_dir - self._default_request_queue_id = default_request_queue_id - self._default_key_value_store_id = default_key_value_store_id - self._default_dataset_id = default_dataset_id - - self.datasets_handled: list[DatasetClient] = [] - self.key_value_stores_handled: list[KeyValueStoreClient] = [] - self.request_queues_handled: list[RequestQueueClient] = [] - - self._purged_on_start = False # Indicates whether a purge was already performed on this instance. - self._purge_lock = asyncio.Lock() - - @classmethod - def from_config(cls, config: Configuration | None = None) -> MemoryStorageClient: - """Initialize a new instance based on the provided `Configuration`. - - Args: - config: The `Configuration` instance. Uses the global (default) one if not provided. - """ - config = config or Configuration.get_global_configuration() - - return cls( - write_metadata=config.write_metadata, - persist_storage=config.persist_storage, - storage_dir=config.storage_dir, - default_request_queue_id=config.default_request_queue_id, - default_key_value_store_id=config.default_key_value_store_id, - default_dataset_id=config.default_dataset_id, - ) - - @property - def write_metadata(self) -> bool: - """Whether to write metadata to the storage.""" - return self._write_metadata - - @property - def persist_storage(self) -> bool: - """Whether to persist the storage.""" - return self._persist_storage - - @property - def storage_dir(self) -> str: - """Path to the storage directory.""" - return self._storage_dir - - @property - def datasets_directory(self) -> str: - """Path to the directory containing datasets.""" - return os.path.join(self.storage_dir, self._DATASETS_DIR_NAME) - - @property - def key_value_stores_directory(self) -> str: - """Path to the directory containing key-value stores.""" - return os.path.join(self.storage_dir, self._KEY_VALUE_STORES_DIR_NAME) - - @property - def request_queues_directory(self) -> str: - """Path to the directory containing request queues.""" - return os.path.join(self.storage_dir, self._REQUEST_QUEUES_DIR_NAME) - - @override - def dataset(self, id: str) -> DatasetClient: - return DatasetClient(memory_storage_client=self, id=id) - - @override - def key_value_store(self, id: str) -> KeyValueStoreClient: - return KeyValueStoreClient(memory_storage_client=self, id=id) - - @override - def request_queue(self, id: str) -> RequestQueueClient: - return RequestQueueClient(memory_storage_client=self, id=id) - - @override - async def purge_on_start(self) -> None: - # Optimistic, non-blocking check - if self._purged_on_start is True: - logger.debug('Storage was already purged on start.') - return - - async with self._purge_lock: - # Another check under the lock just to be sure - if self._purged_on_start is True: - # Mypy doesn't understand that the _purged_on_start can change while we're getting the async lock - return # type: ignore[unreachable] - - await self._purge_default_storages() - self._purged_on_start = True - - def get_cached_resource_client( - self, - resource_client_class: type[TResourceClient], - id: str | None, - name: str | None, - ) -> TResourceClient | None: - """Try to return a resource client from the internal cache.""" - if issubclass(resource_client_class, DatasetClient): - cache = self.datasets_handled - elif issubclass(resource_client_class, KeyValueStoreClient): - cache = self.key_value_stores_handled - elif issubclass(resource_client_class, RequestQueueClient): - cache = self.request_queues_handled - else: - return None - - for storage_client in cache: - if storage_client.id == id or ( - storage_client.name and name and storage_client.name.lower() == name.lower() - ): - return storage_client - - return None - - def add_resource_client_to_cache(self, resource_client: ResourceClient) -> None: - """Add a new resource client to the internal cache.""" - if isinstance(resource_client, DatasetClient): - self.datasets_handled.append(resource_client) - if isinstance(resource_client, KeyValueStoreClient): - self.key_value_stores_handled.append(resource_client) - if isinstance(resource_client, RequestQueueClient): - self.request_queues_handled.append(resource_client) - - async def _purge_default_storages(self) -> None: - """Clean up the storage directories, preparing the environment for a new run. - - It aims to remove residues from previous executions to avoid data contamination between runs. - - It specifically targets: - - The local directory containing the default dataset. - - All records from the default key-value store in the local directory, except for the 'INPUT' key. - - The local directory containing the default request queue. - """ - # Key-value stores - if await asyncio.to_thread(os.path.exists, self.key_value_stores_directory): - key_value_store_folders = await asyncio.to_thread(os.scandir, self.key_value_stores_directory) - for key_value_store_folder in key_value_store_folders: - if key_value_store_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ) or key_value_store_folder.name.startswith('__OLD'): - await self._batch_remove_files(key_value_store_folder.path) - elif key_value_store_folder.name == self._default_key_value_store_id: - await self._handle_default_key_value_store(key_value_store_folder.path) - - # Datasets - if await asyncio.to_thread(os.path.exists, self.datasets_directory): - dataset_folders = await asyncio.to_thread(os.scandir, self.datasets_directory) - for dataset_folder in dataset_folders: - if dataset_folder.name == self._default_dataset_id or dataset_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ): - await self._batch_remove_files(dataset_folder.path) - - # Request queues - if await asyncio.to_thread(os.path.exists, self.request_queues_directory): - request_queue_folders = await asyncio.to_thread(os.scandir, self.request_queues_directory) - for request_queue_folder in request_queue_folders: - if request_queue_folder.name == self._default_request_queue_id or request_queue_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ): - await self._batch_remove_files(request_queue_folder.path) - - async def _handle_default_key_value_store(self, folder: str) -> None: - """Manage the cleanup of the default key-value store. - - It removes all files to ensure a clean state except for a set of predefined input keys (`possible_input_keys`). - - Args: - folder: Path to the default key-value store directory to clean. - """ - folder_exists = await asyncio.to_thread(os.path.exists, folder) - temporary_path = os.path.normpath(os.path.join(folder, '..', self._MIGRATING_KEY_VALUE_STORE_DIR_NAME)) - - # For optimization, we want to only attempt to copy a few files from the default key-value store - possible_input_keys = [ - 'INPUT', - 'INPUT.json', - 'INPUT.bin', - 'INPUT.txt', - ] - - if folder_exists: - # Create a temporary folder to save important files in - Path(temporary_path).mkdir(parents=True, exist_ok=True) - - # Go through each file and save the ones that are important - for entity in possible_input_keys: - original_file_path = os.path.join(folder, entity) - temp_file_path = os.path.join(temporary_path, entity) - with contextlib.suppress(Exception): - await asyncio.to_thread(os.rename, original_file_path, temp_file_path) - - # Remove the original folder and all its content - counter = 0 - temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) - done = False - try: - while not done: - await asyncio.to_thread(os.rename, folder, temp_path_for_old_folder) - done = True - except Exception: - counter += 1 - temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) - - # Replace the temporary folder with the original folder - await asyncio.to_thread(os.rename, temporary_path, folder) - - # Remove the old folder - await self._batch_remove_files(temp_path_for_old_folder) - - async def _batch_remove_files(self, folder: str, counter: int = 0) -> None: - """Remove a folder and its contents in batches to minimize blocking time. - - This method first renames the target folder to a temporary name, then deletes the temporary folder, - allowing the file system operations to proceed without hindering other asynchronous tasks. - - Args: - folder: The directory path to remove. - counter: A counter used for generating temporary directory names in case of conflicts. - """ - folder_exists = await asyncio.to_thread(os.path.exists, folder) - - if folder_exists: - temporary_folder = ( - folder - if os.path.basename(folder).startswith(f'{self._TEMPORARY_DIR_NAME}_') - else os.path.normpath(os.path.join(folder, '..', f'{self._TEMPORARY_DIR_NAME}_{counter}')) - ) - - try: - # Rename the old folder to the new one to allow background deletions - await asyncio.to_thread(os.rename, folder, temporary_folder) - except Exception: - # Folder exists already, try again with an incremented counter - return await self._batch_remove_files(folder, counter + 1) - - await asyncio.to_thread(shutil.rmtree, temporary_folder, ignore_errors=True) - return None - - def _get_default_storage_id(self, storage_client_class: type[TResourceClient]) -> str: - """Get the default storage ID based on the storage class.""" - if issubclass(storage_client_class, DatasetClient): - return self._default_dataset_id - - if issubclass(storage_client_class, KeyValueStoreClient): - return self._default_key_value_store_id - - if issubclass(storage_client_class, RequestQueueClient): - return self._default_request_queue_id - - raise ValueError(f'Invalid storage class: {storage_client_class.__name__}') - - def _get_storage_dir(self, storage_client_class: type[TResourceClient]) -> str: - """Get the storage directory based on the storage class.""" - if issubclass(storage_client_class, DatasetClient): - return self.datasets_directory - - if issubclass(storage_client_class, KeyValueStoreClient): - return self.key_value_stores_directory - - if issubclass(storage_client_class, RequestQueueClient): - return self.request_queues_directory - - raise ValueError(f'Invalid storage class: {storage_client_class.__name__}') diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index 687260d91d..d31c0602a0 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -1,517 +1,11 @@ from __future__ import annotations -import asyncio -import os -import shutil -from datetime import datetime, timezone -from decimal import Decimal from logging import getLogger -from typing import TYPE_CHECKING -from sortedcollections import ValueSortedDict # type: ignore[import-untyped] -from typing_extensions import override - -from crawlee._types import StorageTypes -from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import raise_on_non_existing_storage -from crawlee._utils.file import force_remove, json_dumps -from crawlee._utils.requests import unique_key_to_request_id -from crawlee.storage_clients._base import RequestQueueClient as BaseRequestQueueClient -from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, - InternalRequest, - ProcessedRequest, - ProlongRequestLockResponse, - RequestQueueHead, - RequestQueueHeadWithLocks, - RequestQueueMetadata, - UnprocessedRequest, -) - -from ._creation_management import find_or_create_client_by_id_or_name_inner, persist_metadata_if_enabled - -if TYPE_CHECKING: - from collections.abc import Sequence - - from sortedcontainers import SortedDict - - from crawlee import Request - - from ._memory_storage_client import MemoryStorageClient +from crawlee.storage_clients._base import RequestQueueClient logger = getLogger(__name__) -class RequestQueueClient(BaseRequestQueueClient): - """Subclient for manipulating a single request queue.""" - - def __init__( - self, - *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, - handled_request_count: int = 0, - pending_request_count: int = 0, - ) -> None: - self._memory_storage_client = memory_storage_client - self.id = id or crypto_random_object_id() - self.name = name - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) - self.handled_request_count = handled_request_count - self.pending_request_count = pending_request_count - - self.requests: SortedDict[str, InternalRequest] = ValueSortedDict( - lambda request: request.order_no or -float('inf') - ) - self.file_operation_lock = asyncio.Lock() - self._last_used_timestamp = Decimal(0) - - self._in_progress = set[str]() - - @property - def resource_info(self) -> RequestQueueMetadata: - """Get the resource info for the request queue client.""" - return RequestQueueMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - had_multiple_clients=False, - handled_request_count=self.handled_request_count, - pending_request_count=self.pending_request_count, - stats={}, - total_request_count=len(self.requests), - user_id='1', - resource_directory=self.resource_directory, - ) - - @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.request_queues_directory, self.name or self.id) - - @override - async def get(self) -> RequestQueueMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info - - return None - - @override - async def delete(self) -> None: - queue = next( - (queue for queue in self._memory_storage_client.request_queues_handled if queue.id == self.id), - None, - ) - - if queue is not None: - async with queue.file_operation_lock: - self._memory_storage_client.request_queues_handled.remove(queue) - queue.pending_request_count = 0 - queue.handled_request_count = 0 - queue.requests.clear() - - if os.path.exists(queue.resource_directory): - await asyncio.to_thread(shutil.rmtree, queue.resource_directory) - - @override - async def list_head(self, *, limit: int | None = None, skip_in_progress: bool = False) -> RequestQueueHead: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - async with existing_queue_by_id.file_operation_lock: - await existing_queue_by_id.update_timestamps(has_been_modified=False) - - requests: list[Request] = [] - - # Iterate all requests in the queue which have sorted key larger than infinity, which means - # `order_no` is not `None`. This will iterate them in order of `order_no`. - for request_key in existing_queue_by_id.requests.irange_key( # type: ignore[attr-defined] # irange_key is a valid SortedDict method but not recognized by mypy - min_key=-float('inf'), inclusive=(False, True) - ): - if len(requests) == limit: - break - - if skip_in_progress and request_key in existing_queue_by_id._in_progress: # noqa: SLF001 - continue - internal_request = existing_queue_by_id.requests.get(request_key) - - # Check that the request still exists and was not handled, - # in case something deleted it or marked it as handled concurrenctly - if internal_request and not internal_request.handled_at: - requests.append(internal_request.to_request()) - - return RequestQueueHead( - limit=limit, - had_multiple_clients=False, - queue_modified_at=existing_queue_by_id._modified_at, # noqa: SLF001 - items=requests, - ) - - @override - async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) -> RequestQueueHeadWithLocks: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - result = await self.list_head(limit=limit, skip_in_progress=True) - - for item in result.items: - existing_queue_by_id._in_progress.add(item.id) # noqa: SLF001 - - return RequestQueueHeadWithLocks( - queue_has_locked_requests=len(existing_queue_by_id._in_progress) > 0, # noqa: SLF001 - lock_secs=lock_secs, - limit=result.limit, - had_multiple_clients=result.had_multiple_clients, - queue_modified_at=result.queue_modified_at, - items=result.items, - ) - - @override - async def add_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - internal_request = await self._create_internal_request(request, forefront) - - async with existing_queue_by_id.file_operation_lock: - existing_internal_request_with_id = existing_queue_by_id.requests.get(internal_request.id) - - # We already have the request present, so we return information about it - if existing_internal_request_with_id is not None: - await existing_queue_by_id.update_timestamps(has_been_modified=False) - - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=True, - was_already_handled=existing_internal_request_with_id.handled_at is not None, - ) - - existing_queue_by_id.requests[internal_request.id] = internal_request - if internal_request.handled_at: - existing_queue_by_id.handled_request_count += 1 - else: - existing_queue_by_id.pending_request_count += 1 - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._persist_single_request_to_storage( - request=internal_request, - entity_directory=existing_queue_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, - ) - - # We return was_already_handled=False even though the request may have been added as handled, - # because that's how API behaves. - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=False, - was_already_handled=False, - ) - - @override - async def batch_add_requests( - self, - requests: Sequence[Request], - *, - forefront: bool = False, - ) -> BatchRequestsOperationResponse: - processed_requests = list[ProcessedRequest]() - unprocessed_requests = list[UnprocessedRequest]() - - for request in requests: - try: - processed_request = await self.add_request(request, forefront=forefront) - processed_requests.append( - ProcessedRequest( - id=processed_request.id, - unique_key=processed_request.unique_key, - was_already_present=processed_request.was_already_present, - was_already_handled=processed_request.was_already_handled, - ) - ) - except Exception as exc: # noqa: PERF203 - logger.warning(f'Error adding request to the queue: {exc}') - unprocessed_requests.append( - UnprocessedRequest( - unique_key=request.unique_key, - url=request.url, - method=request.method, - ) - ) - - return BatchRequestsOperationResponse( - processed_requests=processed_requests, - unprocessed_requests=unprocessed_requests, - ) - - @override - async def get_request(self, request_id: str) -> Request | None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - async with existing_queue_by_id.file_operation_lock: - await existing_queue_by_id.update_timestamps(has_been_modified=False) - - internal_request = existing_queue_by_id.requests.get(request_id) - return internal_request.to_request() if internal_request else None - - @override - async def update_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - internal_request = await self._create_internal_request(request, forefront) - - # First we need to check the existing request to be able to return information about its handled state. - existing_internal_request = existing_queue_by_id.requests.get(internal_request.id) - - # Undefined means that the request is not present in the queue. - # We need to insert it, to behave the same as API. - if existing_internal_request is None: - return await self.add_request(request, forefront=forefront) - - async with existing_queue_by_id.file_operation_lock: - # When updating the request, we need to make sure that - # the handled counts are updated correctly in all cases. - existing_queue_by_id.requests[internal_request.id] = internal_request - - pending_count_adjustment = 0 - is_request_handled_state_changing = existing_internal_request.handled_at != internal_request.handled_at - - request_was_handled_before_update = existing_internal_request.handled_at is not None - - # We add 1 pending request if previous state was handled - if is_request_handled_state_changing: - pending_count_adjustment = 1 if request_was_handled_before_update else -1 - - existing_queue_by_id.pending_request_count += pending_count_adjustment - existing_queue_by_id.handled_request_count -= pending_count_adjustment - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._persist_single_request_to_storage( - request=internal_request, - entity_directory=existing_queue_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, - ) - - if request.handled_at is not None: - existing_queue_by_id._in_progress.discard(request.id) # noqa: SLF001 - - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=True, - was_already_handled=request_was_handled_before_update, - ) - - @override - async def delete_request(self, request_id: str) -> None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - async with existing_queue_by_id.file_operation_lock: - internal_request = existing_queue_by_id.requests.get(request_id) - - if internal_request: - del existing_queue_by_id.requests[request_id] - if internal_request.handled_at: - existing_queue_by_id.handled_request_count -= 1 - else: - existing_queue_by_id.pending_request_count -= 1 - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._delete_request_file_from_storage( - entity_directory=existing_queue_by_id.resource_directory, - request_id=request_id, - ) - - @override - async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: - raise NotImplementedError('This method is not supported in memory storage.') - - @override - async def prolong_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - lock_secs: int, - ) -> ProlongRequestLockResponse: - return ProlongRequestLockResponse(lock_expires_at=datetime.now(timezone.utc)) - - @override - async def delete_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - ) -> None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - existing_queue_by_id._in_progress.discard(request_id) # noqa: SLF001 - - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the request queue.""" - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) - - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) - - async def _persist_single_request_to_storage( - self, - *, - request: InternalRequest, - entity_directory: str, - persist_storage: bool, - ) -> None: - """Update or writes a single request item to the disk. - - This function writes a given request dictionary to a JSON file, named after the request's ID, - within a specified directory. The writing process is skipped if `persist_storage` is False. - Before writing, it ensures that the target directory exists, creating it if necessary. - - Args: - request: The dictionary containing the request data. - entity_directory: The directory path where the request file should be stored. - persist_storage: A boolean flag indicating whether the request should be persisted to the disk. - """ - # Skip writing files to the disk if the client has the option set to false - if not persist_storage: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Write the request to the file - file_path = os.path.join(entity_directory, f'{request.id}.json') - f = await asyncio.to_thread(open, file_path, mode='w', encoding='utf-8') - try: - s = await json_dumps(request.model_dump()) - await asyncio.to_thread(f.write, s) - finally: - f.close() - - async def _delete_request_file_from_storage(self, *, request_id: str, entity_directory: str) -> None: - """Delete a specific request item from the disk. - - This function removes a file representing a request, identified by the request's ID, from a - specified directory. Before attempting to remove the file, it ensures that the target directory - exists, creating it if necessary. - - Args: - request_id: The identifier of the request to be deleted. - entity_directory: The directory path where the request file is stored. - """ - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - file_path = os.path.join(entity_directory, f'{request_id}.json') - await force_remove(file_path) - - async def _create_internal_request(self, request: Request, forefront: bool | None) -> InternalRequest: - order_no = self._calculate_order_no(request, forefront) - id = unique_key_to_request_id(request.unique_key) - - if request.id is not None and request.id != id: - logger.warning( - f'The request ID does not match the ID from the unique_key (request.id={request.id}, id={id}).' - ) - - return InternalRequest.from_request(request=request, id=id, order_no=order_no) - - def _calculate_order_no(self, request: Request, forefront: bool | None) -> Decimal | None: - if request.handled_at is not None: - return None - - # Get the current timestamp in milliseconds - timestamp = Decimal(str(datetime.now(tz=timezone.utc).timestamp())) * Decimal(1000) - timestamp = round(timestamp, 6) - - # Make sure that this timestamp was not used yet, so that we have unique order_nos - if timestamp <= self._last_used_timestamp: - timestamp = self._last_used_timestamp + Decimal('0.000001') - - self._last_used_timestamp = timestamp - - return -timestamp if forefront else timestamp +class MemoryRequestQueueClient(RequestQueueClient): + pass diff --git a/src/crawlee/storage_clients/_memory/_storage_client.py b/src/crawlee/storage_clients/_memory/_storage_client.py new file mode 100644 index 0000000000..e867800f26 --- /dev/null +++ b/src/crawlee/storage_clients/_memory/_storage_client.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import MemoryDatasetClient +from ._key_value_store_client import MemoryKeyValueStoreClient +from ._request_queue_client import MemoryRequestQueueClient + +memory_storage_client = StorageClient( + dataset_client_class=MemoryDatasetClient, + key_value_store_client_class=MemoryKeyValueStoreClient, + request_queue_client_class=MemoryRequestQueueClient, +) diff --git a/src/crawlee/storages/_creation_management.py b/src/crawlee/storages/_creation_management.py index 0ba1f0739e..9137e39512 100644 --- a/src/crawlee/storages/_creation_management.py +++ b/src/crawlee/storages/_creation_management.py @@ -162,12 +162,12 @@ async def open_storage( raise RuntimeError(f'{storage_class.__name__} with id "{id}" does not exist!') elif is_default_on_memory: - resource_collection_client = _get_resource_collection_client(storage_class, storage_client) - storage_object = await resource_collection_client.get_or_create(name=name, id=id) + resource_client = _get_resource_client(storage_class, storage_client) + storage_object = await resource_client.get_or_create(name=name, id=id) else: - resource_collection_client = _get_resource_collection_client(storage_class, storage_client) - storage_object = await resource_collection_client.get_or_create(name=name) + resource_client = _get_resource_client(storage_class, storage_client) + storage_object = await resource_client.get_or_create(name=name) storage = storage_class.from_storage_object(storage_client=storage_client, storage_object=storage_object) diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index c19c28d58a..0693a88334 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -4,158 +4,52 @@ import io import json import logging -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Literal, TextIO, TypedDict, cast - -from typing_extensions import NotRequired, Required, Unpack, override +from pathlib import Path +from typing import TYPE_CHECKING, TextIO, cast from crawlee import service_locator -from crawlee._utils.byte_size import ByteSize from crawlee._utils.docs import docs_group -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import DatasetMetadata, StorageMetadata +from crawlee.storage_clients.models import DatasetMetadata -from ._base import Storage from ._key_value_store import KeyValueStore if TYPE_CHECKING: - from collections.abc import AsyncIterator, Callable + from collections.abc import AsyncIterator + from typing import Any + + from typing_extensions import Unpack - from crawlee._types import JsonSerializable, PushDataKwargs from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient + from crawlee.storage_clients._base import DatasetClient from crawlee.storage_clients.models import DatasetItemsListPage -logger = logging.getLogger(__name__) - - -class GetDataKwargs(TypedDict): - """Keyword arguments for dataset's `get_data` method.""" - - offset: NotRequired[int] - """Skip the specified number of items at the start.""" - - limit: NotRequired[int] - """The maximum number of items to retrieve. Unlimited if None.""" - - clean: NotRequired[bool] - """Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty.""" - - desc: NotRequired[bool] - """Set to True to sort results in descending order.""" - - fields: NotRequired[list[str]] - """Fields to include in each item. Sorts fields as specified if provided.""" - - omit: NotRequired[list[str]] - """Fields to exclude from each item.""" - - unwind: NotRequired[str] - """Unwind items by a specified array field, turning each element into a separate item.""" - - skip_empty: NotRequired[bool] - """Exclude empty items from the results if True.""" - - skip_hidden: NotRequired[bool] - """Exclude fields starting with '#' if True.""" - - flatten: NotRequired[list[str]] - """Field to be flattened in returned items.""" - - view: NotRequired[str] - """Specify the dataset view to be used.""" - - -class ExportToKwargs(TypedDict): - """Keyword arguments for dataset's `export_to` method.""" - - key: Required[str] - """The key under which to save the data.""" - - content_type: NotRequired[Literal['json', 'csv']] - """The format in which to export the data. Either 'json' or 'csv'.""" - - to_key_value_store_id: NotRequired[str] - """ID of the key-value store to save the exported file.""" - - to_key_value_store_name: NotRequired[str] - """Name of the key-value store to save the exported file.""" - - -class ExportDataJsonKwargs(TypedDict): - """Keyword arguments for dataset's `export_data_json` method.""" - - skipkeys: NotRequired[bool] - """If True (default: False), dict keys that are not of a basic type (str, int, float, bool, None) will be skipped - instead of raising a `TypeError`.""" - - ensure_ascii: NotRequired[bool] - """Determines if non-ASCII characters should be escaped in the output JSON string.""" - - check_circular: NotRequired[bool] - """If False (default: True), skips the circular reference check for container types. A circular reference will - result in a `RecursionError` or worse if unchecked.""" - - allow_nan: NotRequired[bool] - """If False (default: True), raises a ValueError for out-of-range float values (nan, inf, -inf) to strictly comply - with the JSON specification. If True, uses their JavaScript equivalents (NaN, Infinity, -Infinity).""" + from ._types import ExportDataCsvKwargs, ExportDataJsonKwargs, ExportToKwargs - cls: NotRequired[type[json.JSONEncoder]] - """Allows specifying a custom JSON encoder.""" - - indent: NotRequired[int] - """Specifies the number of spaces for indentation in the pretty-printed JSON output.""" - - separators: NotRequired[tuple[str, str]] - """A tuple of (item_separator, key_separator). The default is (', ', ': ') if indent is None and (',', ': ') - otherwise.""" - - default: NotRequired[Callable] - """A function called for objects that can't be serialized otherwise. It should return a JSON-encodable version - of the object or raise a `TypeError`.""" - - sort_keys: NotRequired[bool] - """Specifies whether the output JSON object should have keys sorted alphabetically.""" - - -class ExportDataCsvKwargs(TypedDict): - """Keyword arguments for dataset's `export_data_csv` method.""" - - dialect: NotRequired[str] - """Specifies a dialect to be used in CSV parsing and writing.""" - - delimiter: NotRequired[str] - """A one-character string used to separate fields. Defaults to ','.""" - - doublequote: NotRequired[bool] - """Controls how instances of `quotechar` inside a field should be quoted. When True, the character is doubled; - when False, the `escapechar` is used as a prefix. Defaults to True.""" - - escapechar: NotRequired[str] - """A one-character string used to escape the delimiter if `quoting` is set to `QUOTE_NONE` and the `quotechar` - if `doublequote` is False. Defaults to None, disabling escaping.""" - - lineterminator: NotRequired[str] - """The string used to terminate lines produced by the writer. Defaults to '\\r\\n'.""" - - quotechar: NotRequired[str] - """A one-character string used to quote fields containing special characters, like the delimiter or quotechar, - or fields containing new-line characters. Defaults to '\"'.""" - - quoting: NotRequired[int] - """Controls when quotes should be generated by the writer and recognized by the reader. Can take any of - the `QUOTE_*` constants, with a default of `QUOTE_MINIMAL`.""" - - skipinitialspace: NotRequired[bool] - """When True, spaces immediately following the delimiter are ignored. Defaults to False.""" +logger = logging.getLogger(__name__) - strict: NotRequired[bool] - """When True, raises an exception on bad CSV input. Defaults to False.""" +# TODO: +# - inherit from storage class +# - export methods + +# Dataset +# - properties: +# - id +# - name +# - metadata +# - methods: +# - open +# - drop +# - push_data +# - get_data +# - iterate +# - export_to_csv +# - export_to_json @docs_group('Classes') -class Dataset(Storage): - """Represents an append-only structured storage, ideal for tabular data similar to database tables. +class Dataset: + """Dataset is an append-only structured storage, ideal for tabular data similar to database tables. The `Dataset` class is designed to store structured data, where each entry (row) maintains consistent attributes (columns) across the dataset. It operates in an append-only mode, allowing new records to be added, but not @@ -186,89 +80,72 @@ class Dataset(Storage): ``` """ - _MAX_PAYLOAD_SIZE = ByteSize.from_mb(9) - """Maximum size for a single payload.""" - - _SAFETY_BUFFER_PERCENT = 0.01 / 100 # 0.01% - """Percentage buffer to reduce payload limit slightly for safety.""" + def __init__(self, client: DatasetClient) -> None: + """Initialize a new instance. - _EFFECTIVE_LIMIT_SIZE = _MAX_PAYLOAD_SIZE - (_MAX_PAYLOAD_SIZE * _SAFETY_BUFFER_PERCENT) - """Calculated payload limit considering safety buffer.""" - - def __init__(self, id: str, name: str | None, storage_client: StorageClient) -> None: - self._id = id - self._name = name - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) - - # Get resource clients from the storage client. - self._resource_client = storage_client.dataset(self._id) - - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> Dataset: - """Initialize a new instance of Dataset from a storage metadata object.""" - dataset = Dataset( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) + Preferably use the `Dataset.open` constructor to create a new instance. - dataset.storage_object = storage_object - return dataset + Args: + client: An instance of a dataset client. + """ + self._client = client @property - @override def id(self) -> str: - return self._id + return self._client.id @property - @override def name(self) -> str | None: - return self._name + return self._client.name @property - @override - def storage_object(self) -> StorageMetadata: - return self._storage_object - - @storage_object.setter - @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object + def metadata(self) -> DatasetMetadata: + return DatasetMetadata( + id=self._client.id, + name=self._client.id, + accessed_at=self._client.accessed_at, + created_at=self._client.created_at, + modified_at=self._client.modified_at, + item_count=self._client.item_count, + ) - @override @classmethod async def open( cls, *, id: str | None = None, name: str | None = None, + purge_on_start: bool | None = None, configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> Dataset: - from crawlee.storages._creation_management import open_storage + if id and name: + raise ValueError('Only one of "id" or "name" can be specified, not both.') configuration = configuration or service_locator.get_configuration() storage_client = storage_client or service_locator.get_storage_client() + purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start - return await open_storage( - storage_class=cls, + dataset_client = await storage_client.dataset_client_class.open( id=id, name=name, - configuration=configuration, - storage_client=storage_client, + storage_dir=Path(configuration.storage_dir), ) - @override - async def drop(self) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + if purge_on_start: + await dataset_client.drop() + dataset_client = await storage_client.dataset_client_class.open( + id=id, + name=name, + storage_dir=Path(configuration.storage_dir), + ) - await self._resource_client.delete() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + return cls(dataset_client) - async def push_data(self, data: JsonSerializable, **kwargs: Unpack[PushDataKwargs]) -> None: + async def drop(self) -> None: + await self._client.drop() + + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: """Store an object or an array of objects to the dataset. The size of the data is limited by the receiving API and therefore `push_data()` will only @@ -278,83 +155,109 @@ async def push_data(self, data: JsonSerializable, **kwargs: Unpack[PushDataKwarg Args: data: A JSON serializable data structure to be stored in the dataset. The JSON representation of each item must be smaller than 9MB. - kwargs: Keyword arguments for the storage client method. """ - # Handle singular items - if not isinstance(data, list): - items = await self.check_and_serialize(data) - return await self._resource_client.push_items(items, **kwargs) - - # Handle lists - payloads_generator = (await self.check_and_serialize(item, index) for index, item in enumerate(data)) - - # Invoke client in series to preserve the order of data - async for items in self._chunk_by_size(payloads_generator): - await self._resource_client.push_items(items, **kwargs) + await self._client.push_data(data) - return None - - async def get_data(self, **kwargs: Unpack[GetDataKwargs]) -> DatasetItemsListPage: - """Retrieve dataset items based on filtering, sorting, and pagination parameters. + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + """Retrieve a paginated list of items from a dataset based on various filtering parameters. - This method allows customization of the data retrieval process from a dataset, supporting operations such as - field selection, ordering, and skipping specific records based on provided parameters. + This method provides the flexibility to filter, sort, and modify the appearance of dataset items + when listed. Each parameter modifies the result set according to its purpose. The method also + supports pagination through 'offset' and 'limit' parameters. Args: - kwargs: Keyword arguments for the storage client method. + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. + skip_hidden: Excludes fields starting with '#' if True. + flatten: Fields to be flattened in returned items. + view: Specifies the dataset view to be used. Returns: - List page containing filtered and paginated dataset items. - """ - return await self._resource_client.list_items(**kwargs) - - async def write_to_csv(self, destination: TextIO, **kwargs: Unpack[ExportDataCsvKwargs]) -> None: - """Export the entire dataset into an arbitrary stream. - - Args: - destination: The stream into which the dataset contents should be written. - kwargs: Additional keyword arguments for `csv.writer`. + An object with filtered, sorted, and paginated dataset items plus pagination details. """ - items: list[dict] = [] - limit = 1000 - offset = 0 - - while True: - list_items = await self._resource_client.list_items(limit=limit, offset=offset) - items.extend(list_items.items) - if list_items.total <= offset + list_items.count: - break - offset += list_items.count + return await self._client.get_data( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + flatten=flatten, + view=view, + ) - if items: - writer = csv.writer(destination, **kwargs) - writer.writerows([items[0].keys(), *[item.values() for item in items]]) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') + async def iterate( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> AsyncIterator[dict]: + """Iterate over items in the dataset according to specified filters and sorting. - async def write_to_json(self, destination: TextIO, **kwargs: Unpack[ExportDataJsonKwargs]) -> None: - """Export the entire dataset into an arbitrary stream. + This method allows for asynchronously iterating through dataset items while applying various filters such as + skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` + parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and + `skip_hidden` parameters. Args: - destination: The stream into which the dataset contents should be written. - kwargs: Additional keyword arguments for `json.dump`. - """ - items: list[dict] = [] - limit = 1000 - offset = 0 - - while True: - list_items = await self._resource_client.list_items(limit=limit, offset=offset) - items.extend(list_items.items) - if list_items.total <= offset + list_items.count: - break - offset += list_items.count + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. + skip_hidden: Excludes fields starting with '#' if True. - if items: - json.dump(items, destination, **kwargs) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') + Yields: + An asynchronous iterator of dictionary objects, each representing a dataset item after applying + the specified filters and transformations. + """ + async for item in self._client.iterate( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + ): + yield item + # TODO: update this once KVS is implemented async def export_to(self, **kwargs: Unpack[ExportToKwargs]) -> None: """Export the entire dataset into a specified file stored under a key in a key-value store. @@ -387,112 +290,51 @@ async def export_to(self, **kwargs: Unpack[ExportToKwargs]) -> None: if content_type == 'json': await key_value_store.set_value(key, output.getvalue(), 'application/json') - async def get_info(self) -> DatasetMetadata | None: - """Get an object containing general information about the dataset.""" - metadata = await self._resource_client.get() - if isinstance(metadata, DatasetMetadata): - return metadata - return None - - async def iterate_items( - self, - *, - offset: int = 0, - limit: int | None = None, - clean: bool = False, - desc: bool = False, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_hidden: bool = False, - ) -> AsyncIterator[dict]: - """Iterate over dataset items, applying filtering, sorting, and pagination. - - Retrieve dataset items incrementally, allowing fine-grained control over the data fetched. The function - supports various parameters to filter, sort, and limit the data returned, facilitating tailored dataset - queries. - - Args: - offset: Initial number of items to skip. - limit: Max number of items to return. No limit if None. - clean: Filter out empty items and hidden fields if True. - desc: Return items in reverse order if True. - fields: Specific fields to include in each item. - omit: Fields to omit from each item. - unwind: Field name to unwind items by. - skip_empty: Omits empty items if True. - skip_hidden: Excludes fields starting with '#' if True. - - Yields: - Each item from the dataset as a dictionary. - """ - async for item in self._resource_client.iterate_items( - offset=offset, - limit=limit, - clean=clean, - desc=desc, - fields=fields, - omit=omit, - unwind=unwind, - skip_empty=skip_empty, - skip_hidden=skip_hidden, - ): - yield item - - @classmethod - async def check_and_serialize(cls, item: JsonSerializable, index: int | None = None) -> str: - """Serialize a given item to JSON, checks its serializability and size against a limit. + # TODO: update this once KVS is implemented + async def write_to_csv(self, destination: TextIO, **kwargs: Unpack[ExportDataCsvKwargs]) -> None: + """Export the entire dataset into an arbitrary stream. Args: - item: The item to serialize. - index: Index of the item, used for error context. - - Returns: - Serialized JSON string. - - Raises: - ValueError: If item is not JSON serializable or exceeds size limit. + destination: The stream into which the dataset contents should be written. + kwargs: Additional keyword arguments for `csv.writer`. """ - s = ' ' if index is None else f' at index {index} ' - - try: - payload = await json_dumps(item) - except Exception as exc: - raise ValueError(f'Data item{s}is not serializable to JSON.') from exc - - payload_size = ByteSize(len(payload.encode('utf-8'))) - if payload_size > cls._EFFECTIVE_LIMIT_SIZE: - raise ValueError(f'Data item{s}is too large (size: {payload_size}, limit: {cls._EFFECTIVE_LIMIT_SIZE})') + items: list[dict] = [] + limit = 1000 + offset = 0 - return payload + while True: + list_items = await self._client.get_data(limit=limit, offset=offset) + items.extend(list_items.items) + if list_items.total <= offset + list_items.count: + break + offset += list_items.count - async def _chunk_by_size(self, items: AsyncIterator[str]) -> AsyncIterator[str]: - """Yield chunks of JSON arrays composed of input strings, respecting a size limit. + if items: + writer = csv.writer(destination, **kwargs) + writer.writerows([items[0].keys(), *[item.values() for item in items]]) + else: + logger.warning('Attempting to export an empty dataset - no file will be created') - Groups an iterable of JSON string payloads into larger JSON arrays, ensuring the total size - of each array does not exceed `EFFECTIVE_LIMIT_SIZE`. Each output is a JSON array string that - contains as many payloads as possible without breaching the size threshold, maintaining the - order of the original payloads. Assumes individual items are below the size limit. + # TODO: update this once KVS is implemented + async def write_to_json(self, destination: TextIO, **kwargs: Unpack[ExportDataJsonKwargs]) -> None: + """Export the entire dataset into an arbitrary stream. Args: - items: Iterable of JSON string payloads. - - Yields: - Strings representing JSON arrays of payloads, each staying within the size limit. + destination: The stream into which the dataset contents should be written. + kwargs: Additional keyword arguments for `json.dump`. """ - last_chunk_size = ByteSize(2) # Add 2 bytes for [] wrapper. - current_chunk = [] - - async for payload in items: - payload_size = ByteSize(len(payload.encode('utf-8'))) + items: list[dict] = [] + limit = 1000 + offset = 0 - if last_chunk_size + payload_size <= self._EFFECTIVE_LIMIT_SIZE: - current_chunk.append(payload) - last_chunk_size += payload_size + ByteSize(1) # Add 1 byte for ',' separator. - else: - yield f'[{",".join(current_chunk)}]' - current_chunk = [payload] - last_chunk_size = payload_size + ByteSize(2) # Add 2 bytes for [] wrapper. + while True: + list_items = await self._client.get_data(limit=limit, offset=offset) + items.extend(list_items.items) + if list_items.total <= offset + list_items.count: + break + offset += list_items.count - yield f'[{",".join(current_chunk)}]' + if items: + json.dump(items, destination, **kwargs) + else: + logger.warning('Attempting to export an empty dataset - no file will be created') diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index 6ef58c047a..372e4373e5 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -5,7 +5,7 @@ from contextlib import suppress from datetime import datetime, timedelta, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any, TypedDict, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from cachetools import LRUCache from typing_extensions import override @@ -27,20 +27,13 @@ from crawlee import Request from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient + from crawlee.storages._types import CachedRequest logger = getLogger(__name__) T = TypeVar('T') -class CachedRequest(TypedDict): - id: str - was_already_handled: bool - hydrated: Request | None - lock_expires_at: datetime | None - forefront: bool - - @docs_group('Classes') class RequestQueue(Storage, RequestManager): """Represents a queue storage for managing HTTP requests in web crawling operations. diff --git a/src/crawlee/storages/_types.py b/src/crawlee/storages/_types.py new file mode 100644 index 0000000000..bc99ce72fd --- /dev/null +++ b/src/crawlee/storages/_types.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, TypedDict + +if TYPE_CHECKING: + import json + from collections.abc import Callable + from datetime import datetime + + from typing_extensions import NotRequired, Required + + from crawlee import Request + + +class CachedRequest(TypedDict): + """Represent a cached request in the `RequestQueue`.""" + + id: str + """The ID of the request.""" + + was_already_handled: bool + """Indicates whether the request was already handled.""" + + hydrated: Request | None + """The hydrated request object.""" + + lock_expires_at: datetime | None + """The time at which the lock on the request expires.""" + + forefront: bool + """Indicates whether the request is at the forefront of the queue.""" + + +class IterateKwargs(TypedDict): + """Keyword arguments for dataset's `iterate` method.""" + + offset: NotRequired[int] + """Skips the specified number of items at the start.""" + + limit: NotRequired[int | None] + """The maximum number of items to retrieve. Unlimited if None.""" + + clean: NotRequired[bool] + """Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty.""" + + desc: NotRequired[bool] + """Set to True to sort results in descending order.""" + + fields: NotRequired[list[str]] + """Fields to include in each item. Sorts fields as specified if provided.""" + + omit: NotRequired[list[str]] + """Fields to exclude from each item.""" + + unwind: NotRequired[str] + """Unwinds items by a specified array field, turning each element into a separate item.""" + + skip_empty: NotRequired[bool] + """Excludes empty items from the results if True.""" + + skip_hidden: NotRequired[bool] + """Excludes fields starting with '#' if True.""" + + +class GetDataKwargs(IterateKwargs): + """Keyword arguments for dataset's `get_data` method.""" + + flatten: NotRequired[list[str]] + """Fields to be flattened in returned items.""" + + view: NotRequired[str] + """Specifies the dataset view to be used.""" + + +class ExportToKwargs(TypedDict): + """Keyword arguments for dataset's `export_to` method.""" + + key: Required[str] + """The key under which to save the data.""" + + content_type: NotRequired[Literal['json', 'csv']] + """The format in which to export the data. Either 'json' or 'csv'.""" + + to_key_value_store_id: NotRequired[str] + """ID of the key-value store to save the exported file.""" + + to_key_value_store_name: NotRequired[str] + """Name of the key-value store to save the exported file.""" + + +class ExportDataJsonKwargs(TypedDict): + """Keyword arguments for dataset's `export_data_json` method.""" + + skipkeys: NotRequired[bool] + """If True (default: False), dict keys that are not of a basic type (str, int, float, bool, None) will be skipped + instead of raising a `TypeError`.""" + + ensure_ascii: NotRequired[bool] + """Determines if non-ASCII characters should be escaped in the output JSON string.""" + + check_circular: NotRequired[bool] + """If False (default: True), skips the circular reference check for container types. A circular reference will + result in a `RecursionError` or worse if unchecked.""" + + allow_nan: NotRequired[bool] + """If False (default: True), raises a ValueError for out-of-range float values (nan, inf, -inf) to strictly comply + with the JSON specification. If True, uses their JavaScript equivalents (NaN, Infinity, -Infinity).""" + + cls: NotRequired[type[json.JSONEncoder]] + """Allows specifying a custom JSON encoder.""" + + indent: NotRequired[int] + """Specifies the number of spaces for indentation in the pretty-printed JSON output.""" + + separators: NotRequired[tuple[str, str]] + """A tuple of (item_separator, key_separator). The default is (', ', ': ') if indent is None and (',', ': ') + otherwise.""" + + default: NotRequired[Callable] + """A function called for objects that can't be serialized otherwise. It should return a JSON-encodable version + of the object or raise a `TypeError`.""" + + sort_keys: NotRequired[bool] + """Specifies whether the output JSON object should have keys sorted alphabetically.""" + + +class ExportDataCsvKwargs(TypedDict): + """Keyword arguments for dataset's `export_data_csv` method.""" + + dialect: NotRequired[str] + """Specifies a dialect to be used in CSV parsing and writing.""" + + delimiter: NotRequired[str] + """A one-character string used to separate fields. Defaults to ','.""" + + doublequote: NotRequired[bool] + """Controls how instances of `quotechar` inside a field should be quoted. When True, the character is doubled; + when False, the `escapechar` is used as a prefix. Defaults to True.""" + + escapechar: NotRequired[str] + """A one-character string used to escape the delimiter if `quoting` is set to `QUOTE_NONE` and the `quotechar` + if `doublequote` is False. Defaults to None, disabling escaping.""" + + lineterminator: NotRequired[str] + """The string used to terminate lines produced by the writer. Defaults to '\\r\\n'.""" + + quotechar: NotRequired[str] + """A one-character string used to quote fields containing special characters, like the delimiter or quotechar, + or fields containing new-line characters. Defaults to '\"'.""" + + quoting: NotRequired[int] + """Controls when quotes should be generated by the writer and recognized by the reader. Can take any of + the `QUOTE_*` constants, with a default of `QUOTE_MINIMAL`.""" + + skipinitialspace: NotRequired[bool] + """When True, spaces immediately following the delimiter are ignored. Defaults to False.""" + + strict: NotRequired[bool] + """When True, raises an exception on bad CSV input. Defaults to False.""" diff --git a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py index 89b028ed81..8af2509db9 100644 --- a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py +++ b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py @@ -456,7 +456,7 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None: await crawler.run(test_urls[:1]) dataset = await crawler.get_dataset() - stored_results = [item async for item in dataset.iterate_items()] + stored_results = [item async for item in dataset.iterate()] if error_in_pw_crawler: assert stored_results == [] diff --git a/tests/unit/storage_clients/_memory/test_memory_storage_client.py b/tests/unit/storage_clients/_memory/test_memory_storage_client.py index 0d043322ae..66345fb023 100644 --- a/tests/unit/storage_clients/_memory/test_memory_storage_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_storage_client.py @@ -74,7 +74,7 @@ async def test_persist_storage(persist_storage: bool, tmp_path: Path) -> None: ds_client = ms.datasets() ds_info = await ds_client.get_or_create(name='ds') - await ms.dataset(ds_info.id).push_items([{'foo': 'bar'}]) + await ms.dataset(ds_info.id).push_data([{'foo': 'bar'}]) def test_persist_storage_set_to_false_via_string_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index f299aee08d..2d21eac3b8 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -129,7 +129,7 @@ async def test_iterate_items(dataset: Dataset) -> None: idx = 0 await dataset.push_data([{'id': i} for i in range(desired_item_count)]) - async for item in dataset.iterate_items(): + async for item in dataset.iterate(): assert item['id'] == idx idx += 1 From 0ff7e67a8584ff06296ab6336c72b8ee6ca29c34 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Thu, 20 Mar 2025 14:21:14 +0100 Subject: [PATCH 03/22] Update storage clients --- src/crawlee/_service_locator.py | 10 +-- src/crawlee/storage_clients/__init__.py | 8 +-- .../storage_clients/_base/_storage_client.py | 44 +++++++++++-- .../storage_clients/_file_system/__init__.py | 4 +- .../_file_system/_storage_client.py | 62 +++++++++++++++++-- .../storage_clients/_memory/__init__.py | 4 +- .../_memory/_storage_client.py | 54 ++++++++++++++-- src/crawlee/storages/_dataset.py | 20 +++--- website/generate_module_shortcuts.py | 3 +- 9 files changed, 165 insertions(+), 44 deletions(-) diff --git a/src/crawlee/_service_locator.py b/src/crawlee/_service_locator.py index 08b25c35cd..2cb8f8302a 100644 --- a/src/crawlee/_service_locator.py +++ b/src/crawlee/_service_locator.py @@ -3,8 +3,8 @@ from crawlee._utils.docs import docs_group from crawlee.configuration import Configuration from crawlee.errors import ServiceConflictError -from crawlee.events import EventManager -from crawlee.storage_clients import StorageClient +from crawlee.events import EventManager, LocalEventManager +from crawlee.storage_clients import FileSystemStorageClient, StorageClient @docs_group('Classes') @@ -49,8 +49,6 @@ def set_configuration(self, configuration: Configuration) -> None: def get_event_manager(self) -> EventManager: """Get the event manager.""" if self._event_manager is None: - from crawlee.events import LocalEventManager - self._event_manager = ( LocalEventManager().from_config(config=self._configuration) if self._configuration @@ -77,9 +75,7 @@ def set_event_manager(self, event_manager: EventManager) -> None: def get_storage_client(self) -> StorageClient: """Get the storage client.""" if self._storage_client is None: - from crawlee.storage_clients import file_system_storage_client - - self._storage_client = file_system_storage_client + self._storage_client = FileSystemStorageClient() self._storage_client_was_retrieved = True return self._storage_client diff --git a/src/crawlee/storage_clients/__init__.py b/src/crawlee/storage_clients/__init__.py index 848c160c37..ce8c713ca9 100644 --- a/src/crawlee/storage_clients/__init__.py +++ b/src/crawlee/storage_clients/__init__.py @@ -1,9 +1,9 @@ from ._base import StorageClient -from ._file_system import file_system_storage_client -from ._memory import memory_storage_client +from ._file_system import FileSystemStorageClient +from ._memory import MemoryStorageClient __all__ = [ + 'FileSystemStorageClient', + 'MemoryStorageClient', 'StorageClient', - 'file_system_storage_client', - 'memory_storage_client' ] diff --git a/src/crawlee/storage_clients/_base/_storage_client.py b/src/crawlee/storage_clients/_base/_storage_client.py index fc7e2e4d97..b12792f202 100644 --- a/src/crawlee/storage_clients/_base/_storage_client.py +++ b/src/crawlee/storage_clients/_base/_storage_client.py @@ -1,16 +1,48 @@ from __future__ import annotations -from dataclasses import dataclass +from abc import ABC, abstractmethod from typing import TYPE_CHECKING if TYPE_CHECKING: + from pathlib import Path + from ._dataset_client import DatasetClient from ._key_value_store_client import KeyValueStoreClient from ._request_queue_client import RequestQueueClient -@dataclass -class StorageClient: - dataset_client_class: type[DatasetClient] - key_value_store_client_class: type[KeyValueStoreClient] - request_queue_client_class: type[RequestQueueClient] +class StorageClient(ABC): + """Base class for storage clients.""" + + @abstractmethod + async def open_dataset_client( + self, + *, + id: str | None, + name: str | None, + purge_on_start: bool, + storage_dir: Path, + ) -> DatasetClient: + """Open the dataset client.""" + + @abstractmethod + async def open_key_value_store_client( + self, + *, + id: str | None, + name: str | None, + purge_on_start: bool, + storage_dir: Path, + ) -> KeyValueStoreClient: + """Open the key-value store client.""" + + @abstractmethod + async def open_request_queue_client( + self, + *, + id: str | None, + name: str | None, + purge_on_start: bool, + storage_dir: Path, + ) -> RequestQueueClient: + """Open the request queue client.""" diff --git a/src/crawlee/storage_clients/_file_system/__init__.py b/src/crawlee/storage_clients/_file_system/__init__.py index 3aa67ad6dc..bac1291176 100644 --- a/src/crawlee/storage_clients/_file_system/__init__.py +++ b/src/crawlee/storage_clients/_file_system/__init__.py @@ -1,3 +1,3 @@ -from ._storage_client import file_system_storage_client +from ._storage_client import FileSystemStorageClient -__all__ = ['file_system_storage_client'] +__all__ = ['FileSystemStorageClient'] diff --git a/src/crawlee/storage_clients/_file_system/_storage_client.py b/src/crawlee/storage_clients/_file_system/_storage_client.py index 248d07b6f6..e0420fe653 100644 --- a/src/crawlee/storage_clients/_file_system/_storage_client.py +++ b/src/crawlee/storage_clients/_file_system/_storage_client.py @@ -1,13 +1,65 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +from typing_extensions import override + from crawlee.storage_clients._base import StorageClient from ._dataset_client import FileSystemDatasetClient from ._key_value_store import FileSystemKeyValueStoreClient from ._request_queue import FileSystemRequestQueueClient -file_system_storage_client = StorageClient( - dataset_client_class=FileSystemDatasetClient, - key_value_store_client_class=FileSystemKeyValueStoreClient, - request_queue_client_class=FileSystemRequestQueueClient, -) +if TYPE_CHECKING: + from pathlib import Path + + +class FileSystemStorageClient(StorageClient): + """File system storage client.""" + + @override + async def open_dataset_client( + self, + *, + id: str | None, + name: str | None, + purge_on_start: bool, + storage_dir: Path, + ) -> FileSystemDatasetClient: + dataset_client = await FileSystemDatasetClient.open( + id=id, + name=name, + storage_dir=storage_dir, + ) + + if purge_on_start: + await dataset_client.drop() + dataset_client = await FileSystemDatasetClient.open( + id=id, + name=name, + storage_dir=storage_dir, + ) + + return dataset_client + + @override + async def open_key_value_store_client( + self, + *, + id: str | None, + name: str | None, + purge_on_start: bool, + storage_dir: Path, + ) -> FileSystemKeyValueStoreClient: + return FileSystemKeyValueStoreClient() + + @override + async def open_request_queue_client( + self, + *, + id: str | None, + name: str | None, + purge_on_start: bool, + storage_dir: Path, + ) -> FileSystemRequestQueueClient: + return FileSystemRequestQueueClient() diff --git a/src/crawlee/storage_clients/_memory/__init__.py b/src/crawlee/storage_clients/_memory/__init__.py index 2463d516c2..0d117a8a6c 100644 --- a/src/crawlee/storage_clients/_memory/__init__.py +++ b/src/crawlee/storage_clients/_memory/__init__.py @@ -1,3 +1,3 @@ -from ._storage_client import memory_storage_client +from ._storage_client import MemoryStorageClient -__all__ = ['memory_storage_client'] +__all__ = ['MemoryStorageClient'] diff --git a/src/crawlee/storage_clients/_memory/_storage_client.py b/src/crawlee/storage_clients/_memory/_storage_client.py index e867800f26..4d9d090b38 100644 --- a/src/crawlee/storage_clients/_memory/_storage_client.py +++ b/src/crawlee/storage_clients/_memory/_storage_client.py @@ -1,13 +1,57 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +from typing_extensions import override + from crawlee.storage_clients._base import StorageClient from ._dataset_client import MemoryDatasetClient from ._key_value_store_client import MemoryKeyValueStoreClient from ._request_queue_client import MemoryRequestQueueClient -memory_storage_client = StorageClient( - dataset_client_class=MemoryDatasetClient, - key_value_store_client_class=MemoryKeyValueStoreClient, - request_queue_client_class=MemoryRequestQueueClient, -) +if TYPE_CHECKING: + from pathlib import Path + + +class MemoryStorageClient(StorageClient): + """Memory storage client.""" + + @override + async def open_dataset_client( + self, + *, + id: str | None, + name: str | None, + purge_on_start: bool, + storage_dir: Path, + ) -> MemoryDatasetClient: + dataset_client = await MemoryDatasetClient.open(id=id, name=name, storage_dir=storage_dir) + + if purge_on_start: + await dataset_client.drop() + dataset_client = await MemoryDatasetClient.open(id=id, name=name, storage_dir=storage_dir) + + return dataset_client + + @override + async def open_key_value_store_client( + self, + *, + id: str | None, + name: str | None, + purge_on_start: bool, + storage_dir: Path, + ) -> MemoryKeyValueStoreClient: + return MemoryKeyValueStoreClient() + + @override + async def open_request_queue_client( + self, + *, + id: str | None, + name: str | None, + purge_on_start: bool, + storage_dir: Path, + ) -> MemoryRequestQueueClient: + return MemoryRequestQueueClient() diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 0693a88334..b6a324434b 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -31,6 +31,7 @@ # TODO: # - inherit from storage class # - export methods +# - caching / memoization of both datasets & dataset clients # Dataset # - properties: @@ -116,30 +117,25 @@ async def open( id: str | None = None, name: str | None = None, purge_on_start: bool | None = None, + storage_dir: Path | None = None, configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> Dataset: if id and name: raise ValueError('Only one of "id" or "name" can be specified, not both.') - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start + storage_dir = Path(configuration.storage_dir) if storage_dir is None else storage_dir - dataset_client = await storage_client.dataset_client_class.open( + dataset_client = await storage_client.open_dataset_client( id=id, name=name, - storage_dir=Path(configuration.storage_dir), + purge_on_start=purge_on_start, + storage_dir=storage_dir, ) - if purge_on_start: - await dataset_client.drop() - dataset_client = await storage_client.dataset_client_class.open( - id=id, - name=name, - storage_dir=Path(configuration.storage_dir), - ) - return cls(dataset_client) async def drop(self) -> None: diff --git a/website/generate_module_shortcuts.py b/website/generate_module_shortcuts.py index 5a18e8d3f3..61acc68ade 100755 --- a/website/generate_module_shortcuts.py +++ b/website/generate_module_shortcuts.py @@ -5,6 +5,7 @@ import importlib import inspect import json +from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -55,5 +56,5 @@ def resolve_shortcuts(shortcuts: dict) -> None: resolve_shortcuts(shortcuts) -with open('module_shortcuts.json', 'w', encoding='utf-8') as shortcuts_file: +with Path('module_shortcuts.json').open('w', encoding='utf-8') as shortcuts_file: json.dump(shortcuts, shortcuts_file, indent=4, sort_keys=True) From a805d9a4839fb08f24ff1f3beee35284bb710bad Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Tue, 8 Apr 2025 16:49:54 +0200 Subject: [PATCH 04/22] Update KVS and its clients --- src/crawlee/_utils/file.py | 16 - .../storage_clients/_base/_dataset_client.py | 3 +- .../_base/_key_value_store_client.py | 149 ++++--- .../_file_system/_dataset_client.py | 8 +- .../_file_system/_key_value_store.py | 11 - .../_file_system/_key_value_store_client.py | 411 ++++++++++++++++++ ...uest_queue.py => _request_queue_client.py} | 0 .../_file_system/_storage_client.py | 36 +- .../storage_clients/_file_system/_utils.py | 4 +- src/crawlee/storage_clients/models.py | 35 +- src/crawlee/storages/_dataset.py | 6 +- src/crawlee/storages/_key_value_store.py | 262 +++++------ .../_memory/test_key_value_store_client.py | 1 - 13 files changed, 659 insertions(+), 283 deletions(-) delete mode 100644 src/crawlee/storage_clients/_file_system/_key_value_store.py create mode 100644 src/crawlee/storage_clients/_file_system/_key_value_store_client.py rename src/crawlee/storage_clients/_file_system/{_request_queue.py => _request_queue_client.py} (100%) diff --git a/src/crawlee/_utils/file.py b/src/crawlee/_utils/file.py index 022d0604ef..d50f6ecd41 100644 --- a/src/crawlee/_utils/file.py +++ b/src/crawlee/_utils/file.py @@ -2,7 +2,6 @@ import asyncio import contextlib -import io import json import mimetypes import os @@ -83,21 +82,6 @@ def determine_file_extension(content_type: str) -> str | None: return ext[1:] if ext is not None else ext -def is_file_or_bytes(value: Any) -> bool: - """Determine if the input value is a file-like object or bytes. - - This function checks whether the provided value is an instance of bytes, bytearray, or io.IOBase (file-like). - The method is simplified for common use cases and may not cover all edge cases. - - Args: - value: The value to be checked. - - Returns: - True if the value is either a file-like object or bytes, False otherwise. - """ - return isinstance(value, (bytes, bytearray, io.IOBase)) - - async def json_dumps(obj: Any) -> str: """Serialize an object to a JSON-formatted string with specific settings. diff --git a/src/crawlee/storage_clients/_base/_dataset_client.py b/src/crawlee/storage_clients/_base/_dataset_client.py index d68bb35c84..265856b3ff 100644 --- a/src/crawlee/storage_clients/_base/_dataset_client.py +++ b/src/crawlee/storage_clients/_base/_dataset_client.py @@ -56,6 +56,7 @@ def item_count(self) -> int: @abstractmethod async def open( cls, + *, id: str | None, name: str | None, storage_dir: Path, @@ -82,7 +83,7 @@ async def drop(self) -> None: """ @abstractmethod - async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + async def push_data(self, *, data: list[Any] | dict[str, Any]) -> None: """Push data to the dataset. The backend method for the `Dataset.push_data` call. diff --git a/src/crawlee/storage_clients/_base/_key_value_store_client.py b/src/crawlee/storage_clients/_base/_key_value_store_client.py index 91f73993b0..097b5fbf8f 100644 --- a/src/crawlee/storage_clients/_base/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_base/_key_value_store_client.py @@ -6,111 +6,136 @@ from crawlee._utils.docs import docs_group if TYPE_CHECKING: - from contextlib import AbstractAsyncContextManager - - from httpx import Response - - from crawlee.storage_clients.models import KeyValueStoreListKeysPage, KeyValueStoreMetadata, KeyValueStoreRecord + from collections.abc import AsyncIterator + from datetime import datetime + from pathlib import Path + + from crawlee.storage_clients.models import ( + KeyValueStoreRecord, + KeyValueStoreRecordMetadata, + ) + +# Properties: +# - id +# - name +# - created_at +# - accessed_at +# - modified_at + +# Methods: +# - open +# - drop +# - get_value +# - set_value +# - delete_value +# - iterate_keys +# - get_public_url @docs_group('Abstract classes') class KeyValueStoreClient(ABC): - """An abstract class for key-value store resource clients. + """An abstract class for key-value store (KVS) resource clients. These clients are specific to the type of resource they manage and operate under a designated storage client, like a memory storage client. """ + @property @abstractmethod - async def get(self) -> KeyValueStoreMetadata | None: - """Get metadata about the key-value store being managed by this client. + def id(self) -> str: + """The ID of the key-value store.""" - Returns: - An object containing the key-value store's details, or None if the key-value store does not exist. - """ + @property + @abstractmethod + def name(self) -> str | None: + """The name of the key-value store.""" + @property @abstractmethod - async def delete(self) -> None: - """Permanently delete the key-value store managed by this client.""" + def created_at(self) -> datetime: + """The time at which the key-value store was created.""" + @property @abstractmethod - async def list_keys( - self, + def accessed_at(self) -> datetime: + """The time at which the key-value store was last accessed.""" + + @property + @abstractmethod + def modified_at(self) -> datetime: + """The time at which the key-value store was last modified.""" + + @classmethod + @abstractmethod + async def open( + cls, *, - limit: int = 1000, - exclusive_start_key: str | None = None, - ) -> KeyValueStoreListKeysPage: - """List the keys in the key-value store. + id: str | None, + name: str | None, + storage_dir: Path, + ) -> KeyValueStoreClient: + """Open existing or create a new key-value store client. + + If a key-value store with the given name already exists, the appropriate key-value store client is returned. + Otherwise, a new key-value store is created and client for it is returned. Args: - limit: Number of keys to be returned. Maximum value is 1000. - exclusive_start_key: All keys up to this one (including) are skipped from the result. + id: The ID of the key-value store. + name: The name of the key-value store. + storage_dir: The path to the storage directory. If the client persists data, it should use this directory. Returns: - The list of keys in the key-value store matching the given arguments. + A key-value store client. """ @abstractmethod - async def get_record(self, key: str) -> KeyValueStoreRecord | None: - """Retrieve the given record from the key-value store. - - Args: - key: Key of the record to retrieve. + async def drop(self) -> None: + """Drop the whole key-value store and remove all its values. - Returns: - The requested record, or None, if the record does not exist + The backend method for the `KeyValueStore.drop` call. """ @abstractmethod - async def get_record_as_bytes(self, key: str) -> KeyValueStoreRecord[bytes] | None: - """Retrieve the given record from the key-value store, without parsing it. - - Args: - key: Key of the record to retrieve. + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + """Retrieve the given record from the key-value store. - Returns: - The requested record, or None, if the record does not exist + The backend method for the `KeyValueStore.get_value` call. """ @abstractmethod - async def stream_record(self, key: str) -> AbstractAsyncContextManager[KeyValueStoreRecord[Response] | None]: - """Retrieve the given record from the key-value store, as a stream. + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + """Set a value in the key-value store by its key. - Args: - key: Key of the record to retrieve. - - Returns: - The requested record as a context-managed streaming Response, or None, if the record does not exist + The backend method for the `KeyValueStore.set_value` call. """ @abstractmethod - async def set_record(self, key: str, value: Any, content_type: str | None = None) -> None: - """Set a value to the given record in the key-value store. + async def delete_value(self, *, key: str) -> None: + """Delete a value from the key-value store by its key. - Args: - key: The key of the record to save the value to. - value: The value to save into the record. - content_type: The content type of the saved value. + The backend method for the `KeyValueStore.delete_value` call. """ @abstractmethod - async def delete_record(self, key: str) -> None: - """Delete the specified record from the key-value store. + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + """Iterate over all the existing keys in the key-value store. - Args: - key: The key of the record which to delete. + The backend method for the `KeyValueStore.iterate_keys` call. """ + # This syntax is to make mypy properly work with abstract AsyncIterator. + # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators + raise NotImplementedError + if False: # type: ignore[unreachable] + yield 0 @abstractmethod - async def get_public_url(self, key: str) -> str: + async def get_public_url(self, *, key: str) -> str: """Get the public URL for the given key. - Args: - key: Key of the record for which URL is required. - - Returns: - The public URL for the given key. - - Raises: - ValueError: If the key does not exist. + The backend method for the `KeyValueStore.get_public_url` call. """ diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py index b05340cae2..63103ff310 100644 --- a/src/crawlee/storage_clients/_file_system/_dataset_client.py +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -10,12 +10,11 @@ from pydantic import ValidationError from typing_extensions import override -from crawlee._consts import METADATA_FILENAME from crawlee._utils.crypto import crypto_random_object_id from crawlee.storage_clients._base import DatasetClient from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata -from ._utils import json_dumps +from ._utils import METADATA_FILENAME, json_dumps if TYPE_CHECKING: from collections.abc import AsyncIterator @@ -129,6 +128,11 @@ async def open( Returns: A new instance of the file system dataset client. """ + if id: + raise ValueError( + 'Opening a dataset by "id" is not supported for file system storage client, use "name" instead.' + ) + name = name or cls._DEFAULT_NAME dataset_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = dataset_path / METADATA_FILENAME diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store.py b/src/crawlee/storage_clients/_file_system/_key_value_store.py deleted file mode 100644 index 8bf0815b3a..0000000000 --- a/src/crawlee/storage_clients/_file_system/_key_value_store.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from logging import getLogger - -from crawlee.storage_clients._base import KeyValueStoreClient - -logger = getLogger(__name__) - - -class FileSystemKeyValueStoreClient(KeyValueStoreClient): - pass diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py new file mode 100644 index 0000000000..921838a73f --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -0,0 +1,411 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +from datetime import datetime, timezone +from logging import getLogger +from typing import TYPE_CHECKING, Any + +from pydantic import ValidationError +from typing_extensions import override + +from crawlee._utils.crypto import crypto_random_object_id +from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import ( + KeyValueStoreMetadata, + KeyValueStoreRecord, + KeyValueStoreRecordMetadata, +) + +from ._utils import METADATA_FILENAME, json_dumps + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from pathlib import Path + +logger = getLogger(__name__) + + +class FileSystemKeyValueStoreClient(KeyValueStoreClient): + """A file system key-value store (KVS) implementation.""" + + _DEFAULT_NAME = 'default' + """The name of the unnamed KVS.""" + + _STORAGE_SUBDIR = 'key_value_stores' + """The name of the subdirectory where KVSs are stored.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + storage_dir: Path, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemKeyValueStoreClient.open` class method to create a new instance. + """ + self._id = id + self._name = name + self._created_at = created_at + self._accessed_at = accessed_at + self._modified_at = modified_at + self._storage_dir = storage_dir + + # Internal attributes. + self._lock = asyncio.Lock() + """A lock to ensure that only one file operation is performed at a time.""" + + @override + @property + def id(self) -> str: + return self._id + + @override + @property + def name(self) -> str | None: + return self._name + + @override + @property + def created_at(self) -> datetime: + return self._created_at + + @override + @property + def accessed_at(self) -> datetime: + return self._accessed_at + + @override + @property + def modified_at(self) -> datetime: + return self._modified_at + + @property + def _path_to_kvs(self) -> Path: + """The full path to the key-value store directory.""" + return self._storage_dir / self._STORAGE_SUBDIR / self._name + + @property + def _path_to_metadata(self) -> Path: + """The full path to the key-value store metadata file.""" + return self._path_to_kvs / METADATA_FILENAME + + @override + @classmethod + async def open( + cls, + id: str | None, + name: str | None, + storage_dir: Path, + ) -> FileSystemKeyValueStoreClient: + """Open an existing key-value store client or create a new one if it does not exist. + + If the key-value store directory exists, this method reconstructs the client from the metadata file. + Otherwise, a new key-value store client is created with a new unique ID. + + Args: + id: The key-value store ID. + name: The key-value store name; if not provided, defaults to the default name. + storage_dir: The base directory for storage. + + Returns: + A new instance of the file system key-value store client. + """ + if id: + raise ValueError( + 'Opening a key-value store by "id" is not supported for file system storage client, use "name" instead.' + ) + + name = name or cls._DEFAULT_NAME + kvs_path = storage_dir / cls._STORAGE_SUBDIR / name + metadata_path = kvs_path / METADATA_FILENAME + + # If the key-value store directory exists, reconstruct the client from the metadata file. + if kvs_path.exists(): + # If metadata file is missing, raise an error. + if not metadata_path.exists(): + raise ValueError(f'Metadata file not found for key-value store "{name}"') + + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = KeyValueStoreMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for key-value store "{name}"') from exc + + client = cls( + id=metadata.id, + name=name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + storage_dir=storage_dir, + ) + + await client._update_metadata(update_accessed_at=True) + + # Otherwise, create a new key-value store client. + else: + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=datetime.now(timezone.utc), + accessed_at=datetime.now(timezone.utc), + modified_at=datetime.now(timezone.utc), + storage_dir=storage_dir, + ) + await client._update_metadata() + + return client + + @override + async def drop(self) -> None: + # If the key-value store directory exists, remove it recursively. + if self._path_to_kvs.exists(): + async with self._lock: + await asyncio.to_thread(shutil.rmtree, self._path_to_kvs) + + @override + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + record_path = self._path_to_kvs / key + + if not record_path.exists(): + return None + + # Found a file for this key, now look for its metadata + record_metadata_filepath = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + if not record_metadata_filepath.exists(): + logger.warning(f'Found value file for key "{key}" but no metadata file.') + return None + + # Read the metadata file + async with self._lock: + file = await asyncio.to_thread(open, record_metadata_filepath) + try: + metadata_content = json.load(file) + except json.JSONDecodeError: + logger.warning(f'Invalid metadata file for key "{key}"') + return None + finally: + await asyncio.to_thread(file.close) + + try: + metadata = KeyValueStoreRecordMetadata(**metadata_content) + except ValidationError: + logger.warning(f'Invalid metadata schema for key "{key}"') + return None + + # Read the actual value + value_bytes = await asyncio.to_thread(record_path.read_bytes) + + # Handle JSON values + if 'application/json' in metadata.content_type: + try: + value = json.loads(value_bytes.decode('utf-8')) + except (json.JSONDecodeError, UnicodeDecodeError): + logger.warning(f'Failed to decode JSON value for key "{key}"') + return None + + # Handle text values + elif metadata.content_type.startswith('text/'): + try: + value = value_bytes.decode('utf-8') + except UnicodeDecodeError: + logger.warning(f'Failed to decode text value for key "{key}"') + return None + + # Handle binary values + else: + value = value_bytes + + # Update the metadata to record access + await self._update_metadata(update_accessed_at=True) + + # Calculate the size of the value in bytes + size = len(value_bytes) + + return KeyValueStoreRecord( + key=metadata.key, + value=value, + content_type=metadata.content_type, + size=size, + ) + + @override + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + content_type = content_type or self._infer_mime_type(value) + + # Serialize the value to bytes. + if 'application/json' in content_type: + value_bytes = (await json_dumps(value)).encode('utf-8') + elif isinstance(value, str): + value_bytes = value.encode('utf-8') + elif isinstance(value, (bytes, bytearray)): + value_bytes = value + else: + # Fallback: attempt to convert to string and encode. + value_bytes = str(value).encode('utf-8') + + record_path = self._path_to_kvs / key + + # Get the metadata. + # Calculate the size of the value in bytes + size = len(value_bytes) + record_metadata = KeyValueStoreRecordMetadata(key=key, content_type=content_type, size=size) + record_metadata_filepath = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + record_metadata_content = await json_dumps(record_metadata.model_dump()) + + async with self._lock: + # Ensure the key-value store directory exists. + await asyncio.to_thread(self._path_to_kvs.mkdir, parents=True, exist_ok=True) + + # Dump the value to the file. + await asyncio.to_thread(record_path.write_bytes, value_bytes) + + # Dump the record metadata to the file. + await asyncio.to_thread( + record_metadata_filepath.write_text, + record_metadata_content, + encoding='utf-8', + ) + + # Update the KVS metadata to record the access and modification. + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def delete_value(self, *, key: str) -> None: + record_path = self._path_to_kvs / key + metadata_path = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + deleted = False + + async with self._lock: + # Delete the value file and its metadata if found + if record_path.exists(): + await asyncio.to_thread(record_path.unlink) + + # Delete the metadata file if it exists + if metadata_path.exists(): + await asyncio.to_thread(metadata_path.unlink) + else: + logger.warning(f'Found value file for key "{key}" but no metadata file when trying to delete it.') + + deleted = True + + # If we deleted something, update the KVS metadata + if deleted: + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + # Check if the KVS directory exists + if not self._path_to_kvs.exists(): + return + + count = 0 + async with self._lock: + # Get all files in the KVS directory + files = sorted(await asyncio.to_thread(list, self._path_to_kvs.glob('*'))) + + for file_path in files: + # Skip the main metadata file + if file_path.name == METADATA_FILENAME: + continue + + # Only process metadata files for records + if not file_path.name.endswith(f'.{METADATA_FILENAME}'): + continue + + # Extract the base key name from the metadata filename + key_name = file_path.name[: -len(f'.{METADATA_FILENAME}')] + + # Apply exclusive_start_key filter if provided + if exclusive_start_key is not None and key_name <= exclusive_start_key: + continue + + # Try to read and parse the metadata file + try: + metadata_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + metadata_dict = json.loads(metadata_content) + record_metadata = KeyValueStoreRecordMetadata(**metadata_dict) + + yield record_metadata + + count += 1 + if limit and count >= limit: + break + + except (json.JSONDecodeError, ValidationError) as e: + logger.warning(f'Failed to parse metadata file {file_path}: {e}') + + # Update accessed_at timestamp + await self._update_metadata(update_accessed_at=True) + + @override + async def get_public_url(self, *, key: str) -> str: + raise NotImplementedError('Public URLs are not supported for file system key-value stores.') + + async def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the KVS metadata file with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + metadata = KeyValueStoreMetadata( + id=self._id, + name=self._name, + created_at=self._created_at, + accessed_at=now if update_accessed_at else self._accessed_at, + modified_at=now if update_modified_at else self._modified_at, + ) + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self._path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(metadata.model_dump()) + await asyncio.to_thread(self._path_to_metadata.write_text, data, encoding='utf-8') + + def _infer_mime_type(self, value: Any) -> str: + """Infer the MIME content type from the value. + + Args: + value: The value to infer the content type from. + + Returns: + The inferred MIME content type. + """ + # If the value is bytes (or bytearray), return binary content type. + if isinstance(value, (bytes, bytearray)): + return 'application/octet-stream' + + # If the value is a dict or list, assume JSON. + if isinstance(value, (dict, list)): + return 'application/json; charset=utf-8' + + # If the value is a string, assume plain text. + if isinstance(value, str): + return 'text/plain; charset=utf-8' + + # Default fallback. + return 'application/octet-stream' diff --git a/src/crawlee/storage_clients/_file_system/_request_queue.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py similarity index 100% rename from src/crawlee/storage_clients/_file_system/_request_queue.py rename to src/crawlee/storage_clients/_file_system/_request_queue_client.py diff --git a/src/crawlee/storage_clients/_file_system/_storage_client.py b/src/crawlee/storage_clients/_file_system/_storage_client.py index e0420fe653..9d3adefb76 100644 --- a/src/crawlee/storage_clients/_file_system/_storage_client.py +++ b/src/crawlee/storage_clients/_file_system/_storage_client.py @@ -7,8 +7,8 @@ from crawlee.storage_clients._base import StorageClient from ._dataset_client import FileSystemDatasetClient -from ._key_value_store import FileSystemKeyValueStoreClient -from ._request_queue import FileSystemRequestQueueClient +from ._key_value_store_client import FileSystemKeyValueStoreClient +from ._request_queue_client import FileSystemRequestQueueClient if TYPE_CHECKING: from pathlib import Path @@ -26,21 +26,13 @@ async def open_dataset_client( purge_on_start: bool, storage_dir: Path, ) -> FileSystemDatasetClient: - dataset_client = await FileSystemDatasetClient.open( - id=id, - name=name, - storage_dir=storage_dir, - ) + client = await FileSystemDatasetClient.open(id=id, name=name, storage_dir=storage_dir) if purge_on_start: - await dataset_client.drop() - dataset_client = await FileSystemDatasetClient.open( - id=id, - name=name, - storage_dir=storage_dir, - ) + await client.drop() + client = await FileSystemDatasetClient.open(id=id, name=name, storage_dir=storage_dir) - return dataset_client + return client @override async def open_key_value_store_client( @@ -51,7 +43,13 @@ async def open_key_value_store_client( purge_on_start: bool, storage_dir: Path, ) -> FileSystemKeyValueStoreClient: - return FileSystemKeyValueStoreClient() + client = await FileSystemKeyValueStoreClient.open(id=id, name=name, storage_dir=storage_dir) + + if purge_on_start: + await client.drop() + client = await FileSystemKeyValueStoreClient.open(id=id, name=name, storage_dir=storage_dir) + + return client @override async def open_request_queue_client( @@ -62,4 +60,10 @@ async def open_request_queue_client( purge_on_start: bool, storage_dir: Path, ) -> FileSystemRequestQueueClient: - return FileSystemRequestQueueClient() + client = await FileSystemRequestQueueClient.open(id=id, name=name, storage_dir=storage_dir) + + if purge_on_start: + await client.drop() + client = await FileSystemRequestQueueClient.open(id=id, name=name, storage_dir=storage_dir) + + return client diff --git a/src/crawlee/storage_clients/_file_system/_utils.py b/src/crawlee/storage_clients/_file_system/_utils.py index 5ad9121448..c172df50cc 100644 --- a/src/crawlee/storage_clients/_file_system/_utils.py +++ b/src/crawlee/storage_clients/_file_system/_utils.py @@ -2,13 +2,13 @@ import asyncio import json -from logging import getLogger from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Any -logger = getLogger(__name__) +METADATA_FILENAME = '__metadata__.json' +"""The name of the metadata file for storage clients.""" async def json_dumps(obj: Any) -> str: diff --git a/src/crawlee/storage_clients/models.py b/src/crawlee/storage_clients/models.py index f016e24730..8299220475 100644 --- a/src/crawlee/storage_clients/models.py +++ b/src/crawlee/storage_clients/models.py @@ -47,8 +47,6 @@ class KeyValueStoreMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) - user_id: Annotated[str, Field(alias='userId')] - @docs_group('Data structures') class RequestQueueMetadata(StorageMetadata): @@ -61,40 +59,39 @@ class RequestQueueMetadata(StorageMetadata): pending_request_count: Annotated[int, Field(alias='pendingRequestCount')] stats: Annotated[dict, Field(alias='stats')] total_request_count: Annotated[int, Field(alias='totalRequestCount')] - user_id: Annotated[str, Field(alias='userId')] resource_directory: Annotated[str, Field(alias='resourceDirectory')] @docs_group('Data structures') -class KeyValueStoreRecord(BaseModel, Generic[KvsValueType]): - """Model for a key-value store record.""" +class KeyValueStoreRecordMetadata(BaseModel): + """Model for a key-value store record metadata.""" model_config = ConfigDict(populate_by_name=True) key: Annotated[str, Field(alias='key')] - value: Annotated[KvsValueType, Field(alias='value')] - content_type: Annotated[str | None, Field(alias='contentType', default=None)] - filename: Annotated[str | None, Field(alias='filename', default=None)] + """The key of the record. + A unique identifier for the record in the key-value store. + """ -@docs_group('Data structures') -class KeyValueStoreRecordMetadata(BaseModel): - """Model for a key-value store record metadata.""" + content_type: Annotated[str, Field(alias='contentType')] + """The MIME type of the record. - model_config = ConfigDict(populate_by_name=True) + Describe the format and type of data stored in the record, following the MIME specification. + """ - key: Annotated[str, Field(alias='key')] - content_type: Annotated[str, Field(alias='contentType')] + size: Annotated[int, Field(alias='size')] + """The size of the record in bytes.""" @docs_group('Data structures') -class KeyValueStoreKeyInfo(BaseModel): - """Model for a key-value store key info.""" +class KeyValueStoreRecord(KeyValueStoreRecordMetadata, Generic[KvsValueType]): + """Model for a key-value store record.""" model_config = ConfigDict(populate_by_name=True) - key: Annotated[str, Field(alias='key')] - size: Annotated[int, Field(alias='size')] + value: Annotated[KvsValueType, Field(alias='value')] + """The value of the record.""" @docs_group('Data structures') @@ -106,9 +103,9 @@ class KeyValueStoreListKeysPage(BaseModel): count: Annotated[int, Field(alias='count')] limit: Annotated[int, Field(alias='limit')] is_truncated: Annotated[bool, Field(alias='isTruncated')] - items: Annotated[list[KeyValueStoreKeyInfo], Field(alias='items', default_factory=list)] exclusive_start_key: Annotated[str | None, Field(alias='exclusiveStartKey', default=None)] next_exclusive_start_key: Annotated[str | None, Field(alias='nextExclusiveStartKey', default=None)] + items: Annotated[list[KeyValueStoreRecordMetadata], Field(alias='items', default_factory=list)] @docs_group('Data structures') diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index b6a324434b..6ef3c6e4cb 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -129,14 +129,14 @@ async def open( purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start storage_dir = Path(configuration.storage_dir) if storage_dir is None else storage_dir - dataset_client = await storage_client.open_dataset_client( + client = await storage_client.open_dataset_client( id=id, name=name, purge_on_start=purge_on_start, storage_dir=storage_dir, ) - return cls(dataset_client) + return cls(client) async def drop(self) -> None: await self._client.drop() @@ -152,7 +152,7 @@ async def push_data(self, data: list[Any] | dict[str, Any]) -> None: data: A JSON serializable data structure to be stored in the dataset. The JSON representation of each item must be smaller than 9MB. """ - await self._client.push_data(data) + await self._client.push_data(data=data) async def get_data( self, diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index b7d3a4b582..d19430f997 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -1,18 +1,11 @@ from __future__ import annotations -import asyncio -from collections.abc import AsyncIterator -from datetime import datetime, timezone +from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload -from typing_extensions import override - from crawlee import service_locator from crawlee._utils.docs import docs_group -from crawlee.events._types import Event, EventPersistStateData -from crawlee.storage_clients.models import KeyValueStoreKeyInfo, KeyValueStoreMetadata, StorageMetadata - -from ._base import Storage +from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata if TYPE_CHECKING: from collections.abc import AsyncIterator @@ -20,12 +13,40 @@ from crawlee._types import JsonSerializable from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient + from crawlee.storage_clients._base import KeyValueStoreClient T = TypeVar('T') +# TODO: +# - inherit from storage class +# - caching / memoization of both KVS & KVS clients + +# Suggested KVS breaking changes: +# - from_storage_object method has been removed - Use the open method with name and/or id instead. +# - get_info -> metadata property +# - storage_object -> metadata property +# - set_metadata method has been removed - Do we want to support it (e.g. for renaming)? +# - get_auto_saved_value method has been removed -> It should be managed by the underlying client. +# - persist_autosaved_values method has been removed -> It should be managed by the underlying client. + +# Properties: +# - id +# - name +# - metadata + +# Methods: +# - open +# - drop +# - get_value +# - set_value +# - delete_value (new method) +# - iterate_keys +# - list_keys (new method) +# - get_public_url + @docs_group('Classes') -class KeyValueStore(Storage): +class KeyValueStore: """Represents a key-value based storage for reading and writing data records or files. Each data record is identified by a unique key and associated with a specific MIME content type. This class is @@ -63,84 +84,64 @@ class KeyValueStore(Storage): _general_cache: ClassVar[dict[str, dict[str, dict[str, JsonSerializable]]]] = {} _persist_state_event_started = False - def __init__(self, id: str, name: str | None, storage_client: StorageClient) -> None: - self._id = id - self._name = name - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) + def __init__(self, client: KeyValueStoreClient) -> None: + """Initialize a new instance. - # Get resource clients from storage client - self._resource_client = storage_client.key_value_store(self._id) - self._autosave_lock = asyncio.Lock() - - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> KeyValueStore: - """Initialize a new instance of KeyValueStore from a storage metadata object.""" - key_value_store = KeyValueStore( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) + Preferably use the `KeyValueStore.open` constructor to create a new instance. - key_value_store.storage_object = storage_object - return key_value_store + Args: + client: An instance of a key-value store client. + """ + self._client = client @property - @override def id(self) -> str: - return self._id + return self._client.id @property - @override def name(self) -> str | None: - return self._name + return self._client.name @property - @override - def storage_object(self) -> StorageMetadata: - return self._storage_object - - @storage_object.setter - @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object - - async def get_info(self) -> KeyValueStoreMetadata | None: - """Get an object containing general information about the key value store.""" - return await self._resource_client.get() + def metadata(self) -> KeyValueStoreMetadata: + return KeyValueStoreMetadata( + id=self._client.id, + name=self._client.id, + accessed_at=self._client.accessed_at, + created_at=self._client.created_at, + modified_at=self._client.modified_at, + ) - @override @classmethod async def open( cls, *, id: str | None = None, name: str | None = None, + purge_on_start: bool | None = None, + storage_dir: Path | None = None, configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> KeyValueStore: - from crawlee.storages._creation_management import open_storage + if id and name: + raise ValueError('Only one of "id" or "name" can be specified, not both.') - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client + purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start + storage_dir = Path(configuration.storage_dir) if storage_dir is None else storage_dir - return await open_storage( - storage_class=cls, + client = await storage_client.open_key_value_store_client( id=id, name=name, - configuration=configuration, - storage_client=storage_client, + purge_on_start=purge_on_start, + storage_dir=storage_dir, ) - @override - async def drop(self) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + return cls(client) - await self._resource_client.delete() - self._clear_cache() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + async def drop(self) -> None: + await self._client.drop() @overload async def get_value(self, key: str) -> Any: ... @@ -161,27 +162,9 @@ async def get_value(self, key: str, default_value: T | None = None) -> T | None: Returns: The value associated with the given key. `default_value` is used in case the record does not exist. """ - record = await self._resource_client.get_record(key) + record = await self._client.get_value(key=key) return record.value if record else default_value - async def iterate_keys(self, exclusive_start_key: str | None = None) -> AsyncIterator[KeyValueStoreKeyInfo]: - """Iterate over the existing keys in the KVS. - - Args: - exclusive_start_key: Key to start the iteration from. - - Yields: - Information about the key. - """ - while True: - list_keys = await self._resource_client.list_keys(exclusive_start_key=exclusive_start_key) - for item in list_keys.items: - yield KeyValueStoreKeyInfo(key=item.key, size=item.size) - - if not list_keys.is_truncated: - break - exclusive_start_key = list_keys.next_exclusive_start_key - async def set_value( self, key: str, @@ -192,91 +175,70 @@ async def set_value( Args: key: Key of the record to set. - value: Value to set. If `None`, the record is deleted. - content_type: Content type of the record. + value: Value to set. + content_type: The MIME content type string. """ - if value is None: - return await self._resource_client.delete_record(key) - - return await self._resource_client.set_record(key, value, content_type) + await self._client.set_value(key=key, value=value, content_type=content_type) - async def get_public_url(self, key: str) -> str: - """Get the public URL for the given key. + async def delete_value(self, key: str) -> None: + """Delete a value from the KVS. Args: - key: Key of the record for which URL is required. - - Returns: - The public URL for the given key. + key: Key of the record to delete. """ - return await self._resource_client.get_public_url(key) + await self._client.delete_value(key=key) - async def get_auto_saved_value( + async def iterate_keys( self, - key: str, - default_value: dict[str, JsonSerializable] | None = None, - ) -> dict[str, JsonSerializable]: - """Get a value from KVS that will be automatically saved on changes. + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + """Iterate over the existing keys in the KVS. Args: - key: Key of the record, to store the value. - default_value: Value to be used if the record does not exist yet. Should be a dictionary. + exclusive_start_key: Key to start the iteration from. + limit: Maximum number of keys to return. None means no limit. - Returns: - Return the value of the key. + Yields: + Information about the key. """ - default_value = {} if default_value is None else default_value + async for item in self._client.iterate_keys( + exclusive_start_key=exclusive_start_key, + limit=limit, + ): + yield item - async with self._autosave_lock: - if key in self._cache: - return self._cache[key] + async def list_keys( + self, + exclusive_start_key: str | None = None, + limit: int = 1000, + ) -> list[KeyValueStoreRecordMetadata]: + """List all the existing keys in the KVS. - value = await self.get_value(key, default_value) + It uses client's `iterate_keys` method to get the keys. - if not isinstance(value, dict): - raise TypeError( - f'Expected dictionary for persist state value at key "{key}, but got {type(value).__name__}' - ) + Args: + exclusive_start_key: Key to start the iteration from. + limit: Maximum number of keys to return. - self._cache[key] = value + Returns: + A list of keys in the KVS. + """ + return [ + key + async for key in self._client.iterate_keys( + exclusive_start_key=exclusive_start_key, + limit=limit, + ) + ] - self._ensure_persist_event() + async def get_public_url(self, key: str) -> str: + """Get the public URL for the given key. - return value + Args: + key: Key of the record for which URL is required. - @property - def _cache(self) -> dict[str, dict[str, JsonSerializable]]: - """Cache dictionary for storing auto-saved values indexed by store ID.""" - if self._id not in self._general_cache: - self._general_cache[self._id] = {} - return self._general_cache[self._id] - - async def _persist_save(self, _event_data: EventPersistStateData | None = None) -> None: - """Save cache with persistent values. Can be used in Event Manager.""" - for key, value in self._cache.items(): - await self.set_value(key, value) - - def _ensure_persist_event(self) -> None: - """Ensure persist state event handling if not already done.""" - if self._persist_state_event_started: - return - - event_manager = service_locator.get_event_manager() - event_manager.on(event=Event.PERSIST_STATE, listener=self._persist_save) - self._persist_state_event_started = True - - def _clear_cache(self) -> None: - """Clear cache with persistent values.""" - self._cache.clear() - - def _drop_persist_state_event(self) -> None: - """Off event_manager listener and drop event status.""" - if self._persist_state_event_started: - event_manager = service_locator.get_event_manager() - event_manager.off(event=Event.PERSIST_STATE, listener=self._persist_save) - self._persist_state_event_started = False - - async def persist_autosaved_values(self) -> None: - """Force persistent values to be saved without waiting for an event in Event Manager.""" - if self._persist_state_event_started: - await self._persist_save() + Returns: + The public URL for the given key. + """ + return await self._client.get_public_url(key=key) diff --git a/tests/unit/storage_clients/_memory/test_key_value_store_client.py b/tests/unit/storage_clients/_memory/test_key_value_store_client.py index 26d1f8f974..c7813b5b84 100644 --- a/tests/unit/storage_clients/_memory/test_key_value_store_client.py +++ b/tests/unit/storage_clients/_memory/test_key_value_store_client.py @@ -399,7 +399,6 @@ async def test_reads_correct_metadata( accessed_at=datetime.now(timezone.utc), created_at=datetime.now(timezone.utc), modified_at=datetime.now(timezone.utc), - user_id='1', ) # Write the store metadata to disk From 32dfdf6fdecd44f939e1c12cae689f66c144c421 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Fri, 11 Apr 2025 09:06:40 +0200 Subject: [PATCH 05/22] Dataset export methods --- src/crawlee/_cli.py | 2 +- src/crawlee/_utils/file.py | 44 ++++- src/crawlee/crawlers/_basic/_basic_crawler.py | 75 ++------- .../storage_clients/_base/_dataset_client.py | 16 ++ .../_file_system/_dataset_client.py | 2 +- .../_memory/_dataset_client.py | 4 +- src/crawlee/storages/_dataset.py | 158 ++++++++---------- src/crawlee/storages/_key_value_store.py | 15 +- 8 files changed, 158 insertions(+), 158 deletions(-) diff --git a/src/crawlee/_cli.py b/src/crawlee/_cli.py index 689cb7b182..d8b295b5ed 100644 --- a/src/crawlee/_cli.py +++ b/src/crawlee/_cli.py @@ -22,7 +22,7 @@ cli = typer.Typer(no_args_is_help=True) template_directory = importlib.resources.files('crawlee') / 'project_template' -with open(str(template_directory / 'cookiecutter.json')) as f: +with (template_directory / 'cookiecutter.json').open() as f: cookiecutter_json = json.load(f) crawler_choices = cookiecutter_json['crawler_type'] diff --git a/src/crawlee/_utils/file.py b/src/crawlee/_utils/file.py index d50f6ecd41..c74bdbb771 100644 --- a/src/crawlee/_utils/file.py +++ b/src/crawlee/_utils/file.py @@ -2,17 +2,26 @@ import asyncio import contextlib +import csv import json import mimetypes import os import re import shutil from enum import Enum +from logging import getLogger from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import AsyncIterator from pathlib import Path - from typing import Any + from typing import Any, TextIO + + from typing_extensions import Unpack + + from crawlee.storages._types import ExportDataCsvKwargs, ExportDataJsonKwargs + +logger = getLogger(__name__) class ContentType(Enum): @@ -92,3 +101,36 @@ async def json_dumps(obj: Any) -> str: A string containing the JSON representation of the input object. """ return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) + + +async def export_json_to_stream( + iterator: AsyncIterator[dict], + dst: TextIO, + **kwargs: Unpack[ExportDataJsonKwargs], +) -> None: + items = [item async for item in iterator] + + if items: + json.dump(items, dst, **kwargs) + else: + logger.warning('Attempting to export an empty dataset - no file will be created') + + +async def export_csv_to_stream( + iterator: AsyncIterator[dict], + dst: TextIO, + **kwargs: Unpack[ExportDataCsvKwargs], +) -> None: + writer = csv.writer(dst, **kwargs) + write_header = True + + # Iterate over the dataset and write to CSV. + async for item in iterator: + if not item: + continue + + if write_header: + writer.writerow(item.keys()) + write_header = False + + writer.writerow(item.values()) diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index c69db280a6..bb79761c92 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -32,6 +32,7 @@ SendRequestFunction, ) from crawlee._utils.docs import docs_group +from crawlee._utils.file import export_csv_to_stream, export_json_to_stream from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute from crawlee._utils.wait import wait_for from crawlee._utils.web import is_status_code_client_error, is_status_code_server_error @@ -57,7 +58,7 @@ import re from contextlib import AbstractAsyncContextManager - from crawlee._types import ConcurrencySettings, HttpMethod, JsonSerializable + from crawlee._types import ConcurrencySettings, HttpMethod, JsonSerializable, PushDataKwargs from crawlee.configuration import Configuration from crawlee.events import EventManager from crawlee.http_clients import HttpClient, HttpResponse @@ -67,7 +68,7 @@ from crawlee.statistics import FinalStatistics from crawlee.storage_clients import StorageClient from crawlee.storage_clients.models import DatasetItemsListPage - from crawlee.storages._dataset import ExportDataCsvKwargs, ExportDataJsonKwargs, GetDataKwargs, PushDataKwargs + from crawlee.storages._types import GetDataKwargs TCrawlingContext = TypeVar('TCrawlingContext', bound=BasicCrawlingContext, default=BasicCrawlingContext) TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState) @@ -655,13 +656,18 @@ async def add_requests( wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout, ) - async def _use_state(self, default_value: dict[str, JsonSerializable] | None = None) -> dict[str, JsonSerializable]: - store = await self.get_key_value_store() - return await store.get_auto_saved_value(self._CRAWLEE_STATE_KEY, default_value) + async def _use_state( + self, + default_value: dict[str, JsonSerializable] | None = None, + ) -> dict[str, JsonSerializable]: + kvs = await self.get_key_value_store() + # TODO: + # return some kvs value async def _save_crawler_state(self) -> None: - store = await self.get_key_value_store() - await store.persist_autosaved_values() + kvs = await self.get_key_value_store() + # TODO: + # some kvs call async def get_data( self, @@ -705,64 +711,15 @@ async def export_data( dataset = await self.get_dataset(id=dataset_id, name=dataset_name) path = path if isinstance(path, Path) else Path(path) - destination = path.open('w', newline='') + dst = path.open('w', newline='') if path.suffix == '.csv': - await dataset.write_to_csv(destination) + await export_csv_to_stream(dataset.iterate(), dst) elif path.suffix == '.json': - await dataset.write_to_json(destination) + await export_json_to_stream(dataset.iterate(), dst) else: raise ValueError(f'Unsupported file extension: {path.suffix}') - async def export_data_csv( - self, - path: str | Path, - *, - dataset_id: str | None = None, - dataset_name: str | None = None, - **kwargs: Unpack[ExportDataCsvKwargs], - ) -> None: - """Export data from a `Dataset` to a CSV file. - - This helper method simplifies the process of exporting data from a `Dataset` in csv format. It opens - the specified one and then exports the data based on the provided parameters. - - Args: - path: The destination path. - content_type: The output format. - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. - kwargs: Extra configurations for dumping/writing in csv format. - """ - dataset = await self.get_dataset(id=dataset_id, name=dataset_name) - path = path if isinstance(path, Path) else Path(path) - - return await dataset.write_to_csv(path.open('w', newline=''), **kwargs) - - async def export_data_json( - self, - path: str | Path, - *, - dataset_id: str | None = None, - dataset_name: str | None = None, - **kwargs: Unpack[ExportDataJsonKwargs], - ) -> None: - """Export data from a `Dataset` to a JSON file. - - This helper method simplifies the process of exporting data from a `Dataset` in json format. It opens the - specified one and then exports the data based on the provided parameters. - - Args: - path: The destination path - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. - kwargs: Extra configurations for dumping/writing in json format. - """ - dataset = await self.get_dataset(id=dataset_id, name=dataset_name) - path = path if isinstance(path, Path) else Path(path) - - return await dataset.write_to_json(path.open('w', newline=''), **kwargs) - async def _push_data( self, data: JsonSerializable, diff --git a/src/crawlee/storage_clients/_base/_dataset_client.py b/src/crawlee/storage_clients/_base/_dataset_client.py index 265856b3ff..b9b6767310 100644 --- a/src/crawlee/storage_clients/_base/_dataset_client.py +++ b/src/crawlee/storage_clients/_base/_dataset_client.py @@ -14,6 +14,22 @@ from crawlee.storage_clients.models import DatasetItemsListPage +# Properties: +# - id +# - name +# - created_at +# - accessed_at +# - modified_at +# - item_count + +# Methods: +# - open +# - drop +# - push_data +# - get_data +# - iterate + + @docs_group('Abstract classes') class DatasetClient(ABC): """An abstract class for dataset resource clients. diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py index 63103ff310..c5693c46c8 100644 --- a/src/crawlee/storage_clients/_file_system/_dataset_client.py +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -219,7 +219,7 @@ async def get_data( invalid = [arg for arg in unsupported_args if arg not in (False, None)] if invalid: logger.warning( - f'The arguments {invalid} of iterate_items are not supported by the {self.__class__.__name__} client.' + f'The arguments {invalid} of get_data are not supported by the {self.__class__.__name__} client.' ) # If the dataset directory does not exist, log a warning and return an empty page. diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 279be563c9..6ffa22f028 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -135,7 +135,7 @@ async def get_data( invalid = [arg for arg in unsupported_args if arg not in (False, None)] if invalid: logger.warning( - f'The arguments {invalid} of iterate_items are not supported by the {self.__class__.__name__} client.' + f'The arguments {invalid} of get_data are not supported by the {self.__class__.__name__} client.' ) total = len(self._records) @@ -172,7 +172,7 @@ async def iterate( invalid = [arg for arg in unsupported_args if arg not in (False, None)] if invalid: logger.warning( - f'The arguments {invalid} of iterate_items are not supported by the {self.__class__.__name__} client.' + f'The arguments {invalid} of iterate are not supported by the {self.__class__.__name__} client.' ) items = self._records.copy() diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 6ef3c6e4cb..12f74c2062 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -1,14 +1,13 @@ from __future__ import annotations -import csv -import io -import json import logging +from io import StringIO from pathlib import Path -from typing import TYPE_CHECKING, TextIO, cast +from typing import TYPE_CHECKING, Literal from crawlee import service_locator from crawlee._utils.docs import docs_group +from crawlee._utils.file import export_csv_to_stream, export_json_to_stream from crawlee.storage_clients.models import DatasetMetadata from ._key_value_store import KeyValueStore @@ -24,28 +23,36 @@ from crawlee.storage_clients._base import DatasetClient from crawlee.storage_clients.models import DatasetItemsListPage - from ._types import ExportDataCsvKwargs, ExportDataJsonKwargs, ExportToKwargs + from ._types import ExportDataCsvKwargs, ExportDataJsonKwargs logger = logging.getLogger(__name__) # TODO: # - inherit from storage class -# - export methods # - caching / memoization of both datasets & dataset clients -# Dataset -# - properties: -# - id -# - name -# - metadata -# - methods: -# - open -# - drop -# - push_data -# - get_data -# - iterate -# - export_to_csv -# - export_to_json +# Properties: +# - id +# - name +# - metadata + +# Methods: +# - open +# - drop +# - push_data +# - get_data +# - iterate +# - export_to +# - export_to_json +# - export_to_csv + +# Breaking changes: +# - from_storage_object method has been removed - Use the open method with name and/or id instead. +# - get_info -> metadata property +# - storage_object -> metadata property +# - set_metadata method has been removed - Do we want to support it (e.g. for renaming)? +# - write_to_json -> export_to_json +# - write_to_csv -> export_to_csv @docs_group('Classes') @@ -253,8 +260,13 @@ async def iterate( ): yield item - # TODO: update this once KVS is implemented - async def export_to(self, **kwargs: Unpack[ExportToKwargs]) -> None: + async def export_to( + self, + key: str, + content_type: Literal['json', 'csv'] = 'json', + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, + ) -> None: """Export the entire dataset into a specified file stored under a key in a key-value store. This method consolidates all entries from a specified dataset into one file, which is then saved under a @@ -263,74 +275,48 @@ async def export_to(self, **kwargs: Unpack[ExportToKwargs]) -> None: name should be used. Args: - kwargs: Keyword arguments for the storage client method. + key: The key under which to save the data in the key-value store. + content_type: The format in which to export the data. + to_key_value_store_id: ID of the key-value store to save the exported file. + Specify only one of ID or name. + to_key_value_store_name: Name of the key-value store to save the exported file. + Specify only one of ID or name. """ - key = cast('str', kwargs.get('key')) - content_type = kwargs.get('content_type', 'json') - to_key_value_store_id = kwargs.get('to_key_value_store_id') - to_key_value_store_name = kwargs.get('to_key_value_store_name') - - key_value_store = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) - - output = io.StringIO() if content_type == 'csv': - await self.write_to_csv(output) + await self.export_to_csv( + key, + to_key_value_store_id, + to_key_value_store_name, + ) elif content_type == 'json': - await self.write_to_json(output) + await self.export_to_json( + key, + to_key_value_store_id, + to_key_value_store_name, + ) else: raise ValueError('Unsupported content type, expecting CSV or JSON') - if content_type == 'csv': - await key_value_store.set_value(key, output.getvalue(), 'text/csv') - - if content_type == 'json': - await key_value_store.set_value(key, output.getvalue(), 'application/json') - - # TODO: update this once KVS is implemented - async def write_to_csv(self, destination: TextIO, **kwargs: Unpack[ExportDataCsvKwargs]) -> None: - """Export the entire dataset into an arbitrary stream. - - Args: - destination: The stream into which the dataset contents should be written. - kwargs: Additional keyword arguments for `csv.writer`. - """ - items: list[dict] = [] - limit = 1000 - offset = 0 - - while True: - list_items = await self._client.get_data(limit=limit, offset=offset) - items.extend(list_items.items) - if list_items.total <= offset + list_items.count: - break - offset += list_items.count - - if items: - writer = csv.writer(destination, **kwargs) - writer.writerows([items[0].keys(), *[item.values() for item in items]]) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') - - # TODO: update this once KVS is implemented - async def write_to_json(self, destination: TextIO, **kwargs: Unpack[ExportDataJsonKwargs]) -> None: - """Export the entire dataset into an arbitrary stream. - - Args: - destination: The stream into which the dataset contents should be written. - kwargs: Additional keyword arguments for `json.dump`. - """ - items: list[dict] = [] - limit = 1000 - offset = 0 - - while True: - list_items = await self._client.get_data(limit=limit, offset=offset) - items.extend(list_items.items) - if list_items.total <= offset + list_items.count: - break - offset += list_items.count - - if items: - json.dump(items, destination, **kwargs) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') + async def export_to_json( + self, + key: str, + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, + **kwargs: Unpack[ExportDataJsonKwargs], + ) -> None: + kvs = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) + dst = StringIO() + await export_json_to_stream(self.iterate(), dst, **kwargs) + await kvs.set_value(key, dst.getvalue(), 'application/json') + + async def export_to_csv( + self, + key: str, + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, + **kwargs: Unpack[ExportDataCsvKwargs], + ) -> None: + kvs = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) + dst = StringIO() + await export_csv_to_stream(self.iterate(), dst, **kwargs) + await kvs.set_value(key, dst.getvalue(), 'text/csv') diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index d19430f997..3d06ebba40 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -21,14 +21,6 @@ # - inherit from storage class # - caching / memoization of both KVS & KVS clients -# Suggested KVS breaking changes: -# - from_storage_object method has been removed - Use the open method with name and/or id instead. -# - get_info -> metadata property -# - storage_object -> metadata property -# - set_metadata method has been removed - Do we want to support it (e.g. for renaming)? -# - get_auto_saved_value method has been removed -> It should be managed by the underlying client. -# - persist_autosaved_values method has been removed -> It should be managed by the underlying client. - # Properties: # - id # - name @@ -44,6 +36,13 @@ # - list_keys (new method) # - get_public_url +# Breaking changes: +# - from_storage_object method has been removed - Use the open method with name and/or id instead. +# - get_info -> metadata property +# - storage_object -> metadata property +# - set_metadata method has been removed - Do we want to support it (e.g. for renaming)? +# - get_auto_saved_value method has been removed -> It should be managed by the underlying client. +# - persist_autosaved_values method has been removed -> It should be managed by the underlying client. @docs_group('Classes') class KeyValueStore: From eed4c5a01972affa74242632a8847d78b9029964 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Fri, 11 Apr 2025 09:42:47 +0200 Subject: [PATCH 06/22] inherit from Storage class and RQ init --- .../_base/_request_queue_client.py | 67 ++- src/crawlee/storages/_base.py | 11 +- src/crawlee/storages/_creation_management.py | 210 ------- src/crawlee/storages/_dataset.py | 11 +- src/crawlee/storages/_key_value_store.py | 12 +- src/crawlee/storages/_request_queue.py | 558 +++--------------- 6 files changed, 165 insertions(+), 704 deletions(-) delete mode 100644 src/crawlee/storages/_creation_management.py diff --git a/src/crawlee/storage_clients/_base/_request_queue_client.py b/src/crawlee/storage_clients/_base/_request_queue_client.py index f43766461c..0d7c2ddb45 100644 --- a/src/crawlee/storage_clients/_base/_request_queue_client.py +++ b/src/crawlee/storage_clients/_base/_request_queue_client.py @@ -1,12 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod +from datetime import datetime from typing import TYPE_CHECKING from crawlee._utils.docs import docs_group if TYPE_CHECKING: from collections.abc import Sequence + from datetime import datetime from crawlee.storage_clients.models import ( BatchRequestsOperationResponse, @@ -15,7 +17,6 @@ Request, RequestQueueHead, RequestQueueHeadWithLocks, - RequestQueueMetadata, ) @@ -27,17 +28,67 @@ class RequestQueueClient(ABC): client, like a memory storage client. """ + @property @abstractmethod - async def get(self) -> RequestQueueMetadata | None: - """Get metadata about the request queue being managed by this client. + def id(self) -> str: + """The ID of the dataset.""" - Returns: - An object containing the request queue's details, or None if the request queue does not exist. - """ + @property + @abstractmethod + def name(self) -> str | None: + """The name of the dataset.""" + + @property + @abstractmethod + def created_at(self) -> datetime: + """The time at which the dataset was created.""" + + @property + @abstractmethod + def accessed_at(self) -> datetime: + """The time at which the dataset was last accessed.""" + + @property + @abstractmethod + def modified_at(self) -> datetime: + """The time at which the dataset was last modified.""" + @property @abstractmethod - async def delete(self) -> None: - """Permanently delete the request queue managed by this client.""" + def had_multiple_clients(self) -> bool: + """TODO.""" + + @property + @abstractmethod + def handled_request_count(self) -> int: + """TODO.""" + + @property + @abstractmethod + def pending_request_count(self) -> int: + """TODO.""" + + @property + @abstractmethod + def stats(self) -> dict: + """TODO.""" + + @property + @abstractmethod + def total_request_count(self) -> int: + """TODO.""" + + @property + @abstractmethod + def resource_directory(self) -> str: + """TODO.""" + + @abstractmethod + async def drop(self) -> None: + """Drop the whole request queue and remove all its values. + + The backend method for the `RequestQueue.drop` call. + """ @abstractmethod async def list_head(self, *, limit: int | None = None) -> RequestQueueHead: diff --git a/src/crawlee/storages/_base.py b/src/crawlee/storages/_base.py index 08d2cbd7be..8e73326041 100644 --- a/src/crawlee/storages/_base.py +++ b/src/crawlee/storages/_base.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from crawlee.configuration import Configuration from crawlee.storage_clients._base import StorageClient - from crawlee.storage_clients.models import StorageMetadata + from crawlee.storage_clients.models import DatasetMetadata, KeyValueStoreMetadata, RequestQueueMetadata class Storage(ABC): @@ -24,13 +24,8 @@ def name(self) -> str | None: @property @abstractmethod - def storage_object(self) -> StorageMetadata: - """Get the full storage object.""" - - @storage_object.setter - @abstractmethod - def storage_object(self, storage_object: StorageMetadata) -> None: - """Set the full storage object.""" + def metadata(self) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata: + """Get the storage metadata.""" @classmethod @abstractmethod diff --git a/src/crawlee/storages/_creation_management.py b/src/crawlee/storages/_creation_management.py deleted file mode 100644 index 9137e39512..0000000000 --- a/src/crawlee/storages/_creation_management.py +++ /dev/null @@ -1,210 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING, TypeVar - -from crawlee.storage_clients import MemoryStorageClient - -from ._dataset import Dataset -from ._key_value_store import KeyValueStore -from ._request_queue import RequestQueue - -if TYPE_CHECKING: - from crawlee.configuration import Configuration - from crawlee.storage_clients._base import ResourceClient, StorageClient - -TResource = TypeVar('TResource', Dataset, KeyValueStore, RequestQueue) - - -_creation_lock = asyncio.Lock() -"""Lock for storage creation.""" - -_cache_dataset_by_id: dict[str, Dataset] = {} -_cache_dataset_by_name: dict[str, Dataset] = {} -_cache_kvs_by_id: dict[str, KeyValueStore] = {} -_cache_kvs_by_name: dict[str, KeyValueStore] = {} -_cache_rq_by_id: dict[str, RequestQueue] = {} -_cache_rq_by_name: dict[str, RequestQueue] = {} - - -def _get_from_cache_by_name( - storage_class: type[TResource], - name: str, -) -> TResource | None: - """Try to restore storage from cache by name.""" - if issubclass(storage_class, Dataset): - return _cache_dataset_by_name.get(name) - if issubclass(storage_class, KeyValueStore): - return _cache_kvs_by_name.get(name) - if issubclass(storage_class, RequestQueue): - return _cache_rq_by_name.get(name) - raise ValueError(f'Unknown storage class: {storage_class.__name__}') - - -def _get_from_cache_by_id( - storage_class: type[TResource], - id: str, -) -> TResource | None: - """Try to restore storage from cache by ID.""" - if issubclass(storage_class, Dataset): - return _cache_dataset_by_id.get(id) - if issubclass(storage_class, KeyValueStore): - return _cache_kvs_by_id.get(id) - if issubclass(storage_class, RequestQueue): - return _cache_rq_by_id.get(id) - raise ValueError(f'Unknown storage: {storage_class.__name__}') - - -def _add_to_cache_by_name(name: str, storage: TResource) -> None: - """Add storage to cache by name.""" - if isinstance(storage, Dataset): - _cache_dataset_by_name[name] = storage - elif isinstance(storage, KeyValueStore): - _cache_kvs_by_name[name] = storage - elif isinstance(storage, RequestQueue): - _cache_rq_by_name[name] = storage - else: - raise TypeError(f'Unknown storage: {storage}') - - -def _add_to_cache_by_id(id: str, storage: TResource) -> None: - """Add storage to cache by ID.""" - if isinstance(storage, Dataset): - _cache_dataset_by_id[id] = storage - elif isinstance(storage, KeyValueStore): - _cache_kvs_by_id[id] = storage - elif isinstance(storage, RequestQueue): - _cache_rq_by_id[id] = storage - else: - raise TypeError(f'Unknown storage: {storage}') - - -def _rm_from_cache_by_id(storage_class: type, id: str) -> None: - """Remove a storage from cache by ID.""" - try: - if issubclass(storage_class, Dataset): - del _cache_dataset_by_id[id] - elif issubclass(storage_class, KeyValueStore): - del _cache_kvs_by_id[id] - elif issubclass(storage_class, RequestQueue): - del _cache_rq_by_id[id] - else: - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - except KeyError as exc: - raise RuntimeError(f'Storage with provided ID was not found ({id}).') from exc - - -def _rm_from_cache_by_name(storage_class: type, name: str) -> None: - """Remove a storage from cache by name.""" - try: - if issubclass(storage_class, Dataset): - del _cache_dataset_by_name[name] - elif issubclass(storage_class, KeyValueStore): - del _cache_kvs_by_name[name] - elif issubclass(storage_class, RequestQueue): - del _cache_rq_by_name[name] - else: - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - except KeyError as exc: - raise RuntimeError(f'Storage with provided name was not found ({name}).') from exc - - -def _get_default_storage_id(configuration: Configuration, storage_class: type[TResource]) -> str: - if issubclass(storage_class, Dataset): - return configuration.default_dataset_id - if issubclass(storage_class, KeyValueStore): - return configuration.default_key_value_store_id - if issubclass(storage_class, RequestQueue): - return configuration.default_request_queue_id - - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - - -async def open_storage( - *, - storage_class: type[TResource], - id: str | None, - name: str | None, - configuration: Configuration, - storage_client: StorageClient, -) -> TResource: - """Open either a new storage or restore an existing one and return it.""" - # Try to restore the storage from cache by name - if name: - cached_storage = _get_from_cache_by_name(storage_class=storage_class, name=name) - if cached_storage: - return cached_storage - - default_id = _get_default_storage_id(configuration, storage_class) - - if not id and not name: - id = default_id - - # Find out if the storage is a default on memory storage - is_default_on_memory = id == default_id and isinstance(storage_client, MemoryStorageClient) - - # Try to restore storage from cache by ID - if id: - cached_storage = _get_from_cache_by_id(storage_class=storage_class, id=id) - if cached_storage: - return cached_storage - - # Purge on start if configured - if configuration.purge_on_start: - await storage_client.purge_on_start() - - # Lock and create new storage - async with _creation_lock: - if id and not is_default_on_memory: - resource_client = _get_resource_client(storage_class, storage_client, id) - storage_object = await resource_client.get() - if not storage_object: - raise RuntimeError(f'{storage_class.__name__} with id "{id}" does not exist!') - - elif is_default_on_memory: - resource_client = _get_resource_client(storage_class, storage_client) - storage_object = await resource_client.get_or_create(name=name, id=id) - - else: - resource_client = _get_resource_client(storage_class, storage_client) - storage_object = await resource_client.get_or_create(name=name) - - storage = storage_class.from_storage_object(storage_client=storage_client, storage_object=storage_object) - - # Cache the storage by ID and name - _add_to_cache_by_id(storage.id, storage) - if storage.name is not None: - _add_to_cache_by_name(storage.name, storage) - - return storage - - -def remove_storage_from_cache( - *, - storage_class: type, - id: str | None = None, - name: str | None = None, -) -> None: - """Remove a storage from cache by ID or name.""" - if id: - _rm_from_cache_by_id(storage_class=storage_class, id=id) - - if name: - _rm_from_cache_by_name(storage_class=storage_class, name=name) - - -def _get_resource_client( - storage_class: type[TResource], - storage_client: StorageClient, - id: str, -) -> ResourceClient: - if issubclass(storage_class, Dataset): - return storage_client.dataset(id) - - if issubclass(storage_class, KeyValueStore): - return storage_client.key_value_store(id) - - if issubclass(storage_class, RequestQueue): - return storage_client.request_queue(id) - - raise ValueError(f'Unknown storage class label: {storage_class.__name__}') diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 12f74c2062..6b6160f471 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -5,11 +5,14 @@ from pathlib import Path from typing import TYPE_CHECKING, Literal +from typing_extensions import override + from crawlee import service_locator from crawlee._utils.docs import docs_group from crawlee._utils.file import export_csv_to_stream, export_json_to_stream from crawlee.storage_clients.models import DatasetMetadata +from ._base import Storage from ._key_value_store import KeyValueStore if TYPE_CHECKING: @@ -28,7 +31,6 @@ logger = logging.getLogger(__name__) # TODO: -# - inherit from storage class # - caching / memoization of both datasets & dataset clients # Properties: @@ -56,7 +58,7 @@ @docs_group('Classes') -class Dataset: +class Dataset(Storage): """Dataset is an append-only structured storage, ideal for tabular data similar to database tables. The `Dataset` class is designed to store structured data, where each entry (row) maintains consistent attributes @@ -98,14 +100,17 @@ def __init__(self, client: DatasetClient) -> None: """ self._client = client + @override @property def id(self) -> str: return self._client.id + @override @property def name(self) -> str | None: return self._client.name + @override @property def metadata(self) -> DatasetMetadata: return DatasetMetadata( @@ -117,6 +122,7 @@ def metadata(self) -> DatasetMetadata: item_count=self._client.item_count, ) + @override @classmethod async def open( cls, @@ -145,6 +151,7 @@ async def open( return cls(client) + @override async def drop(self) -> None: await self._client.drop() diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 3d06ebba40..b99957bb19 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -3,10 +3,14 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload +from typing_extensions import override + from crawlee import service_locator from crawlee._utils.docs import docs_group from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata +from ._base import Storage + if TYPE_CHECKING: from collections.abc import AsyncIterator @@ -18,7 +22,6 @@ T = TypeVar('T') # TODO: -# - inherit from storage class # - caching / memoization of both KVS & KVS clients # Properties: @@ -45,7 +48,7 @@ # - persist_autosaved_values method has been removed -> It should be managed by the underlying client. @docs_group('Classes') -class KeyValueStore: +class KeyValueStore(Storage): """Represents a key-value based storage for reading and writing data records or files. Each data record is identified by a unique key and associated with a specific MIME content type. This class is @@ -93,14 +96,17 @@ def __init__(self, client: KeyValueStoreClient) -> None: """ self._client = client + @override @property def id(self) -> str: return self._client.id + @override @property def name(self) -> str | None: return self._client.name + @override @property def metadata(self) -> KeyValueStoreMetadata: return KeyValueStoreMetadata( @@ -111,6 +117,7 @@ def metadata(self) -> KeyValueStoreMetadata: modified_at=self._client.modified_at, ) + @override @classmethod async def open( cls, @@ -139,6 +146,7 @@ async def open( return cls(client) + @override async def drop(self) -> None: await self._client.drop() diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index 372e4373e5..fd5dd017b2 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -1,23 +1,16 @@ from __future__ import annotations -import asyncio -from collections import deque -from contextlib import suppress -from datetime import datetime, timedelta, timezone +from datetime import timedelta from logging import getLogger -from typing import TYPE_CHECKING, Any, TypeVar +from pathlib import Path +from typing import TYPE_CHECKING, TypeVar -from cachetools import LRUCache from typing_extensions import override from crawlee import service_locator -from crawlee._utils.crypto import crypto_random_object_id from crawlee._utils.docs import docs_group -from crawlee._utils.requests import unique_key_to_request_id -from crawlee._utils.wait import wait_for_all_tasks_for_finish -from crawlee.events import Event from crawlee.request_loaders import RequestManager -from crawlee.storage_clients.models import ProcessedRequest, RequestQueueMetadata, StorageMetadata +from crawlee.storage_clients.models import Request, RequestQueueMetadata from ._base import Storage @@ -27,12 +20,40 @@ from crawlee import Request from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient - from crawlee.storages._types import CachedRequest + from crawlee.storage_clients._base import RequestQueueClient + from crawlee.storage_clients.models import ProcessedRequest logger = getLogger(__name__) T = TypeVar('T') +# TODO: implement: +# - caching / memoization of both KVS & KVS clients + +# Properties: +# - id +# - name +# - metadata + +# Methods +# - open +# - drop +# - add_request +# - add_requests_batched +# - get_handled_count +# - get_total_count +# - get_request +# - fetch_next_request +# - mark_request_as_handled +# - reclaim_request +# - is_empty +# - is_finished + +# Breaking changes: +# - from_storage_object method has been removed - Use the open method with name and/or id instead. +# - get_info -> metadata property +# - storage_object -> metadata property + @docs_group('Classes') class RequestQueue(Storage, RequestManager): @@ -70,80 +91,42 @@ class RequestQueue(Storage, RequestManager): _MAX_CACHED_REQUESTS = 1_000_000 """Maximum number of requests that can be cached.""" - def __init__( - self, - id: str, - name: str | None, - storage_client: StorageClient, - ) -> None: - config = service_locator.get_configuration() - event_manager = service_locator.get_event_manager() - - self._id = id - self._name = name - - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) - - # Get resource clients from storage client - self._resource_client = storage_client.request_queue(self._id) + def __init__(self, client: RequestQueueClient) -> None: + """Initialize a new instance. - self._request_lock_time = timedelta(minutes=3) - self._queue_paused_for_migration = False - self._queue_has_locked_requests: bool | None = None - self._should_check_for_forefront_requests = False - - self._is_finished_log_throttle_counter = 0 - self._dequeued_request_count = 0 - - event_manager.on(event=Event.MIGRATING, listener=lambda _: setattr(self, '_queue_paused_for_migration', True)) - event_manager.on(event=Event.MIGRATING, listener=self._clear_possible_locks) - event_manager.on(event=Event.ABORTING, listener=self._clear_possible_locks) - - # Other internal attributes - self._tasks = list[asyncio.Task]() - self._client_key = crypto_random_object_id() - self._internal_timeout = config.internal_timeout or timedelta(minutes=5) - self._assumed_total_count = 0 - self._assumed_handled_count = 0 - self._queue_head = deque[str]() - self._list_head_and_lock_task: asyncio.Task | None = None - self._last_activity = datetime.now(timezone.utc) - self._requests_cache: LRUCache[str, CachedRequest] = LRUCache(maxsize=self._MAX_CACHED_REQUESTS) - - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> RequestQueue: - """Initialize a new instance of RequestQueue from a storage metadata object.""" - request_queue = RequestQueue( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) + Preferably use the `RequestQueue.open` constructor to create a new instance. - request_queue.storage_object = storage_object - return request_queue + Args: + client: An instance of a key-value store client. + """ + self._client = client - @property @override + @property def id(self) -> str: - return self._id + return self._client.id - @property @override - def name(self) -> str | None: - return self._name - @property - @override - def storage_object(self) -> StorageMetadata: - return self._storage_object + def name(self) -> str | None: + return self._client.name - @storage_object.setter @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object + @property + def metadata(self) -> RequestQueueMetadata: + return RequestQueueMetadata( + id=self._client.id, + name=self._client.id, + accessed_at=self._client.accessed_at, + created_at=self._client.created_at, + modified_at=self._client.modified_at, + had_multiple_clients=self._client.had_multiple_clients, + handled_request_count=self._client.handled_request_count, + pending_request_count=self._client.pending_request_count, + stats=self._client.stats, + total_request_count=self._client.total_request_count, + resource_directory=self._client.resource_directory, + ) @override @classmethod @@ -152,32 +135,31 @@ async def open( *, id: str | None = None, name: str | None = None, + purge_on_start: bool | None = None, + storage_dir: Path | None = None, configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> RequestQueue: - from crawlee.storages._creation_management import open_storage + if id and name: + raise ValueError('Only one of "id" or "name" can be specified, not both.') - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client + purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start + storage_dir = Path(configuration.storage_dir) if storage_dir is None else storage_dir - return await open_storage( - storage_class=cls, + client = await storage_client.open_request_queue_client( id=id, name=name, - configuration=configuration, - storage_client=storage_client, + purge_on_start=purge_on_start, + storage_dir=storage_dir, ) + return cls(client) + @override async def drop(self, *, timeout: timedelta | None = None) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache - - # Wait for all tasks to finish - await wait_for_all_tasks_for_finish(self._tasks, logger=logger, timeout=timeout) - - # Delete the storage from the underlying client and remove it from the cache - await self._resource_client.delete() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + await self._client.drop() @override async def add_request( @@ -186,35 +168,7 @@ async def add_request( *, forefront: bool = False, ) -> ProcessedRequest: - request = self._transform_request(request) - self._last_activity = datetime.now(timezone.utc) - - cache_key = unique_key_to_request_id(request.unique_key) - cached_info = self._requests_cache.get(cache_key) - - if cached_info: - request.id = cached_info['id'] - # We may assume that if request is in local cache then also the information if the request was already - # handled is there because just one client should be using one queue. - return ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=cached_info['was_already_handled'], - ) - - processed_request = await self._resource_client.add_request(request, forefront=forefront) - processed_request.unique_key = request.unique_key - - self._cache_request(cache_key, processed_request, forefront=forefront) - - if not processed_request.was_already_present and forefront: - self._should_check_for_forefront_requests = True - - if request.handled_at is None and not processed_request.was_already_present: - self._assumed_total_count += 1 - - return processed_request + return await self._client.add_request(request, forefront=forefront) @override async def add_requests_batched( @@ -226,8 +180,8 @@ async def add_requests_batched( wait_for_all_requests_to_be_added: bool = False, wait_for_all_requests_to_be_added_timeout: timedelta | None = None, ) -> None: - transformed_requests = self._transform_requests(requests) - wait_time_secs = wait_time_between_batches.total_seconds() + # TODO: implement + pass # Wait for the first batch to be added first_batch = transformed_requests[:batch_size] @@ -290,7 +244,7 @@ async def get_request(self, request_id: str) -> Request | None: Returns: The retrieved request, or `None`, if it does not exist. """ - return await self._resource_client.get_request(request_id) + # TODO: implement async def fetch_next_request(self) -> Request | None: """Return the next request in the queue to be processed. @@ -307,47 +261,7 @@ async def fetch_next_request(self) -> Request | None: Returns: The request or `None` if there are no more pending requests. """ - self._last_activity = datetime.now(timezone.utc) - - await self._ensure_head_is_non_empty() - - # We are likely done at this point. - if len(self._queue_head) == 0: - return None - - next_request_id = self._queue_head.popleft() - request = await self._get_or_hydrate_request(next_request_id) - - # NOTE: It can happen that the queue head index is inconsistent with the main queue table. - # This can occur in two situations: - - # 1) - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). In this case, keep the request marked as in progress for a short while, - # so that is_finished() doesn't return true and _ensure_head_is_non_empty() doesn't not load the request into - # the queueHeadDict straight again. After the interval expires, fetch_next_request() will try to fetch this - # request again, until it eventually appears in the main table. - if request is None: - logger.debug( - 'Cannot find a request from the beginning of queue, will be retried later', - extra={'nextRequestId': next_request_id}, - ) - return None - - # 2) - # Queue head index is behind the main table and the underlying request was already handled (by some other - # client, since we keep the track of handled requests in recently_handled dictionary). We just add the request - # to the recently_handled dictionary so that next call to _ensure_head_is_non_empty() will not put the request - # again to queue_head_dict. - if request.handled_at is not None: - logger.debug( - 'Request fetched from the beginning of queue was already handled', - extra={'nextRequestId': next_request_id}, - ) - return None - - self._dequeued_request_count += 1 - return request + # TODO: implement async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: """Mark a request as handled after successful processing. @@ -360,20 +274,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | Returns: Information about the queue operation. `None` if the given request was not in progress. """ - self._last_activity = datetime.now(timezone.utc) - - if request.handled_at is None: - request.handled_at = datetime.now(timezone.utc) - - processed_request = await self._resource_client.update_request(request) - processed_request.unique_key = request.unique_key - self._dequeued_request_count -= 1 - - if not processed_request.was_already_handled: - self._assumed_handled_count += 1 - - self._cache_request(unique_key_to_request_id(request.unique_key), processed_request, forefront=False) - return processed_request + # TODO: implement async def reclaim_request( self, @@ -392,32 +293,15 @@ async def reclaim_request( Returns: Information about the queue operation. `None` if the given request was not in progress. """ - self._last_activity = datetime.now(timezone.utc) - - processed_request = await self._resource_client.update_request(request, forefront=forefront) - processed_request.unique_key = request.unique_key - self._cache_request(unique_key_to_request_id(request.unique_key), processed_request, forefront=forefront) - - if forefront: - self._should_check_for_forefront_requests = True - - if processed_request: - # Try to delete the request lock if possible - try: - await self._resource_client.delete_request_lock(request.id, forefront=forefront) - except Exception as err: - logger.debug(f'Failed to delete request lock for request {request.id}', exc_info=err) - - return processed_request + # TODO: implement async def is_empty(self) -> bool: """Check whether the queue is empty. Returns: - bool: `True` if the next call to `RequestQueue.fetch_next_request` would return `None`, otherwise `False`. + `True` if the next call to `RequestQueue.fetch_next_request` would return `None`, otherwise `False`. """ - await self._ensure_head_is_non_empty() - return len(self._queue_head) == 0 + # TODO: implement async def is_finished(self) -> bool: """Check whether the queue is finished. @@ -426,280 +310,6 @@ async def is_finished(self) -> bool: negative, but it will never return a false positive. Returns: - bool: `True` if all requests were already handled and there are no more left. `False` otherwise. + `True` if all requests were already handled and there are no more left. `False` otherwise. """ - if self._tasks: - logger.debug('Background tasks are still in progress') - return False - - if self._queue_head: - logger.debug( - 'There are still ids in the queue head that are pending processing', - extra={ - 'queue_head_ids_pending': len(self._queue_head), - }, - ) - - return False - - await self._ensure_head_is_non_empty() - - if self._queue_head: - logger.debug('Queue head still returned requests that need to be processed') - - return False - - # Could not lock any new requests - decide based on whether the queue contains requests locked by another client - if self._queue_has_locked_requests is not None: - if self._queue_has_locked_requests and self._dequeued_request_count == 0: - # The `% 25` was absolutely arbitrarily picked. It's just to not spam the logs too much. - if self._is_finished_log_throttle_counter % 25 == 0: - logger.info('The queue still contains requests locked by another client') - - self._is_finished_log_throttle_counter += 1 - - logger.debug( - f'Deciding if we are finished based on `queue_has_locked_requests` = {self._queue_has_locked_requests}' - ) - return not self._queue_has_locked_requests - - metadata = await self._resource_client.get() - if metadata is not None and not metadata.had_multiple_clients and not self._queue_head: - logger.debug('Queue head is empty and there are no other clients - we are finished') - - return True - - # The following is a legacy algorithm for checking if the queue is finished. - # It is used only for request queue clients that do not provide the `queue_has_locked_requests` flag. - current_head = await self._resource_client.list_head(limit=2) - - if current_head.items: - logger.debug('The queue still contains unfinished requests or requests locked by another client') - - return len(current_head.items) == 0 - - async def get_info(self) -> RequestQueueMetadata | None: - """Get an object containing general information about the request queue.""" - return await self._resource_client.get() - - @override - async def get_handled_count(self) -> int: - return self._assumed_handled_count - - @override - async def get_total_count(self) -> int: - return self._assumed_total_count - - async def _ensure_head_is_non_empty(self) -> None: - # Stop fetching if we are paused for migration - if self._queue_paused_for_migration: - return - - # We want to fetch ahead of time to minimize dead time - if len(self._queue_head) > 1 and not self._should_check_for_forefront_requests: - return - - if self._list_head_and_lock_task is None: - task = asyncio.create_task(self._list_head_and_lock(), name='request_queue_list_head_and_lock_task') - - def callback(_: Any) -> None: - self._list_head_and_lock_task = None - - task.add_done_callback(callback) - self._list_head_and_lock_task = task - - await self._list_head_and_lock_task - - async def _list_head_and_lock(self) -> None: - # Make a copy so that we can clear the flag only if the whole method executes after the flag was set - # (i.e, it was not set in the middle of the execution of the method) - should_check_for_forefront_requests = self._should_check_for_forefront_requests - - limit = 25 - - response = await self._resource_client.list_and_lock_head( - limit=limit, lock_secs=int(self._request_lock_time.total_seconds()) - ) - - self._queue_has_locked_requests = response.queue_has_locked_requests - - head_id_buffer = list[str]() - forefront_head_id_buffer = list[str]() - - for request in response.items: - # Queue head index might be behind the main table, so ensure we don't recycle requests - if not request.id or not request.unique_key: - logger.debug( - 'Skipping request from queue head, already in progress or recently handled', - extra={ - 'id': request.id, - 'unique_key': request.unique_key, - }, - ) - - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(request.id) - - continue - - # If we remember that we added the request ourselves and we added it to the forefront, - # we will put it to the beginning of the local queue head to preserve the expected order. - # If we do not remember that, we will enqueue it normally. - cached_request = self._requests_cache.get(unique_key_to_request_id(request.unique_key)) - forefront = cached_request['forefront'] if cached_request else False - - if forefront: - forefront_head_id_buffer.insert(0, request.id) - else: - head_id_buffer.append(request.id) - - self._cache_request( - unique_key_to_request_id(request.unique_key), - ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=False, - ), - forefront=forefront, - ) - - for request_id in head_id_buffer: - self._queue_head.append(request_id) - - for request_id in forefront_head_id_buffer: - self._queue_head.appendleft(request_id) - - # If the queue head became too big, unlock the excess requests - to_unlock = list[str]() - while len(self._queue_head) > limit: - to_unlock.append(self._queue_head.pop()) - - if to_unlock: - await asyncio.gather( - *[self._resource_client.delete_request_lock(request_id) for request_id in to_unlock], - return_exceptions=True, # Just ignore the exceptions - ) - - # Unset the should_check_for_forefront_requests flag - the check is finished - if should_check_for_forefront_requests: - self._should_check_for_forefront_requests = False - - def _reset(self) -> None: - self._queue_head.clear() - self._list_head_and_lock_task = None - self._assumed_total_count = 0 - self._assumed_handled_count = 0 - self._requests_cache.clear() - self._last_activity = datetime.now(timezone.utc) - - def _cache_request(self, cache_key: str, processed_request: ProcessedRequest, *, forefront: bool) -> None: - self._requests_cache[cache_key] = { - 'id': processed_request.id, - 'was_already_handled': processed_request.was_already_handled, - 'hydrated': None, - 'lock_expires_at': None, - 'forefront': forefront, - } - - async def _get_or_hydrate_request(self, request_id: str) -> Request | None: - cached_entry = self._requests_cache.get(request_id) - - if not cached_entry: - # 2.1. Attempt to prolong the request lock to see if we still own the request - prolong_result = await self._prolong_request_lock(request_id) - - if not prolong_result: - return None - - # 2.1.1. If successful, hydrate the request and return it - hydrated_request = await self.get_request(request_id) - - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). - if not hydrated_request: - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(request_id) - - return None - - self._requests_cache[request_id] = { - 'id': request_id, - 'hydrated': hydrated_request, - 'was_already_handled': hydrated_request.handled_at is not None, - 'lock_expires_at': prolong_result, - 'forefront': False, - } - - return hydrated_request - - # 1.1. If hydrated, prolong the lock more and return it - if cached_entry['hydrated']: - # 1.1.1. If the lock expired on the hydrated requests, try to prolong. If we fail, we lost the request - # (or it was handled already) - if cached_entry['lock_expires_at'] and cached_entry['lock_expires_at'] < datetime.now(timezone.utc): - prolonged = await self._prolong_request_lock(cached_entry['id']) - - if not prolonged: - return None - - cached_entry['lock_expires_at'] = prolonged - - return cached_entry['hydrated'] - - # 1.2. If not hydrated, try to prolong the lock first (to ensure we keep it in our queue), hydrate and return it - prolonged = await self._prolong_request_lock(cached_entry['id']) - - if not prolonged: - return None - - # This might still return null if the queue head is inconsistent with the main queue table. - hydrated_request = await self.get_request(cached_entry['id']) - - cached_entry['hydrated'] = hydrated_request - - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). - if not hydrated_request: - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(cached_entry['id']) - - return None - - return hydrated_request - - async def _prolong_request_lock(self, request_id: str) -> datetime | None: - try: - res = await self._resource_client.prolong_request_lock( - request_id, lock_secs=int(self._request_lock_time.total_seconds()) - ) - except Exception as err: - # Most likely we do not own the lock anymore - logger.warning( - f'Failed to prolong lock for cached request {request_id}, either lost the lock ' - 'or the request was already handled\n', - exc_info=err, - ) - return None - else: - return res.lock_expires_at - - async def _clear_possible_locks(self) -> None: - self._queue_paused_for_migration = True - request_id: str | None = None - - while True: - try: - request_id = self._queue_head.pop() - except LookupError: - break - - with suppress(Exception): - await self._resource_client.delete_request_lock(request_id) - # If this fails, we don't have the lock, or the request was never locked. Either way it's fine + # TODO: implement From e74ae8e495a25acc788e4c62158ce93b74adf44d Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Fri, 11 Apr 2025 18:55:42 +0200 Subject: [PATCH 07/22] Add Dataset & KVS file system cliets tests --- pyproject.toml | 1 + .../_file_system/_dataset_client.py | 120 +++-- .../_file_system/_key_value_store_client.py | 99 ++-- src/crawlee/storage_clients/models.py | 2 +- src/crawlee/storages/_dataset.py | 2 +- src/crawlee/storages/_key_value_store.py | 2 +- tests/unit/_utils/test_file.py | 17 +- tests/unit/conftest.py | 26 +- .../_file_system/test_dataset_client.py | 280 +++++++++++ .../test_key_value_store_client.py | 338 ++++++++++++++ .../_memory/test_creation_management.py | 59 --- .../_memory/test_dataset_client.py | 148 ------ .../_memory/test_dataset_collection_client.py | 45 -- .../_memory/test_key_value_store_client.py | 442 ------------------ .../test_key_value_store_collection_client.py | 42 -- .../_memory/test_memory_storage_e2e.py | 130 ------ .../_memory/test_request_queue_client.py | 249 ---------- .../test_request_queue_collection_client.py | 42 -- tests/unit/storages/test_key_value_store.py | 2 +- 19 files changed, 756 insertions(+), 1290 deletions(-) create mode 100644 tests/unit/storage_clients/_file_system/test_dataset_client.py create mode 100644 tests/unit/storage_clients/_file_system/test_key_value_store_client.py delete mode 100644 tests/unit/storage_clients/_memory/test_creation_management.py delete mode 100644 tests/unit/storage_clients/_memory/test_dataset_client.py delete mode 100644 tests/unit/storage_clients/_memory/test_dataset_collection_client.py delete mode 100644 tests/unit/storage_clients/_memory/test_key_value_store_client.py delete mode 100644 tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py delete mode 100644 tests/unit/storage_clients/_memory/test_memory_storage_e2e.py delete mode 100644 tests/unit/storage_clients/_memory/test_request_queue_client.py delete mode 100644 tests/unit/storage_clients/_memory/test_request_queue_collection_client.py diff --git a/pyproject.toml b/pyproject.toml index ece3fe7956..fdc89eed71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,7 @@ indent-style = "space" "F401", # Unused imports ] "**/{tests}/*" = [ + "ASYNC230", # Async functions should not open files with blocking methods like `open` "D", # Everything from the pydocstyle "INP001", # File {filename} is part of an implicit namespace package, add an __init__.py "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py index c5693c46c8..3c0b46d5c2 100644 --- a/src/crawlee/storage_clients/_file_system/_dataset_client.py +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -23,6 +23,9 @@ logger = getLogger(__name__) +_cache_by_name = dict[str, 'FileSystemDatasetClient']() +"""A dictionary to cache clients by their names.""" + class FileSystemDatasetClient(DatasetClient): """A file system storage implementation of the dataset client. @@ -55,12 +58,15 @@ def __init__( Preferably use the `FileSystemDatasetClient.open` class method to create a new instance. """ - self._id = id - self._name = name - self._created_at = created_at - self._accessed_at = accessed_at - self._modified_at = modified_at - self._item_count = item_count + self._metadata = DatasetMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + item_count=item_count, + ) + self._storage_dir = storage_dir # Internal attributes. @@ -70,49 +76,50 @@ def __init__( @override @property def id(self) -> str: - return self._id + return self._metadata.id @override @property - def name(self) -> str | None: - return self._name + def name(self) -> str: + return self._metadata.name @override @property def created_at(self) -> datetime: - return self._created_at + return self._metadata.created_at @override @property def accessed_at(self) -> datetime: - return self._accessed_at + return self._metadata.accessed_at @override @property def modified_at(self) -> datetime: - return self._modified_at + return self._metadata.modified_at @override @property def item_count(self) -> int: - return self._item_count + return self._metadata.item_count @property - def _path_to_dataset(self) -> Path: + def path_to_dataset(self) -> Path: """The full path to the dataset directory.""" - return self._storage_dir / self._STORAGE_SUBDIR / self._name + return self._storage_dir / self._STORAGE_SUBDIR / self.name @property - def _path_to_metadata(self) -> Path: + def path_to_metadata(self) -> Path: """The full path to the dataset metadata file.""" - return self._path_to_dataset / METADATA_FILENAME + return self.path_to_dataset / METADATA_FILENAME @override @classmethod async def open( cls, - id: str | None, - name: str | None, + *, + id: str | None = None, + name: str | None = None, storage_dir: Path, ) -> FileSystemDatasetClient: """Open an existing dataset client or create a new one if it does not exist. @@ -134,6 +141,11 @@ async def open( ) name = name or cls._DEFAULT_NAME + + # Check if the client is already cached by name. + if name in _cache_by_name: + return _cache_by_name[name] + dataset_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = dataset_path / METADATA_FILENAME @@ -178,25 +190,40 @@ async def open( ) await client._update_metadata() + # Cache the client by name. + _cache_by_name[name] = client + return client @override async def drop(self) -> None: - # If the dataset directory exists, remove it recursively. - if self._path_to_dataset.exists(): + # If the client directory exists, remove it recursively. + if self.path_to_dataset.exists(): async with self._lock: - await asyncio.to_thread(shutil.rmtree, self._path_to_dataset) + await asyncio.to_thread(shutil.rmtree, self.path_to_dataset) + + # Remove the client from the cache. + if self.name in _cache_by_name: + del _cache_by_name[self.name] @override async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + new_item_count = self.item_count + # If data is a list, push each item individually. if isinstance(data, list): for item in data: - await self._push_item(item) + new_item_count += 1 + await self._push_item(item, new_item_count) else: - await self._push_item(data) + new_item_count += 1 + await self._push_item(data, new_item_count) - await self._update_metadata(update_accessed_at=True, update_modified_at=True) + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=new_item_count, + ) @override async def get_data( @@ -223,8 +250,8 @@ async def get_data( ) # If the dataset directory does not exist, log a warning and return an empty page. - if not self._path_to_dataset.exists(): - logger.warning(f'Dataset directory not found: {self._path_to_dataset}') + if not self.path_to_dataset.exists(): + logger.warning(f'Dataset directory not found: {self.path_to_dataset}') return DatasetItemsListPage( count=0, offset=offset, @@ -298,8 +325,8 @@ async def iterate( ) # If the dataset directory does not exist, log a warning and return immediately. - if not self._path_to_dataset.exists(): - logger.warning(f'Dataset directory not found: {self._path_to_dataset}') + if not self.path_to_dataset.exists(): + logger.warning(f'Dataset directory not found: {self.path_to_dataset}') return # Get the list of sorted data files. @@ -334,33 +361,31 @@ async def iterate( async def _update_metadata( self, *, + new_item_count: int | None = None, update_accessed_at: bool = False, update_modified_at: bool = False, ) -> None: """Update the dataset metadata file with current information. Args: + new_item_count: If provided, update the item count to this value. update_accessed_at: If True, update the `accessed_at` timestamp to the current time. update_modified_at: If True, update the `modified_at` timestamp to the current time. """ now = datetime.now(timezone.utc) - metadata = DatasetMetadata( - id=self._id, - name=self._name, - created_at=self._created_at, - accessed_at=now if update_accessed_at else self._accessed_at, - modified_at=now if update_modified_at else self._modified_at, - item_count=self._item_count, - ) + + self._metadata.accessed_at = now if update_accessed_at else self.accessed_at + self._metadata.modified_at = now if update_modified_at else self.modified_at + self._metadata.item_count = new_item_count if new_item_count else self.item_count # Ensure the parent directory for the metadata file exists. - await asyncio.to_thread(self._path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) # Dump the serialized metadata to the file. - data = await json_dumps(metadata.model_dump()) - await asyncio.to_thread(self._path_to_metadata.write_text, data, encoding='utf-8') + data = await json_dumps(self._metadata.model_dump()) + await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') - async def _push_item(self, item: dict[str, Any]) -> None: + async def _push_item(self, item: dict[str, Any], item_id: int) -> None: """Push a single item to the dataset. This method increments the item count, writes the item as a JSON file with a zero-padded filename, @@ -368,13 +393,12 @@ async def _push_item(self, item: dict[str, Any]) -> None: """ # Acquire the lock to perform file operations safely. async with self._lock: - self._item_count += 1 # Generate the filename for the new item using zero-padded numbering. - filename = f'{str(self._item_count).zfill(self._LOCAL_ENTRY_NAME_DIGITS)}.json' - file_path = self._path_to_dataset / filename + filename = f'{str(item_id).zfill(self._LOCAL_ENTRY_NAME_DIGITS)}.json' + file_path = self.path_to_dataset / filename # Ensure the dataset directory exists. - await asyncio.to_thread(self._path_to_dataset.mkdir, parents=True, exist_ok=True) + await asyncio.to_thread(self.path_to_dataset.mkdir, parents=True, exist_ok=True) # Dump the serialized item to the file. data = await json_dumps(item) @@ -389,12 +413,12 @@ async def _get_sorted_data_files(self) -> list[Path]: # Retrieve and sort all JSON files in the dataset directory numerically. files = await asyncio.to_thread( sorted, - self._path_to_dataset.glob('*.json'), + self.path_to_dataset.glob('*.json'), key=lambda f: int(f.stem) if f.stem.isdigit() else 0, ) # Remove the metadata file from the list if present. - if self._path_to_metadata in files: - files.remove(self._path_to_metadata) + if self.path_to_metadata in files: + files.remove(self.path_to_metadata) return files diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py index 921838a73f..129c8aeb49 100644 --- a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -12,11 +12,7 @@ from crawlee._utils.crypto import crypto_random_object_id from crawlee.storage_clients._base import KeyValueStoreClient -from crawlee.storage_clients.models import ( - KeyValueStoreMetadata, - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, -) +from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata from ._utils import METADATA_FILENAME, json_dumps @@ -26,6 +22,9 @@ logger = getLogger(__name__) +_cache_by_name = dict[str, 'FileSystemKeyValueStoreClient']() +"""A dictionary to cache clients by their names.""" + class FileSystemKeyValueStoreClient(KeyValueStoreClient): """A file system key-value store (KVS) implementation.""" @@ -50,11 +49,14 @@ def __init__( Preferably use the `FileSystemKeyValueStoreClient.open` class method to create a new instance. """ - self._id = id - self._name = name - self._created_at = created_at - self._accessed_at = accessed_at - self._modified_at = modified_at + self._metadata = KeyValueStoreMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + ) + self._storage_dir = storage_dir # Internal attributes. @@ -64,44 +66,45 @@ def __init__( @override @property def id(self) -> str: - return self._id + return self._metadata.id @override @property - def name(self) -> str | None: - return self._name + def name(self) -> str: + return self._metadata.name @override @property def created_at(self) -> datetime: - return self._created_at + return self._metadata.created_at @override @property def accessed_at(self) -> datetime: - return self._accessed_at + return self._metadata.accessed_at @override @property def modified_at(self) -> datetime: - return self._modified_at + return self._metadata.modified_at @property - def _path_to_kvs(self) -> Path: + def path_to_kvs(self) -> Path: """The full path to the key-value store directory.""" - return self._storage_dir / self._STORAGE_SUBDIR / self._name + return self._storage_dir / self._STORAGE_SUBDIR / self.name @property - def _path_to_metadata(self) -> Path: + def path_to_metadata(self) -> Path: """The full path to the key-value store metadata file.""" - return self._path_to_kvs / METADATA_FILENAME + return self.path_to_kvs / METADATA_FILENAME @override @classmethod async def open( cls, - id: str | None, - name: str | None, + *, + id: str | None = None, + name: str | None = None, storage_dir: Path, ) -> FileSystemKeyValueStoreClient: """Open an existing key-value store client or create a new one if it does not exist. @@ -123,6 +126,11 @@ async def open( ) name = name or cls._DEFAULT_NAME + + # Check if the client is already cached by name. + if name in _cache_by_name: + return _cache_by_name[name] + kvs_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = kvs_path / METADATA_FILENAME @@ -165,18 +173,28 @@ async def open( ) await client._update_metadata() + # Cache the client by name. + _cache_by_name[name] = client + return client @override async def drop(self) -> None: - # If the key-value store directory exists, remove it recursively. - if self._path_to_kvs.exists(): + # If the client directory exists, remove it recursively. + if self.path_to_kvs.exists(): async with self._lock: - await asyncio.to_thread(shutil.rmtree, self._path_to_kvs) + await asyncio.to_thread(shutil.rmtree, self.path_to_kvs) + + # Remove the client from the cache. + if self.name in _cache_by_name: + del _cache_by_name[self.name] @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: - record_path = self._path_to_kvs / key + # Update the metadata to record access + await self._update_metadata(update_accessed_at=True) + + record_path = self.path_to_kvs / key if not record_path.exists(): return None @@ -227,9 +245,6 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: else: value = value_bytes - # Update the metadata to record access - await self._update_metadata(update_accessed_at=True) - # Calculate the size of the value in bytes size = len(value_bytes) @@ -255,7 +270,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No # Fallback: attempt to convert to string and encode. value_bytes = str(value).encode('utf-8') - record_path = self._path_to_kvs / key + record_path = self.path_to_kvs / key # Get the metadata. # Calculate the size of the value in bytes @@ -266,7 +281,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No async with self._lock: # Ensure the key-value store directory exists. - await asyncio.to_thread(self._path_to_kvs.mkdir, parents=True, exist_ok=True) + await asyncio.to_thread(self.path_to_kvs.mkdir, parents=True, exist_ok=True) # Dump the value to the file. await asyncio.to_thread(record_path.write_bytes, value_bytes) @@ -283,7 +298,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No @override async def delete_value(self, *, key: str) -> None: - record_path = self._path_to_kvs / key + record_path = self.path_to_kvs / key metadata_path = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') deleted = False @@ -312,13 +327,13 @@ async def iterate_keys( limit: int | None = None, ) -> AsyncIterator[KeyValueStoreRecordMetadata]: # Check if the KVS directory exists - if not self._path_to_kvs.exists(): + if not self.path_to_kvs.exists(): return count = 0 async with self._lock: # Get all files in the KVS directory - files = sorted(await asyncio.to_thread(list, self._path_to_kvs.glob('*'))) + files = sorted(await asyncio.to_thread(list, self.path_to_kvs.glob('*'))) for file_path in files: # Skip the main metadata file @@ -371,20 +386,16 @@ async def _update_metadata( update_modified_at: If True, update the `modified_at` timestamp to the current time. """ now = datetime.now(timezone.utc) - metadata = KeyValueStoreMetadata( - id=self._id, - name=self._name, - created_at=self._created_at, - accessed_at=now if update_accessed_at else self._accessed_at, - modified_at=now if update_modified_at else self._modified_at, - ) + + self._metadata.accessed_at = now if update_accessed_at else self._metadata.accessed_at + self._metadata.modified_at = now if update_modified_at else self._metadata.modified_at # Ensure the parent directory for the metadata file exists. - await asyncio.to_thread(self._path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) # Dump the serialized metadata to the file. - data = await json_dumps(metadata.model_dump()) - await asyncio.to_thread(self._path_to_metadata.write_text, data, encoding='utf-8') + data = await json_dumps(self._metadata.model_dump()) + await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') def _infer_mime_type(self, value: Any) -> str: """Infer the MIME content type from the value. diff --git a/src/crawlee/storage_clients/models.py b/src/crawlee/storage_clients/models.py index 8299220475..2887492885 100644 --- a/src/crawlee/storage_clients/models.py +++ b/src/crawlee/storage_clients/models.py @@ -26,7 +26,7 @@ class StorageMetadata(BaseModel): model_config = ConfigDict(populate_by_name=True, extra='allow') id: Annotated[str, Field(alias='id')] - name: Annotated[str | None, Field(alias='name', default='')] + name: Annotated[str, Field(alias='name', default='default')] accessed_at: Annotated[datetime, Field(alias='accessedAt')] created_at: Annotated[datetime, Field(alias='createdAt')] modified_at: Annotated[datetime, Field(alias='modifiedAt')] diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 6b6160f471..112addfcf1 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) # TODO: -# - caching / memoization of both datasets & dataset clients +# - caching / memoization of Dataset # Properties: # - id diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index b99957bb19..6567adeb8f 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -22,7 +22,7 @@ T = TypeVar('T') # TODO: -# - caching / memoization of both KVS & KVS clients +# - caching / memoization of KVS # Properties: # - id diff --git a/tests/unit/_utils/test_file.py b/tests/unit/_utils/test_file.py index a86291b43f..b05d44723e 100644 --- a/tests/unit/_utils/test_file.py +++ b/tests/unit/_utils/test_file.py @@ -1,6 +1,5 @@ from __future__ import annotations -import io from datetime import datetime, timezone from pathlib import Path @@ -12,7 +11,6 @@ force_remove, force_rename, is_content_type, - is_file_or_bytes, json_dumps, ) @@ -25,15 +23,6 @@ async def test_json_dumps() -> None: assert await json_dumps(datetime(2022, 1, 1, tzinfo=timezone.utc)) == '"2022-01-01 00:00:00+00:00"' -def test_is_file_or_bytes() -> None: - assert is_file_or_bytes(b'bytes') is True - assert is_file_or_bytes(bytearray(b'bytearray')) is True - assert is_file_or_bytes(io.BytesIO(b'some bytes')) is True - assert is_file_or_bytes(io.StringIO('string')) is True - assert is_file_or_bytes('just a regular string') is False - assert is_file_or_bytes(12345) is False - - @pytest.mark.parametrize( ('content_type_enum', 'content_type', 'expected_result'), [ @@ -115,7 +104,7 @@ async def test_force_remove(tmp_path: Path) -> None: assert test_file_path.exists() is False # Remove the file if it exists - with open(test_file_path, 'a', encoding='utf-8'): # noqa: ASYNC230 + with open(test_file_path, 'a', encoding='utf-8'): pass assert test_file_path.exists() is True await force_remove(test_file_path) @@ -134,11 +123,11 @@ async def test_force_rename(tmp_path: Path) -> None: # Will remove dst_dir if it exists (also covers normal case) # Create the src_dir with a file in it src_dir.mkdir() - with open(src_file, 'a', encoding='utf-8'): # noqa: ASYNC230 + with open(src_file, 'a', encoding='utf-8'): pass # Create the dst_dir with a file in it dst_dir.mkdir() - with open(dst_file, 'a', encoding='utf-8'): # noqa: ASYNC230 + with open(dst_file, 'a', encoding='utf-8'): pass assert src_file.exists() is True assert dst_file.exists() is True diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b7ac06d124..a749d43f2e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -13,12 +13,10 @@ from uvicorn.config import Config from crawlee import service_locator -from crawlee.configuration import Configuration from crawlee.fingerprint_suite._browserforge_adapter import get_available_header_network from crawlee.http_clients import CurlImpersonateHttpClient, HttpxHttpClient from crawlee.proxy_configuration import ProxyInfo -from crawlee.storage_clients import MemoryStorageClient -from crawlee.storages import KeyValueStore, _creation_management +from crawlee.storages import KeyValueStore from tests.unit.server import TestServer, app, serve_in_thread if TYPE_CHECKING: @@ -64,14 +62,6 @@ def _prepare_test_env() -> None: service_locator._event_manager = None service_locator._storage_client = None - # Clear creation-related caches to ensure no state is carried over between tests. - monkeypatch.setattr(_creation_management, '_cache_dataset_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_dataset_by_name', {}) - monkeypatch.setattr(_creation_management, '_cache_kvs_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_kvs_by_name', {}) - monkeypatch.setattr(_creation_management, '_cache_rq_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_rq_by_name', {}) - # Verify that the test environment was set up correctly. assert os.environ.get('CRAWLEE_STORAGE_DIR') == str(tmp_path) assert service_locator._configuration_was_retrieved is False @@ -103,6 +93,8 @@ def _set_crawler_log_level(pytestconfig: pytest.Config, monkeypatch: pytest.Monk monkeypatch.setattr(_log_config, 'get_configured_log_level', lambda: getattr(logging, loglevel.upper())) + + @pytest.fixture async def proxy_info(unused_tcp_port: int) -> ProxyInfo: username = 'user' @@ -149,18 +141,6 @@ async def disabled_proxy(proxy_info: ProxyInfo) -> AsyncGenerator[ProxyInfo, Non yield proxy_info -@pytest.fixture -def memory_storage_client(tmp_path: Path) -> MemoryStorageClient: - """A fixture for testing the memory storage client and its resource clients.""" - config = Configuration( - persist_storage=True, - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - - return MemoryStorageClient.from_config(config) - - @pytest.fixture(scope='session') def header_network() -> dict: return get_available_header_network() diff --git a/tests/unit/storage_clients/_file_system/test_dataset_client.py b/tests/unit/storage_clients/_file_system/test_dataset_client.py new file mode 100644 index 0000000000..ae17746e06 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_dataset_client.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest + +from crawlee._consts import METADATA_FILENAME +from crawlee.storage_clients._file_system._dataset_client import FileSystemDatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +pytestmark = pytest.mark.only + + +@pytest.fixture +async def dataset_client(tmp_path: Path) -> AsyncGenerator[FileSystemDatasetClient, None]: + """A fixture for a file system dataset client.""" + client = await FileSystemDatasetClient.open(name='test_dataset', storage_dir=tmp_path) + yield client + await client.drop() + + +async def test_open_creates_new_dataset(tmp_path: Path) -> None: + """Test that open() creates a new dataset with proper metadata when it doesn't exist.""" + client = await FileSystemDatasetClient.open(name='new_dataset', storage_dir=tmp_path) + + # Verify client properties + assert client.id is not None + assert client.name == 'new_dataset' + assert client.item_count == 0 + assert isinstance(client.created_at, datetime) + assert isinstance(client.accessed_at, datetime) + assert isinstance(client.modified_at, datetime) + + # Verify files were created + assert client.path_to_dataset.exists() + assert client.path_to_metadata.exists() + + # Verify metadata content + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == client.id + assert metadata['name'] == 'new_dataset' + assert metadata['item_count'] == 0 + + +async def test_open_existing_dataset(dataset_client: FileSystemDatasetClient, tmp_path: Path) -> None: + """Test that open() loads an existing dataset correctly.""" + # Open the same dataset again + reopened_client = await FileSystemDatasetClient.open(name=dataset_client.name, storage_dir=tmp_path) + + # Verify client properties + assert dataset_client.id == reopened_client.id + assert dataset_client.name == reopened_client.name + assert dataset_client.item_count == reopened_client.item_count + + # Verify clients (python) ids + assert id(dataset_client) == id(reopened_client) + + +async def test_open_with_id_raises_error(tmp_path: Path) -> None: + """Test that open() raises an error when an ID is provided.""" + with pytest.raises(ValueError, match='not supported for file system storage client'): + await FileSystemDatasetClient.open(id='some-id', storage_dir=tmp_path) + + +async def test_push_data_single_item(dataset_client: FileSystemDatasetClient) -> None: + """Test pushing a single item to the dataset.""" + item = {'key': 'value', 'number': 42} + await dataset_client.push_data(item) + + # Verify item count was updated + assert dataset_client.item_count == 1 + + all_files = list(dataset_client.path_to_dataset.glob('*.json')) + assert len(all_files) == 2 # 1 data file + 1 metadata file + + # Verify item was persisted + data_files = [item for item in all_files if item.name != METADATA_FILENAME] + assert len(data_files) == 1 + + # Verify file content + with Path(data_files[0]).open() as f: + saved_item = json.load(f) + assert saved_item == item + + +async def test_push_data_multiple_items(dataset_client: FileSystemDatasetClient) -> None: + """Test pushing multiple items to the dataset.""" + items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] + await dataset_client.push_data(items) + + # Verify item count was updated + assert dataset_client.item_count == 3 + + all_files = list(dataset_client.path_to_dataset.glob('*.json')) + assert len(all_files) == 4 # 3 data files + 1 metadata file + + # Verify items were saved to files + data_files = [f for f in all_files if f.name != METADATA_FILENAME] + assert len(data_files) == 3 + + +async def test_get_data_empty_dataset(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data from an empty dataset.""" + result = await dataset_client.get_data() + + assert isinstance(result, DatasetItemsListPage) + assert result.count == 0 + assert result.total == 0 + assert result.items == [] + + +async def test_get_data_with_items(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data from a dataset with items.""" + # Add some items + items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + + assert result.count == 3 + assert result.total == 3 + assert len(result.items) == 3 + assert result.items[0]['id'] == 1 + assert result.items[1]['id'] == 2 + assert result.items[2]['id'] == 3 + + +async def test_get_data_with_pagination(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data with pagination.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test offset + result = await dataset_client.get_data(offset=3) + assert result.count == 7 + assert result.offset == 3 + assert result.items[0]['id'] == 4 + + # Test limit + result = await dataset_client.get_data(limit=5) + assert result.count == 5 + assert result.limit == 5 + assert result.items[-1]['id'] == 5 + + # Test both offset and limit + result = await dataset_client.get_data(offset=2, limit=3) + assert result.count == 3 + assert result.offset == 2 + assert result.limit == 3 + assert result.items[0]['id'] == 3 + assert result.items[-1]['id'] == 5 + + +async def test_get_data_descending_order(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data in descending order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Get items in descending order + result = await dataset_client.get_data(desc=True) + + assert result.desc is True + assert result.items[0]['id'] == 5 + assert result.items[-1]['id'] == 1 + + +async def test_get_data_skip_empty(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data with skip_empty option.""" + # Add some items including an empty one + items = [ + {'id': 1, 'name': 'Item 1'}, + {}, # Empty item + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + assert result.count == 3 + + # Get non-empty items + result = await dataset_client.get_data(skip_empty=True) + assert result.count == 2 + assert all(item != {} for item in result.items) + + +async def test_iterate(dataset_client: FileSystemDatasetClient) -> None: + """Test iterating over dataset items.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Iterate over all items + collected_items = [item async for item in dataset_client.iterate()] + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_iterate_with_options(dataset_client: FileSystemDatasetClient) -> None: + """Test iterating with various options.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test with offset and limit + collected_items = [item async for item in dataset_client.iterate(offset=3, limit=3)] + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 4 + assert collected_items[-1]['id'] == 6 + + # Test with descending order + collected_items = [] + async for item in dataset_client.iterate(desc=True, limit=3): + collected_items.append(item) + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 10 + assert collected_items[-1]['id'] == 8 + + +async def test_drop(tmp_path: Path) -> None: + """Test dropping a dataset.""" + # Create a dataset and add an item + client = await FileSystemDatasetClient.open(name='to_drop', storage_dir=tmp_path) + await client.push_data({'test': 'data'}) + + # Verify the dataset directory exists + assert client.path_to_dataset.exists() + + # Drop the dataset + await client.drop() + + # Verify the dataset directory was removed + assert not client.path_to_dataset.exists() + + +async def test_metadata_updates(dataset_client: FileSystemDatasetClient) -> None: + """Test that metadata is updated correctly after operations.""" + # Record initial timestamps + initial_created = dataset_client.created_at + initial_accessed = dataset_client.accessed_at + initial_modified = dataset_client.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await dataset_client.get_data() + + # Verify timestamps + assert dataset_client.created_at == initial_created + assert dataset_client.accessed_at > initial_accessed + assert dataset_client.modified_at == initial_modified + + accessed_after_get = dataset_client.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await dataset_client.push_data({'new': 'item'}) + + # Verify timestamps again + assert dataset_client.created_at == initial_created + assert dataset_client.modified_at > initial_modified + assert dataset_client.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_file_system/test_key_value_store_client.py b/tests/unit/storage_clients/_file_system/test_key_value_store_client.py new file mode 100644 index 0000000000..4f1431ea59 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_key_value_store_client.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee._consts import METADATA_FILENAME +from crawlee.storage_clients._file_system._key_value_store_client import FileSystemKeyValueStoreClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + +pytestmark = pytest.mark.only + + +@pytest.fixture +async def kvs_client(tmp_path: Path) -> AsyncGenerator[FileSystemKeyValueStoreClient, None]: + """A fixture for a file system key-value store client.""" + client = await FileSystemKeyValueStoreClient.open(name='test_kvs', storage_dir=tmp_path) + yield client + await client.drop() + + +async def test_open_creates_new_kvs(tmp_path: Path) -> None: + """Test that open() creates a new key-value store with proper metadata when it doesn't exist.""" + client = await FileSystemKeyValueStoreClient.open(name='new_kvs', storage_dir=tmp_path) + + # Verify client properties + assert client.id is not None + assert client.name == 'new_kvs' + assert isinstance(client.created_at, datetime) + assert isinstance(client.accessed_at, datetime) + assert isinstance(client.modified_at, datetime) + + # Verify files were created + assert client.path_to_kvs.exists() + assert client.path_to_metadata.exists() + + # Verify metadata content + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == client.id + assert metadata['name'] == 'new_kvs' + + +async def test_open_existing_kvs(kvs_client: FileSystemKeyValueStoreClient, tmp_path: Path) -> None: + """Test that open() loads an existing key-value store correctly.""" + # Open the same key-value store again + reopened_client = await FileSystemKeyValueStoreClient.open(name=kvs_client.name, storage_dir=tmp_path) + + # Verify client properties + assert kvs_client.id == reopened_client.id + assert kvs_client.name == reopened_client.name + + # Verify clients (python) ids - should be the same object due to caching + assert id(kvs_client) == id(reopened_client) + + +async def test_open_with_id_raises_error(tmp_path: Path) -> None: + """Test that open() raises an error when an ID is provided.""" + with pytest.raises(ValueError, match='not supported for file system storage client'): + await FileSystemKeyValueStoreClient.open(id='some-id', storage_dir=tmp_path) + + +async def test_set_get_value_string(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test setting and getting a string value.""" + # Set a value + test_key = 'test-key' + test_value = 'Hello, world!' + await kvs_client.set_value(key=test_key, value=test_value) + + # Check if the file was created + key_path = kvs_client.path_to_kvs / test_key + key_metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' + assert key_path.exists() + assert key_metadata_path.exists() + + # Check file content + content = key_path.read_text(encoding='utf-8') + assert content == test_value + + # Check record metadata + with key_metadata_path.open() as f: + metadata = json.load(f) + assert metadata['key'] == test_key + assert metadata['content_type'] == 'text/plain; charset=utf-8' + assert metadata['size'] == len(test_value.encode('utf-8')) + + # Get the value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.key == test_key + assert record.value == test_value + assert record.content_type == 'text/plain; charset=utf-8' + assert record.size == len(test_value.encode('utf-8')) + + +async def test_set_get_value_json(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test setting and getting a JSON value.""" + # Set a value + test_key = 'test-json' + test_value = {'name': 'John', 'age': 30, 'items': [1, 2, 3]} + await kvs_client.set_value(key=test_key, value=test_value) + + # Get the value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.key == test_key + assert record.value == test_value + assert 'application/json' in record.content_type + + +async def test_set_get_value_bytes(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test setting and getting binary data.""" + # Set a value + test_key = 'test-binary' + test_value = b'\x00\x01\x02\x03\x04' + await kvs_client.set_value(key=test_key, value=test_value) + + # Get the value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.key == test_key + assert record.value == test_value + assert record.content_type == 'application/octet-stream' + assert record.size == len(test_value) + + +async def test_set_value_explicit_content_type(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test setting a value with an explicit content type.""" + test_key = 'test-explicit-content-type' + test_value = 'Hello, world!' + explicit_content_type = 'text/html; charset=utf-8' + + await kvs_client.set_value(key=test_key, value=test_value, content_type=explicit_content_type) + + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.content_type == explicit_content_type + + +async def test_get_nonexistent_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test getting a value that doesn't exist.""" + record = await kvs_client.get_value(key='nonexistent-key') + assert record is None + + +async def test_overwrite_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test overwriting an existing value.""" + test_key = 'test-overwrite' + + # Set initial value + initial_value = 'Initial value' + await kvs_client.set_value(key=test_key, value=initial_value) + + # Overwrite with new value + new_value = 'New value' + await kvs_client.set_value(key=test_key, value=new_value) + + # Verify the updated value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.value == new_value + + +async def test_delete_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test deleting a value.""" + test_key = 'test-delete' + test_value = 'Delete me' + + # Set a value + await kvs_client.set_value(key=test_key, value=test_value) + + # Verify it exists + key_path = kvs_client.path_to_kvs / test_key + metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' + assert key_path.exists() + assert metadata_path.exists() + + # Delete the value + await kvs_client.delete_value(key=test_key) + + # Verify files were deleted + assert not key_path.exists() + assert not metadata_path.exists() + + # Verify value is no longer retrievable + record = await kvs_client.get_value(key=test_key) + assert record is None + + +async def test_delete_nonexistent_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test deleting a value that doesn't exist.""" + # Should not raise an error + await kvs_client.delete_value(key='nonexistent-key') + + +async def test_iterate_keys_empty_store(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test iterating over keys in an empty store.""" + keys = [key async for key in kvs_client.iterate_keys()] + assert len(keys) == 0 + + +async def test_iterate_keys(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test iterating over keys.""" + # Add some values + await kvs_client.set_value(key='key1', value='value1') + await kvs_client.set_value(key='key2', value='value2') + await kvs_client.set_value(key='key3', value='value3') + + # Iterate over keys + keys = [key.key async for key in kvs_client.iterate_keys()] + assert len(keys) == 3 + assert sorted(keys) == ['key1', 'key2', 'key3'] + + +async def test_iterate_keys_with_limit(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test iterating over keys with a limit.""" + # Add some values + await kvs_client.set_value(key='key1', value='value1') + await kvs_client.set_value(key='key2', value='value2') + await kvs_client.set_value(key='key3', value='value3') + + # Iterate with limit + keys = [key.key async for key in kvs_client.iterate_keys(limit=2)] + assert len(keys) == 2 + + +async def test_iterate_keys_with_exclusive_start_key(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test iterating over keys with an exclusive start key.""" + # Add some values with alphabetical keys + await kvs_client.set_value(key='a-key', value='value-a') + await kvs_client.set_value(key='b-key', value='value-b') + await kvs_client.set_value(key='c-key', value='value-c') + await kvs_client.set_value(key='d-key', value='value-d') + + # Iterate with exclusive start key + keys = [key.key async for key in kvs_client.iterate_keys(exclusive_start_key='b-key')] + assert len(keys) == 2 + assert 'c-key' in keys + assert 'd-key' in keys + assert 'a-key' not in keys + assert 'b-key' not in keys + + +async def test_drop(tmp_path: Path) -> None: + """Test dropping a key-value store.""" + # Create a store and add a value + client = await FileSystemKeyValueStoreClient.open(name='to_drop', storage_dir=tmp_path) + await client.set_value(key='test', value='test-value') + + # Verify the store directory exists + kvs_path = client.path_to_kvs + assert kvs_path.exists() + + # Drop the store + await client.drop() + + # Verify the directory was removed + assert not kvs_path.exists() + + +async def test_metadata_updates(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that metadata is updated correctly after operations.""" + # Record initial timestamps + initial_created = kvs_client.created_at + initial_accessed = kvs_client.accessed_at + initial_modified = kvs_client.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await kvs_client.get_value(key='nonexistent') + + # Verify timestamps + assert kvs_client.created_at == initial_created + assert kvs_client.accessed_at > initial_accessed + assert kvs_client.modified_at == initial_modified + + accessed_after_get = kvs_client.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await kvs_client.set_value(key='new-key', value='new-value') + + # Verify timestamps again + assert kvs_client.created_at == initial_created + assert kvs_client.modified_at > initial_modified + assert kvs_client.accessed_at > accessed_after_get + + +async def test_get_public_url_not_supported(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that get_public_url raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match='Public URLs are not supported'): + await kvs_client.get_public_url(key='any-key') + + +async def test_infer_mime_type(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test MIME type inference for different value types.""" + # Test string + assert kvs_client._infer_mime_type('text') == 'text/plain; charset=utf-8' + + # Test JSON + assert kvs_client._infer_mime_type({'key': 'value'}) == 'application/json; charset=utf-8' + assert kvs_client._infer_mime_type([1, 2, 3]) == 'application/json; charset=utf-8' + + # Test binary + assert kvs_client._infer_mime_type(b'binary data') == 'application/octet-stream' + + # Test other types + assert kvs_client._infer_mime_type(123) == 'application/octet-stream' + + +async def test_concurrent_operations(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test concurrent operations on the key-value store.""" + + # Create multiple tasks to set different values concurrently + async def set_value(key: str, value: str) -> None: + await kvs_client.set_value(key=key, value=value) + + tasks = [asyncio.create_task(set_value(f'concurrent-key-{i}', f'value-{i}')) for i in range(10)] + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + # Verify all values were set correctly + for i in range(10): + key = f'concurrent-key-{i}' + record = await kvs_client.get_value(key=key) + assert record is not None + assert record.value == f'value-{i}' diff --git a/tests/unit/storage_clients/_memory/test_creation_management.py b/tests/unit/storage_clients/_memory/test_creation_management.py deleted file mode 100644 index 88a5e9e283..0000000000 --- a/tests/unit/storage_clients/_memory/test_creation_management.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import json -from pathlib import Path -from unittest.mock import AsyncMock, patch - -import pytest - -from crawlee._consts import METADATA_FILENAME -from crawlee.storage_clients._memory._creation_management import persist_metadata_if_enabled - - -async def test_persist_metadata_skips_when_disabled(tmp_path: Path) -> None: - await persist_metadata_if_enabled(data={'key': 'value'}, entity_directory=str(tmp_path), write_metadata=False) - assert not list(tmp_path.iterdir()) # The directory should be empty since write_metadata is False - - -async def test_persist_metadata_creates_files_and_directories_when_enabled(tmp_path: Path) -> None: - data = {'key': 'value'} - entity_directory = Path(tmp_path, 'new_dir') - await persist_metadata_if_enabled(data=data, entity_directory=str(entity_directory), write_metadata=True) - assert entity_directory.exists() is True # Check if directory was created - assert (entity_directory / METADATA_FILENAME).is_file() # Check if file was created - - -async def test_persist_metadata_correctly_writes_data(tmp_path: Path) -> None: - data = {'key': 'value'} - entity_directory = Path(tmp_path, 'data_dir') - await persist_metadata_if_enabled(data=data, entity_directory=str(entity_directory), write_metadata=True) - metadata_path = entity_directory / METADATA_FILENAME - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert json.loads(content) == data # Check if correct data was written - - -async def test_persist_metadata_rewrites_data_with_error(tmp_path: Path) -> None: - init_data = {'key': 'very_long_value'} - update_data = {'key': 'short_value'} - error_data = {'key': 'error'} - - entity_directory = Path(tmp_path, 'data_dir') - metadata_path = entity_directory / METADATA_FILENAME - - # write metadata with init_data - await persist_metadata_if_enabled(data=init_data, entity_directory=str(entity_directory), write_metadata=True) - - # rewrite metadata with new_data - await persist_metadata_if_enabled(data=update_data, entity_directory=str(entity_directory), write_metadata=True) - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert json.loads(content) == update_data # Check if correct data was rewritten - - # raise interrupt between opening a file and writing - module_for_patch = 'crawlee.storage_clients._memory._creation_management.json_dumps' - with patch(module_for_patch, AsyncMock(side_effect=KeyboardInterrupt())), pytest.raises(KeyboardInterrupt): - await persist_metadata_if_enabled(data=error_data, entity_directory=str(entity_directory), write_metadata=True) - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert content == '' # The file is empty after an error diff --git a/tests/unit/storage_clients/_memory/test_dataset_client.py b/tests/unit/storage_clients/_memory/test_dataset_client.py deleted file mode 100644 index 472d11a8b3..0000000000 --- a/tests/unit/storage_clients/_memory/test_dataset_client.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -import asyncio -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import DatasetClient - - -@pytest.fixture -async def dataset_client(memory_storage_client: MemoryStorageClient) -> DatasetClient: - datasets_client = memory_storage_client.datasets() - dataset_info = await datasets_client.get_or_create(name='test') - return memory_storage_client.dataset(dataset_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - dataset_client = memory_storage_client.dataset(id='nonexistent-id') - assert await dataset_client.get() is None - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.update(name='test-update') - - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.list_items() - - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.push_items([{'abc': 123}]) - await dataset_client.delete() - - -async def test_not_implemented(dataset_client: DatasetClient) -> None: - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await dataset_client.stream_items() - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await dataset_client.get_items_as_bytes() - - -async def test_get(dataset_client: DatasetClient) -> None: - await asyncio.sleep(0.1) - info = await dataset_client.get() - assert info is not None - assert info.id == dataset_client.id - assert info.accessed_at != info.created_at - - -async def test_update(dataset_client: DatasetClient) -> None: - new_dataset_name = 'test-update' - await dataset_client.push_items({'abc': 123}) - - old_dataset_info = await dataset_client.get() - assert old_dataset_info is not None - old_dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, old_dataset_info.name or '') - new_dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, new_dataset_name) - assert (old_dataset_directory / '000000001.json').exists() is True - assert (new_dataset_directory / '000000001.json').exists() is False - - await asyncio.sleep(0.1) - updated_dataset_info = await dataset_client.update(name=new_dataset_name) - assert (old_dataset_directory / '000000001.json').exists() is False - assert (new_dataset_directory / '000000001.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_dataset_info.created_at == updated_dataset_info.created_at - assert old_dataset_info.modified_at != updated_dataset_info.modified_at - assert old_dataset_info.accessed_at != updated_dataset_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Dataset with name "test-update" already exists.'): - await dataset_client.update(name=new_dataset_name) - - -async def test_delete(dataset_client: DatasetClient) -> None: - await dataset_client.push_items({'abc': 123}) - dataset_info = await dataset_client.get() - assert dataset_info is not None - dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, dataset_info.name or '') - assert (dataset_directory / '000000001.json').exists() is True - await dataset_client.delete() - assert (dataset_directory / '000000001.json').exists() is False - # Does not crash when called again - await dataset_client.delete() - - -async def test_push_items(dataset_client: DatasetClient) -> None: - await dataset_client.push_items('{"test": "JSON from a string"}') - await dataset_client.push_items({'abc': {'def': {'ghi': '123'}}}) - await dataset_client.push_items(['{"test-json-parse": "JSON from a string"}' for _ in range(10)]) - await dataset_client.push_items([{'test-dict': i} for i in range(10)]) - - list_page = await dataset_client.list_items() - assert list_page.items[0]['test'] == 'JSON from a string' - assert list_page.items[1]['abc']['def']['ghi'] == '123' - assert list_page.items[11]['test-json-parse'] == 'JSON from a string' - assert list_page.items[21]['test-dict'] == 9 - assert list_page.count == 22 - - -async def test_list_items(dataset_client: DatasetClient) -> None: - item_count = 100 - used_offset = 10 - used_limit = 50 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - # Test without any parameters - list_default = await dataset_client.list_items() - assert list_default.count == item_count - assert list_default.offset == 0 - assert list_default.items[0]['id'] == 0 - assert list_default.desc is False - # Test offset - list_offset_10 = await dataset_client.list_items(offset=used_offset) - assert list_offset_10.count == item_count - used_offset - assert list_offset_10.offset == used_offset - assert list_offset_10.total == item_count - assert list_offset_10.items[0]['id'] == used_offset - # Test limit - list_limit_50 = await dataset_client.list_items(limit=used_limit) - assert list_limit_50.count == used_limit - assert list_limit_50.limit == used_limit - assert list_limit_50.total == item_count - # Test desc - list_desc_true = await dataset_client.list_items(desc=True) - assert list_desc_true.items[0]['id'] == 99 - assert list_desc_true.desc is True - - -async def test_iterate_items(dataset_client: DatasetClient) -> None: - item_count = 100 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - actual_items = [] - async for item in dataset_client.iterate_items(): - assert 'id' in item - actual_items.append(item) - assert len(actual_items) == item_count - assert actual_items[0]['id'] == 0 - assert actual_items[99]['id'] == 99 - - -async def test_reuse_dataset(dataset_client: DatasetClient, memory_storage_client: MemoryStorageClient) -> None: - item_count = 10 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - - memory_storage_client.datasets_handled = [] # purge datasets loaded to test create_dataset_from_directory - datasets_client = memory_storage_client.datasets() - dataset_info = await datasets_client.get_or_create(name='test') - assert dataset_info.item_count == item_count diff --git a/tests/unit/storage_clients/_memory/test_dataset_collection_client.py b/tests/unit/storage_clients/_memory/test_dataset_collection_client.py deleted file mode 100644 index d71b7e8f68..0000000000 --- a/tests/unit/storage_clients/_memory/test_dataset_collection_client.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import DatasetCollectionClient - - -@pytest.fixture -def datasets_client(memory_storage_client: MemoryStorageClient) -> DatasetCollectionClient: - return memory_storage_client.datasets() - - -async def test_get_or_create(datasets_client: DatasetCollectionClient) -> None: - dataset_name = 'test' - # A new dataset gets created - dataset_info = await datasets_client.get_or_create(name=dataset_name) - assert dataset_info.name == dataset_name - - # Another get_or_create call returns the same dataset - dataset_info_existing = await datasets_client.get_or_create(name=dataset_name) - assert dataset_info.id == dataset_info_existing.id - assert dataset_info.name == dataset_info_existing.name - assert dataset_info.created_at == dataset_info_existing.created_at - - -async def test_list(datasets_client: DatasetCollectionClient) -> None: - dataset_list_1 = await datasets_client.list() - assert dataset_list_1.count == 0 - - dataset_info = await datasets_client.get_or_create(name='dataset') - dataset_list_2 = await datasets_client.list() - - assert dataset_list_2.count == 1 - assert dataset_list_2.items[0].name == dataset_info.name - - # Test sorting behavior - newer_dataset_info = await datasets_client.get_or_create(name='newer-dataset') - dataset_list_sorting = await datasets_client.list() - assert dataset_list_sorting.count == 2 - assert dataset_list_sorting.items[0].name == dataset_info.name - assert dataset_list_sorting.items[1].name == newer_dataset_info.name diff --git a/tests/unit/storage_clients/_memory/test_key_value_store_client.py b/tests/unit/storage_clients/_memory/test_key_value_store_client.py deleted file mode 100644 index c7813b5b84..0000000000 --- a/tests/unit/storage_clients/_memory/test_key_value_store_client.py +++ /dev/null @@ -1,442 +0,0 @@ -from __future__ import annotations - -import asyncio -import base64 -import json -from datetime import datetime, timezone -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -from crawlee._consts import METADATA_FILENAME -from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import maybe_parse_body -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import KeyValueStoreClient - -TINY_PNG = base64.b64decode( - s='iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVQYV2NgYAAAAAMAAWgmWQ0AAAAASUVORK5CYII=', -) -TINY_BYTES = b'\x12\x34\x56\x78\x90\xab\xcd\xef' -TINY_DATA = {'a': 'b'} -TINY_TEXT = 'abcd' - - -@pytest.fixture -async def key_value_store_client(memory_storage_client: MemoryStorageClient) -> KeyValueStoreClient: - key_value_stores_client = memory_storage_client.key_value_stores() - kvs_info = await key_value_stores_client.get_or_create(name='test') - return memory_storage_client.key_value_store(kvs_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - kvs_client = memory_storage_client.key_value_store(id='nonexistent-id') - assert await kvs_client.get() is None - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.update(name='test-update') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.list_keys() - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.set_record('test', {'abc': 123}) - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.get_record('test') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.get_record_as_bytes('test') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.delete_record('test') - - await kvs_client.delete() - - -async def test_not_implemented(key_value_store_client: KeyValueStoreClient) -> None: - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await key_value_store_client.stream_record('test') - - -async def test_get(key_value_store_client: KeyValueStoreClient) -> None: - await asyncio.sleep(0.1) - info = await key_value_store_client.get() - assert info is not None - assert info.id == key_value_store_client.id - assert info.accessed_at != info.created_at - - -async def test_update(key_value_store_client: KeyValueStoreClient) -> None: - new_kvs_name = 'test-update' - await key_value_store_client.set_record('test', {'abc': 123}) - old_kvs_info = await key_value_store_client.get() - assert old_kvs_info is not None - old_kvs_directory = Path( - key_value_store_client._memory_storage_client.key_value_stores_directory, old_kvs_info.name or '' - ) - new_kvs_directory = Path(key_value_store_client._memory_storage_client.key_value_stores_directory, new_kvs_name) - assert (old_kvs_directory / 'test.json').exists() is True - assert (new_kvs_directory / 'test.json').exists() is False - - await asyncio.sleep(0.1) - updated_kvs_info = await key_value_store_client.update(name=new_kvs_name) - assert (old_kvs_directory / 'test.json').exists() is False - assert (new_kvs_directory / 'test.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_kvs_info.created_at == updated_kvs_info.created_at - assert old_kvs_info.modified_at != updated_kvs_info.modified_at - assert old_kvs_info.accessed_at != updated_kvs_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Key-value store with name "test-update" already exists.'): - await key_value_store_client.update(name=new_kvs_name) - - -async def test_delete(key_value_store_client: KeyValueStoreClient) -> None: - await key_value_store_client.set_record('test', {'abc': 123}) - kvs_info = await key_value_store_client.get() - assert kvs_info is not None - kvs_directory = Path(key_value_store_client._memory_storage_client.key_value_stores_directory, kvs_info.name or '') - assert (kvs_directory / 'test.json').exists() is True - await key_value_store_client.delete() - assert (kvs_directory / 'test.json').exists() is False - # Does not crash when called again - await key_value_store_client.delete() - - -async def test_list_keys_empty(key_value_store_client: KeyValueStoreClient) -> None: - keys = await key_value_store_client.list_keys() - assert len(keys.items) == 0 - assert keys.count == 0 - assert keys.is_truncated is False - - -async def test_list_keys(key_value_store_client: KeyValueStoreClient) -> None: - record_count = 4 - used_limit = 2 - used_exclusive_start_key = 'a' - await key_value_store_client.set_record('b', 'test') - await key_value_store_client.set_record('a', 'test') - await key_value_store_client.set_record('d', 'test') - await key_value_store_client.set_record('c', 'test') - - # Default settings - keys = await key_value_store_client.list_keys() - assert keys.items[0].key == 'a' - assert keys.items[3].key == 'd' - assert keys.count == record_count - assert keys.is_truncated is False - # Test limit - keys_limit_2 = await key_value_store_client.list_keys(limit=used_limit) - assert keys_limit_2.count == record_count - assert keys_limit_2.limit == used_limit - assert keys_limit_2.items[1].key == 'b' - # Test exclusive start key - keys_exclusive_start = await key_value_store_client.list_keys(exclusive_start_key=used_exclusive_start_key, limit=2) - assert keys_exclusive_start.exclusive_start_key == used_exclusive_start_key - assert keys_exclusive_start.is_truncated is True - assert keys_exclusive_start.next_exclusive_start_key == 'c' - assert keys_exclusive_start.items[0].key == 'b' - assert keys_exclusive_start.items[-1].key == keys_exclusive_start.next_exclusive_start_key - - -async def test_get_and_set_record(tmp_path: Path, key_value_store_client: KeyValueStoreClient) -> None: - # Test setting dict record - dict_record_key = 'test-dict' - await key_value_store_client.set_record(dict_record_key, {'test': 123}) - dict_record_info = await key_value_store_client.get_record(dict_record_key) - assert dict_record_info is not None - assert 'application/json' in str(dict_record_info.content_type) - assert dict_record_info.value['test'] == 123 - - # Test setting str record - str_record_key = 'test-str' - await key_value_store_client.set_record(str_record_key, 'test') - str_record_info = await key_value_store_client.get_record(str_record_key) - assert str_record_info is not None - assert 'text/plain' in str(str_record_info.content_type) - assert str_record_info.value == 'test' - - # Test setting explicit json record but use str as value, i.e. json dumps is skipped - explicit_json_key = 'test-json' - await key_value_store_client.set_record(explicit_json_key, '{"test": "explicit string"}', 'application/json') - bytes_record_info = await key_value_store_client.get_record(explicit_json_key) - assert bytes_record_info is not None - assert 'application/json' in str(bytes_record_info.content_type) - assert bytes_record_info.value['test'] == 'explicit string' - - # Test using bytes - bytes_key = 'test-json' - bytes_value = b'testing bytes set_record' - await key_value_store_client.set_record(bytes_key, bytes_value, 'unknown') - bytes_record_info = await key_value_store_client.get_record(bytes_key) - assert bytes_record_info is not None - assert 'unknown' in str(bytes_record_info.content_type) - assert bytes_record_info.value == bytes_value - assert bytes_record_info.value.decode('utf-8') == bytes_value.decode('utf-8') - - # Test using file descriptor - with open(tmp_path / 'test.json', 'w+', encoding='utf-8') as f: # noqa: ASYNC230 - f.write('Test') - with pytest.raises(NotImplementedError, match='File-like values are not supported in local memory storage'): - await key_value_store_client.set_record('file', f) - - -async def test_get_record_as_bytes(key_value_store_client: KeyValueStoreClient) -> None: - record_key = 'test' - record_value = 'testing' - await key_value_store_client.set_record(record_key, record_value) - record_info = await key_value_store_client.get_record_as_bytes(record_key) - assert record_info is not None - assert record_info.value == record_value.encode('utf-8') - - -async def test_delete_record(key_value_store_client: KeyValueStoreClient) -> None: - record_key = 'test' - await key_value_store_client.set_record(record_key, 'test') - await key_value_store_client.delete_record(record_key) - # Does not crash when called again - await key_value_store_client.delete_record(record_key) - - -@pytest.mark.parametrize( - ('input_data', 'expected_output'), - [ - ( - {'key': 'image', 'value': TINY_PNG, 'contentType': None}, - {'filename': 'image', 'key': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'key': 'image', 'value': TINY_PNG, 'contentType': 'image/png'}, - {'filename': 'image.png', 'key': 'image', 'contentType': 'image/png'}, - ), - ( - {'key': 'image.png', 'value': TINY_PNG, 'contentType': None}, - {'filename': 'image.png', 'key': 'image.png', 'contentType': 'application/octet-stream'}, - ), - ( - {'key': 'image.png', 'value': TINY_PNG, 'contentType': 'image/png'}, - {'filename': 'image.png', 'key': 'image.png', 'contentType': 'image/png'}, - ), - ( - {'key': 'data', 'value': TINY_DATA, 'contentType': None}, - {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}, - ), - ( - {'key': 'data', 'value': TINY_DATA, 'contentType': 'application/json'}, - {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}, - ), - ( - {'key': 'data.json', 'value': TINY_DATA, 'contentType': None}, - {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}, - ), - ( - {'key': 'data.json', 'value': TINY_DATA, 'contentType': 'application/json'}, - {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}, - ), - ( - {'key': 'text', 'value': TINY_TEXT, 'contentType': None}, - {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text', 'value': TINY_TEXT, 'contentType': 'text/plain'}, - {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': None}, - {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': 'text/plain'}, - {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}, - ), - ], -) -async def test_writes_correct_metadata( - memory_storage_client: MemoryStorageClient, - input_data: dict, - expected_output: dict, -) -> None: - key_value_store_name = crypto_random_object_id() - - # Get KVS client - kvs_info = await memory_storage_client.key_value_stores().get_or_create(name=key_value_store_name) - kvs_client = memory_storage_client.key_value_store(kvs_info.id) - - # Write the test input item to the store - await kvs_client.set_record( - key=input_data['key'], - value=input_data['value'], - content_type=input_data['contentType'], - ) - - # Check that everything was written correctly, both the data and metadata - storage_path = Path(memory_storage_client.key_value_stores_directory, key_value_store_name) - item_path = Path(storage_path, expected_output['filename']) - item_metadata_path = storage_path / f'{expected_output["filename"]}.__metadata__.json' - - assert item_path.exists() - assert item_metadata_path.exists() - - # Test the actual value of the item - with open(item_path, 'rb') as item_file: # noqa: ASYNC230 - actual_value = maybe_parse_body(item_file.read(), expected_output['contentType']) - assert actual_value == input_data['value'] - - # Test the actual metadata of the item - with open(item_metadata_path, encoding='utf-8') as metadata_file: # noqa: ASYNC230 - json_content = json.load(metadata_file) - metadata = KeyValueStoreRecordMetadata(**json_content) - assert metadata.key == expected_output['key'] - assert expected_output['contentType'] in metadata.content_type - - -@pytest.mark.parametrize( - ('input_data', 'expected_output'), - [ - ( - {'filename': 'image', 'value': TINY_PNG, 'metadata': None}, - {'key': 'image', 'filename': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': None}, - {'key': 'image', 'filename': 'image.png', 'contentType': 'image/png'}, - ), - ( - { - 'filename': 'image', - 'value': TINY_PNG, - 'metadata': {'key': 'image', 'contentType': 'application/octet-stream'}, - }, - {'key': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, - {'key': 'image', 'filename': 'image', 'contentType': 'image/png'}, - ), - ( - { - 'filename': 'image.png', - 'value': TINY_PNG, - 'metadata': {'key': 'image.png', 'contentType': 'application/octet-stream'}, - }, - {'key': 'image.png', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image.png', 'contentType': 'image/png'}}, - {'key': 'image.png', 'contentType': 'image/png'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, - {'key': 'image', 'contentType': 'image/png'}, - ), - ( - {'filename': 'input', 'value': TINY_BYTES, 'metadata': None}, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'input.json', 'value': TINY_DATA, 'metadata': None}, - {'key': 'input', 'contentType': 'application/json'}, - ), - ( - {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': None}, - {'key': 'input', 'contentType': 'text/plain'}, - ), - ( - {'filename': 'input.bin', 'value': TINY_BYTES, 'metadata': None}, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - { - 'filename': 'input', - 'value': TINY_BYTES, - 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}, - }, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - { - 'filename': 'input.json', - 'value': TINY_DATA, - 'metadata': {'key': 'input', 'contentType': 'application/json'}, - }, - {'key': 'input', 'contentType': 'application/json'}, - ), - ( - {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': {'key': 'input', 'contentType': 'text/plain'}}, - {'key': 'input', 'contentType': 'text/plain'}, - ), - ( - { - 'filename': 'input.bin', - 'value': TINY_BYTES, - 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}, - }, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ], -) -async def test_reads_correct_metadata( - memory_storage_client: MemoryStorageClient, - input_data: dict, - expected_output: dict, -) -> None: - key_value_store_name = crypto_random_object_id() - - # Ensure the directory for the store exists - storage_path = Path(memory_storage_client.key_value_stores_directory, key_value_store_name) - storage_path.mkdir(exist_ok=True, parents=True) - - store_metadata = KeyValueStoreMetadata( - id=crypto_random_object_id(), - name='', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - ) - - # Write the store metadata to disk - storage_metadata_path = storage_path / METADATA_FILENAME - with open(storage_metadata_path, mode='wb') as f: # noqa: ASYNC230 - f.write(store_metadata.model_dump_json().encode('utf-8')) - - # Write the test input item to the disk - item_path = storage_path / input_data['filename'] - with open(item_path, 'wb') as item_file: # noqa: ASYNC230 - if isinstance(input_data['value'], bytes): - item_file.write(input_data['value']) - elif isinstance(input_data['value'], str): - item_file.write(input_data['value'].encode('utf-8')) - else: - s = await json_dumps(input_data['value']) - item_file.write(s.encode('utf-8')) - - # Optionally write the metadata to disk if there is some - if input_data['metadata'] is not None: - storage_metadata_path = storage_path / f'{input_data["filename"]}.__metadata__.json' - with open(storage_metadata_path, 'w', encoding='utf-8') as metadata_file: # noqa: ASYNC230 - s = await json_dumps( - { - 'key': input_data['metadata']['key'], - 'contentType': input_data['metadata']['contentType'], - } - ) - metadata_file.write(s) - - # Create the key-value store client to load the items from disk - store_details = await memory_storage_client.key_value_stores().get_or_create(name=key_value_store_name) - key_value_store_client = memory_storage_client.key_value_store(store_details.id) - - # Read the item from the store and check if it is as expected - actual_record = await key_value_store_client.get_record(expected_output['key']) - assert actual_record is not None - - assert actual_record.key == expected_output['key'] - assert actual_record.content_type == expected_output['contentType'] - assert actual_record.value == input_data['value'] diff --git a/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py b/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py deleted file mode 100644 index 41b289eb06..0000000000 --- a/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import KeyValueStoreCollectionClient - - -@pytest.fixture -def key_value_stores_client(memory_storage_client: MemoryStorageClient) -> KeyValueStoreCollectionClient: - return memory_storage_client.key_value_stores() - - -async def test_get_or_create(key_value_stores_client: KeyValueStoreCollectionClient) -> None: - kvs_name = 'test' - # A new kvs gets created - kvs_info = await key_value_stores_client.get_or_create(name=kvs_name) - assert kvs_info.name == kvs_name - - # Another get_or_create call returns the same kvs - kvs_info_existing = await key_value_stores_client.get_or_create(name=kvs_name) - assert kvs_info.id == kvs_info_existing.id - assert kvs_info.name == kvs_info_existing.name - assert kvs_info.created_at == kvs_info_existing.created_at - - -async def test_list(key_value_stores_client: KeyValueStoreCollectionClient) -> None: - assert (await key_value_stores_client.list()).count == 0 - kvs_info = await key_value_stores_client.get_or_create(name='kvs') - kvs_list = await key_value_stores_client.list() - assert kvs_list.count == 1 - assert kvs_list.items[0].name == kvs_info.name - - # Test sorting behavior - newer_kvs_info = await key_value_stores_client.get_or_create(name='newer-kvs') - kvs_list_sorting = await key_value_stores_client.list() - assert kvs_list_sorting.count == 2 - assert kvs_list_sorting.items[0].name == kvs_info.name - assert kvs_list_sorting.items[1].name == newer_kvs_info.name diff --git a/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py b/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py deleted file mode 100644 index c79fa66792..0000000000 --- a/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timezone -from typing import Callable - -import pytest - -from crawlee import Request, service_locator -from crawlee.storages._key_value_store import KeyValueStore -from crawlee.storages._request_queue import RequestQueue - - -@pytest.mark.parametrize('purge_on_start', [True, False]) -async def test_actor_memory_storage_client_key_value_store_e2e( - monkeypatch: pytest.MonkeyPatch, - purge_on_start: bool, # noqa: FBT001 - prepare_test_env: Callable[[], None], -) -> None: - """This test simulates two clean runs using memory storage. - The second run attempts to access data created by the first one. - We run 2 configurations with different `purge_on_start`.""" - # Configure purging env var - monkeypatch.setenv('CRAWLEE_PURGE_ON_START', f'{int(purge_on_start)}') - # Store old storage client so we have the object reference for comparison - old_client = service_locator.get_storage_client() - - old_default_kvs = await KeyValueStore.open() - old_non_default_kvs = await KeyValueStore.open(name='non-default') - # Create data in default and non-default key-value store - await old_default_kvs.set_value('test', 'default value') - await old_non_default_kvs.set_value('test', 'non-default value') - - # We simulate another clean run, we expect the memory storage to read from the local data directory - # Default storages are purged based on purge_on_start parameter. - prepare_test_env() - - # Check if we're using a different memory storage instance - assert old_client is not service_locator.get_storage_client() - default_kvs = await KeyValueStore.open() - assert default_kvs is not old_default_kvs - non_default_kvs = await KeyValueStore.open(name='non-default') - assert non_default_kvs is not old_non_default_kvs - default_value = await default_kvs.get_value('test') - - if purge_on_start: - assert default_value is None - else: - assert default_value == 'default value' - - assert await non_default_kvs.get_value('test') == 'non-default value' - - -@pytest.mark.parametrize('purge_on_start', [True, False]) -async def test_actor_memory_storage_client_request_queue_e2e( - monkeypatch: pytest.MonkeyPatch, - purge_on_start: bool, # noqa: FBT001 - prepare_test_env: Callable[[], None], -) -> None: - """This test simulates two clean runs using memory storage. - The second run attempts to access data created by the first one. - We run 2 configurations with different `purge_on_start`.""" - # Configure purging env var - monkeypatch.setenv('CRAWLEE_PURGE_ON_START', f'{int(purge_on_start)}') - - # Add some requests to the default queue - default_queue = await RequestQueue.open() - for i in range(6): - # [0, 3] <- nothing special - # [1, 4] <- forefront=True - # [2, 5] <- handled=True - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await default_queue.add_request( - Request.from_url( - unique_key=str(i), - url=request_url, - handled_at=datetime.now(timezone.utc) if was_handled else None, - payload=b'test', - ), - forefront=forefront, - ) - - # We simulate another clean run, we expect the memory storage to read from the local data directory - # Default storages are purged based on purge_on_start parameter. - prepare_test_env() - - # Add some more requests to the default queue - default_queue = await RequestQueue.open() - for i in range(6, 12): - # [6, 9] <- nothing special - # [7, 10] <- forefront=True - # [8, 11] <- handled=True - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await default_queue.add_request( - Request.from_url( - unique_key=str(i), - url=request_url, - handled_at=datetime.now(timezone.utc) if was_handled else None, - payload=b'test', - ), - forefront=forefront, - ) - - queue_info = await default_queue.get_info() - assert queue_info is not None - - # If the queue was purged between the runs, only the requests from the second run should be present, - # in the right order - if purge_on_start: - assert queue_info.total_request_count == 6 - assert queue_info.handled_request_count == 2 - - expected_pending_request_order = [10, 7, 6, 9] - # If the queue was NOT purged between the runs, all the requests should be in the queue in the right order - else: - assert queue_info.total_request_count == 12 - assert queue_info.handled_request_count == 4 - - expected_pending_request_order = [10, 7, 4, 1, 0, 3, 6, 9] - - actual_requests = list[Request]() - while req := await default_queue.fetch_next_request(): - actual_requests.append(req) - - assert [int(req.unique_key) for req in actual_requests] == expected_pending_request_order - assert [req.url for req in actual_requests] == [f'http://example.com/{req.unique_key}' for req in actual_requests] - assert [req.payload for req in actual_requests] == [b'test' for _ in actual_requests] diff --git a/tests/unit/storage_clients/_memory/test_request_queue_client.py b/tests/unit/storage_clients/_memory/test_request_queue_client.py deleted file mode 100644 index feffacbbd8..0000000000 --- a/tests/unit/storage_clients/_memory/test_request_queue_client.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -import asyncio -from datetime import datetime, timezone -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -from crawlee import Request -from crawlee._request import RequestState - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import RequestQueueClient - - -@pytest.fixture -async def request_queue_client(memory_storage_client: MemoryStorageClient) -> RequestQueueClient: - request_queues_client = memory_storage_client.request_queues() - rq_info = await request_queues_client.get_or_create(name='test') - return memory_storage_client.request_queue(rq_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - request_queue_client = memory_storage_client.request_queue(id='nonexistent-id') - assert await request_queue_client.get() is None - with pytest.raises(ValueError, match='Request queue with id "nonexistent-id" does not exist.'): - await request_queue_client.update(name='test-update') - await request_queue_client.delete() - - -async def test_get(request_queue_client: RequestQueueClient) -> None: - await asyncio.sleep(0.1) - info = await request_queue_client.get() - assert info is not None - assert info.id == request_queue_client.id - assert info.accessed_at != info.created_at - - -async def test_update(request_queue_client: RequestQueueClient) -> None: - new_rq_name = 'test-update' - request = Request.from_url('https://apify.com') - await request_queue_client.add_request(request) - old_rq_info = await request_queue_client.get() - assert old_rq_info is not None - assert old_rq_info.name is not None - old_rq_directory = Path( - request_queue_client._memory_storage_client.request_queues_directory, - old_rq_info.name, - ) - new_rq_directory = Path(request_queue_client._memory_storage_client.request_queues_directory, new_rq_name) - assert (old_rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - assert (new_rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - - await asyncio.sleep(0.1) - updated_rq_info = await request_queue_client.update(name=new_rq_name) - assert (old_rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - assert (new_rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_rq_info.created_at == updated_rq_info.created_at - assert old_rq_info.modified_at != updated_rq_info.modified_at - assert old_rq_info.accessed_at != updated_rq_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Request queue with name "test-update" already exists'): - await request_queue_client.update(name=new_rq_name) - - -async def test_delete(request_queue_client: RequestQueueClient) -> None: - await request_queue_client.add_request(Request.from_url('https://apify.com')) - rq_info = await request_queue_client.get() - assert rq_info is not None - - rq_directory = Path(request_queue_client._memory_storage_client.request_queues_directory, str(rq_info.name)) - assert (rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - - await request_queue_client.delete() - assert (rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - - # Does not crash when called again - await request_queue_client.delete() - - -async def test_list_head(request_queue_client: RequestQueueClient) -> None: - await request_queue_client.add_request(Request.from_url('https://apify.com')) - await request_queue_client.add_request(Request.from_url('https://example.com')) - list_head = await request_queue_client.list_head() - assert len(list_head.items) == 2 - - for item in list_head.items: - assert item.id is not None - - -async def test_request_state_serialization(request_queue_client: RequestQueueClient) -> None: - request = Request.from_url('https://crawlee.dev', payload=b'test') - request.state = RequestState.UNPROCESSED - - await request_queue_client.add_request(request) - - result = await request_queue_client.list_head() - assert len(result.items) == 1 - assert result.items[0] == request - - got_request = await request_queue_client.get_request(request.id) - - assert request == got_request - - -async def test_add_record(request_queue_client: RequestQueueClient) -> None: - processed_request_forefront = await request_queue_client.add_request( - Request.from_url('https://apify.com'), - forefront=True, - ) - processed_request_not_forefront = await request_queue_client.add_request( - Request.from_url('https://example.com'), - forefront=False, - ) - - assert processed_request_forefront.id is not None - assert processed_request_not_forefront.id is not None - assert processed_request_forefront.was_already_handled is False - assert processed_request_not_forefront.was_already_handled is False - - rq_info = await request_queue_client.get() - assert rq_info is not None - assert rq_info.pending_request_count == rq_info.total_request_count == 2 - assert rq_info.handled_request_count == 0 - - -async def test_get_record(request_queue_client: RequestQueueClient) -> None: - request_url = 'https://apify.com' - processed_request = await request_queue_client.add_request(Request.from_url(request_url)) - - request = await request_queue_client.get_request(processed_request.id) - assert request is not None - assert request.url == request_url - - # Non-existent id - assert (await request_queue_client.get_request('non-existent id')) is None - - -async def test_update_record(request_queue_client: RequestQueueClient) -> None: - processed_request = await request_queue_client.add_request(Request.from_url('https://apify.com')) - request = await request_queue_client.get_request(processed_request.id) - assert request is not None - - rq_info_before_update = await request_queue_client.get() - assert rq_info_before_update is not None - assert rq_info_before_update.pending_request_count == 1 - assert rq_info_before_update.handled_request_count == 0 - - request.handled_at = datetime.now(timezone.utc) - request_update_info = await request_queue_client.update_request(request) - - assert request_update_info.was_already_handled is False - - rq_info_after_update = await request_queue_client.get() - assert rq_info_after_update is not None - assert rq_info_after_update.pending_request_count == 0 - assert rq_info_after_update.handled_request_count == 1 - - -async def test_delete_record(request_queue_client: RequestQueueClient) -> None: - processed_request_pending = await request_queue_client.add_request( - Request.from_url( - url='https://apify.com', - unique_key='pending', - ), - ) - - processed_request_handled = await request_queue_client.add_request( - Request.from_url( - url='https://apify.com', - unique_key='handled', - handled_at=datetime.now(timezone.utc), - ), - ) - - rq_info_before_delete = await request_queue_client.get() - assert rq_info_before_delete is not None - assert rq_info_before_delete.pending_request_count == 1 - - await request_queue_client.delete_request(processed_request_pending.id) - rq_info_after_first_delete = await request_queue_client.get() - assert rq_info_after_first_delete is not None - assert rq_info_after_first_delete.pending_request_count == 0 - assert rq_info_after_first_delete.handled_request_count == 1 - - await request_queue_client.delete_request(processed_request_handled.id) - rq_info_after_second_delete = await request_queue_client.get() - assert rq_info_after_second_delete is not None - assert rq_info_after_second_delete.pending_request_count == 0 - assert rq_info_after_second_delete.handled_request_count == 0 - - # Does not crash when called again - await request_queue_client.delete_request(processed_request_pending.id) - - -async def test_forefront(request_queue_client: RequestQueueClient) -> None: - # this should create a queue with requests in this order: - # Handled: - # 2, 5, 8 - # Not handled: - # 7, 4, 1, 0, 3, 6 - for i in range(9): - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await request_queue_client.add_request( - Request.from_url( - url=request_url, - unique_key=str(i), - handled_at=datetime.now(timezone.utc) if was_handled else None, - ), - forefront=forefront, - ) - - # Check that the queue head (unhandled items) is in the right order - queue_head = await request_queue_client.list_head() - req_unique_keys = [req.unique_key for req in queue_head.items] - assert req_unique_keys == ['7', '4', '1', '0', '3', '6'] - - # Mark request #1 as handled - await request_queue_client.update_request( - Request.from_url( - url='http://example.com/1', - unique_key='1', - handled_at=datetime.now(timezone.utc), - ), - ) - # Move request #3 to forefront - await request_queue_client.update_request( - Request.from_url(url='http://example.com/3', unique_key='3'), - forefront=True, - ) - - # Check that the queue head (unhandled items) is in the right order after the updates - queue_head = await request_queue_client.list_head() - req_unique_keys = [req.unique_key for req in queue_head.items] - assert req_unique_keys == ['3', '7', '4', '0', '6'] - - -async def test_add_duplicate_record(request_queue_client: RequestQueueClient) -> None: - processed_request = await request_queue_client.add_request(Request.from_url('https://apify.com')) - processed_request_duplicate = await request_queue_client.add_request(Request.from_url('https://apify.com')) - - assert processed_request.id == processed_request_duplicate.id - assert processed_request_duplicate.was_already_present is True diff --git a/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py b/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py deleted file mode 100644 index fa10889f83..0000000000 --- a/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import RequestQueueCollectionClient - - -@pytest.fixture -def request_queues_client(memory_storage_client: MemoryStorageClient) -> RequestQueueCollectionClient: - return memory_storage_client.request_queues() - - -async def test_get_or_create(request_queues_client: RequestQueueCollectionClient) -> None: - rq_name = 'test' - # A new request queue gets created - rq_info = await request_queues_client.get_or_create(name=rq_name) - assert rq_info.name == rq_name - - # Another get_or_create call returns the same request queue - rq_existing = await request_queues_client.get_or_create(name=rq_name) - assert rq_info.id == rq_existing.id - assert rq_info.name == rq_existing.name - assert rq_info.created_at == rq_existing.created_at - - -async def test_list(request_queues_client: RequestQueueCollectionClient) -> None: - assert (await request_queues_client.list()).count == 0 - rq_info = await request_queues_client.get_or_create(name='dataset') - rq_list = await request_queues_client.list() - assert rq_list.count == 1 - assert rq_list.items[0].name == rq_info.name - - # Test sorting behavior - newer_rq_info = await request_queues_client.get_or_create(name='newer-dataset') - rq_list_sorting = await request_queues_client.list() - assert rq_list_sorting.count == 2 - assert rq_list_sorting.items[0].name == rq_info.name - assert rq_list_sorting.items[1].name == newer_rq_info.name diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index 955d483546..dc82c412c2 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -134,7 +134,7 @@ async def test_get_public_url(key_value_store: KeyValueStore) -> None: url = urlparse(public_url) path = url.netloc if url.netloc else url.path - with open(path) as f: # noqa: ASYNC230 + with open(path) as f: content = await asyncio.to_thread(f.read) assert content == 'static' From 8df87f9caa58db368783f6ff8b6aa4ae02052132 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Sat, 12 Apr 2025 14:19:50 +0200 Subject: [PATCH 08/22] Memory storage clients and their tests --- src/crawlee/_utils/file.py | 25 ++ src/crawlee/crawlers/_basic/_basic_crawler.py | 4 +- .../storage_clients/_base/_dataset_client.py | 69 ++--- .../_base/_key_value_store_client.py | 65 ++--- .../_file_system/_dataset_client.py | 93 +++--- .../_file_system/_key_value_store_client.py | 74 ++--- .../_memory/_dataset_client.py | 156 +++++++--- .../_memory/_key_value_store_client.py | 202 ++++++++++++- src/crawlee/storages/_dataset.py | 10 +- .../test_adaptive_playwright_crawler.py | 2 +- ...et_client.py => test_fs_dataset_client.py} | 24 +- ..._store_client.py => test_fs_kvs_client.py} | 56 ++-- .../_memory/test_memory_dataset_client.py | 276 ++++++++++++++++++ .../_memory/test_memory_kvs_client.py | 237 +++++++++++++++ tests/unit/storages/test_dataset.py | 2 +- 15 files changed, 1035 insertions(+), 260 deletions(-) rename tests/unit/storage_clients/_file_system/{test_dataset_client.py => test_fs_dataset_client.py} (91%) rename tests/unit/storage_clients/_file_system/{test_key_value_store_client.py => test_fs_kvs_client.py} (85%) create mode 100644 tests/unit/storage_clients/_memory/test_memory_dataset_client.py create mode 100644 tests/unit/storage_clients/_memory/test_memory_kvs_client.py diff --git a/src/crawlee/_utils/file.py b/src/crawlee/_utils/file.py index c74bdbb771..6a2100dd87 100644 --- a/src/crawlee/_utils/file.py +++ b/src/crawlee/_utils/file.py @@ -103,6 +103,31 @@ async def json_dumps(obj: Any) -> str: return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) +def infer_mime_type(value: Any) -> str: + """Infer the MIME content type from the value. + + Args: + value: The value to infer the content type from. + + Returns: + The inferred MIME content type. + """ + # If the value is bytes (or bytearray), return binary content type. + if isinstance(value, (bytes, bytearray)): + return 'application/octet-stream' + + # If the value is a dict or list, assume JSON. + if isinstance(value, (dict, list)): + return 'application/json; charset=utf-8' + + # If the value is a string, assume plain text. + if isinstance(value, str): + return 'text/plain; charset=utf-8' + + # Default fallback. + return 'application/octet-stream' + + async def export_json_to_stream( iterator: AsyncIterator[dict], dst: TextIO, diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index bb79761c92..f87234abad 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -714,9 +714,9 @@ async def export_data( dst = path.open('w', newline='') if path.suffix == '.csv': - await export_csv_to_stream(dataset.iterate(), dst) + await export_csv_to_stream(dataset.iterate_items(), dst) elif path.suffix == '.json': - await export_json_to_stream(dataset.iterate(), dst) + await export_json_to_stream(dataset.iterate_items(), dst) else: raise ValueError(f'Unsupported file extension: {path.suffix}') diff --git a/src/crawlee/storage_clients/_base/_dataset_client.py b/src/crawlee/storage_clients/_base/_dataset_client.py index b9b6767310..f9086f4d9c 100644 --- a/src/crawlee/storage_clients/_base/_dataset_client.py +++ b/src/crawlee/storage_clients/_base/_dataset_client.py @@ -14,81 +14,74 @@ from crawlee.storage_clients.models import DatasetItemsListPage -# Properties: -# - id -# - name -# - created_at -# - accessed_at -# - modified_at -# - item_count - -# Methods: -# - open -# - drop -# - push_data -# - get_data -# - iterate - - @docs_group('Abstract classes') class DatasetClient(ABC): - """An abstract class for dataset resource clients. + """An abstract class for dataset storage clients. + + Dataset clients provide an interface for accessing and manipulating dataset storage. They handle + operations like adding and getting dataset items across different storage backends. - These clients are specific to the type of resource they manage and operate under a designated storage - client, like a memory storage client. + Storage clients are specific to the type of storage they manage (`Dataset`, `KeyValueStore`, + `RequestQueue`), and can operate with various storage systems including memory, file system, + databases, and cloud storage solutions. + + This abstract class defines the interface that all specific dataset clients must implement. """ @property @abstractmethod def id(self) -> str: - """The ID of the dataset.""" + """The ID of the dataet, a unique identifier, typically a UUID or similar value.""" @property @abstractmethod def name(self) -> str | None: - """The name of the dataset.""" + """The optional human-readable name of the dataset.""" @property @abstractmethod def created_at(self) -> datetime: - """The time at which the dataset was created.""" + """Timestamp when the dataset was first created, remains unchanged.""" @property @abstractmethod def accessed_at(self) -> datetime: - """The time at which the dataset was last accessed.""" + """Timestamp of last access to the dataset, updated on read or write operations.""" @property @abstractmethod def modified_at(self) -> datetime: - """The time at which the dataset was last modified.""" + """Timestamp of last modification of the dataset, updated when new data are added.""" @property @abstractmethod def item_count(self) -> int: - """The number of items in the dataset.""" + """Total count of data items stored in the dataset.""" @classmethod @abstractmethod async def open( cls, *, - id: str | None, - name: str | None, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + storage_dir: Path | None = None, ) -> DatasetClient: """Open existing or create a new dataset client. - If a dataset with the given name already exists, the appropriate dataset client is returned. + If a dataset with the given name or ID already exists, the appropriate dataset client is returned. Otherwise, a new dataset is created and client for it is returned. + The backend method for the `Dataset.open` call. + Args: - id: The ID of the dataset. - name: The name of the dataset. - storage_dir: The path to the storage directory. If the client persists data, it should use this directory. + id: The ID of the dataset. If not provided, an ID may be generated. + name: The name of the dataset. If not provided a default name may be used. + storage_dir: The path to the storage directory. If the client persists data, + it should use this directory. May be ignored by non-persistent implementations. Returns: - A dataset client. + A dataset client instance. """ @abstractmethod @@ -99,7 +92,7 @@ async def drop(self) -> None: """ @abstractmethod - async def push_data(self, *, data: list[Any] | dict[str, Any]) -> None: + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: """Push data to the dataset. The backend method for the `Dataset.push_data` call. @@ -121,13 +114,13 @@ async def get_data( flatten: list[str] | None = None, view: str | None = None, ) -> DatasetItemsListPage: - """Get data from the dataset. + """Get data from the dataset with various filtering options. The backend method for the `Dataset.get_data` call. """ @abstractmethod - async def iterate( + async def iterate_items( self, *, offset: int = 0, @@ -140,9 +133,9 @@ async def iterate( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict]: - """Iterate over the dataset. + """Iterate over the dataset items with filtering options. - The backend method for the `Dataset.iterate` call. + The backend method for the `Dataset.iterate_items` call. """ # This syntax is to make mypy properly work with abstract AsyncIterator. # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators diff --git a/src/crawlee/storage_clients/_base/_key_value_store_client.py b/src/crawlee/storage_clients/_base/_key_value_store_client.py index 097b5fbf8f..50b7175745 100644 --- a/src/crawlee/storage_clients/_base/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_base/_key_value_store_client.py @@ -10,82 +10,73 @@ from datetime import datetime from pathlib import Path - from crawlee.storage_clients.models import ( - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, - ) - -# Properties: -# - id -# - name -# - created_at -# - accessed_at -# - modified_at - -# Methods: -# - open -# - drop -# - get_value -# - set_value -# - delete_value -# - iterate_keys -# - get_public_url + from crawlee.storage_clients.models import KeyValueStoreRecord, KeyValueStoreRecordMetadata @docs_group('Abstract classes') class KeyValueStoreClient(ABC): - """An abstract class for key-value store (KVS) resource clients. + """An abstract class for key-value store (KVS) storage clients. - These clients are specific to the type of resource they manage and operate under a designated storage - client, like a memory storage client. + Key-value stores clients provide an interface for accessing and manipulating KVS storage. They handle + operations like getting, setting, deleting KVS values across different storage backends. + + Storage clients are specific to the type of storage they manage (`Dataset`, `KeyValueStore`, + `RequestQueue`), and can operate with various storage systems including memory, file system, + databases, and cloud storage solutions. + + This abstract class defines the interface that all specific KVS clients must implement. """ @property @abstractmethod def id(self) -> str: - """The ID of the key-value store.""" + """The unique identifier of the key-value store (typically a UUID).""" @property @abstractmethod def name(self) -> str | None: - """The name of the key-value store.""" + """The optional human-readable name for the KVS.""" @property @abstractmethod def created_at(self) -> datetime: - """The time at which the key-value store was created.""" + """Timestamp when the KVS was first created, remains unchanged.""" @property @abstractmethod def accessed_at(self) -> datetime: - """The time at which the key-value store was last accessed.""" + """Timestamp of last access to the KVS, updated on read or write operations.""" @property @abstractmethod def modified_at(self) -> datetime: - """The time at which the key-value store was last modified.""" + """Timestamp of last modification of the KVS, updated when new data are added, updated or deleted.""" @classmethod @abstractmethod async def open( cls, *, - id: str | None, - name: str | None, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + storage_dir: Path | None = None, ) -> KeyValueStoreClient: """Open existing or create a new key-value store client. - If a key-value store with the given name already exists, the appropriate key-value store client is returned. - Otherwise, a new key-value store is created and client for it is returned. + If a key-value store with the given name or ID already exists, the appropriate + key-value store client is returned. Otherwise, a new key-value store is created + and a client for it is returned. + + The backend method for the `KeyValueStoreClient.open` call. Args: - id: The ID of the key-value store. - name: The name of the key-value store. - storage_dir: The path to the storage directory. If the client persists data, it should use this directory. + id: The ID of the key-value store. If not provided, an ID may be generated. + name: The name of the key-value store. If not provided a default name may be used. + storage_dir: The path to the storage directory. If the client persists data, + it should use this directory. May be ignored by non-persistent implementations. Returns: - A key-value store client. + A key-value store client instance. """ @abstractmethod diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py index 3c0b46d5c2..b0a665cee7 100644 --- a/src/crawlee/storage_clients/_file_system/_dataset_client.py +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -5,6 +5,7 @@ import shutil from datetime import datetime, timezone from logging import getLogger +from pathlib import Path from typing import TYPE_CHECKING from pydantic import ValidationError @@ -18,7 +19,6 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from pathlib import Path from typing import Any logger = getLogger(__name__) @@ -28,14 +28,15 @@ class FileSystemDatasetClient(DatasetClient): - """A file system storage implementation of the dataset client. + """A file system implementation of the dataset client. - This client stores dataset items as individual JSON files in a subdirectory. - The metadata of the dataset (timestamps, item count, etc.) is stored in a metadata file. + This client persists data to the file system, making it suitable for scenarios where data needs + to survive process restarts. Each dataset item is stored as a separate JSON file with a numeric + filename, allowing for easy ordering and pagination. """ _DEFAULT_NAME = 'default' - """The name of the unnamed dataset.""" + """The default name for the dataset when no name is provided.""" _STORAGE_SUBDIR = 'datasets' """The name of the subdirectory where datasets are stored.""" @@ -69,7 +70,7 @@ def __init__( self._storage_dir = storage_dir - # Internal attributes. + # Internal attributes self._lock = asyncio.Lock() """A lock to ensure that only one file operation is performed at a time.""" @@ -120,21 +121,8 @@ async def open( *, id: str | None = None, name: str | None = None, - storage_dir: Path, + storage_dir: Path | None = None, ) -> FileSystemDatasetClient: - """Open an existing dataset client or create a new one if it does not exist. - - If the dataset directory exists, this method reconstructs the client from the metadata file. - Otherwise, a new dataset client is created with a new unique ID. - - Args: - id: The dataset ID. - name: The dataset name; if not provided, defaults to the default name. - storage_dir: The base directory for storage. - - Returns: - A new instance of the file system dataset client. - """ if id: raise ValueError( 'Opening a dataset by "id" is not supported for file system storage client, use "name" instead.' @@ -144,8 +132,11 @@ async def open( # Check if the client is already cached by name. if name in _cache_by_name: - return _cache_by_name[name] + client = _cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + storage_dir = storage_dir or Path.cwd() dataset_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = dataset_path / METADATA_FILENAME @@ -242,11 +233,21 @@ async def get_data( view: str | None = None, ) -> DatasetItemsListPage: # Check for unsupported arguments and log a warning if found. - unsupported_args = [clean, fields, omit, unwind, skip_hidden, flatten, view] - invalid = [arg for arg in unsupported_args if arg not in (False, None)] - if invalid: + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + 'flatten': flatten, + 'view': view, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: logger.warning( - f'The arguments {invalid} of get_data are not supported by the {self.__class__.__name__} client.' + f'The arguments {list(unsupported.keys())} of get_data are not supported by the ' + f'{self.__class__.__name__} client.' ) # If the dataset directory does not exist, log a warning and return an empty page. @@ -303,7 +304,7 @@ async def get_data( ) @override - async def iterate( + async def iterate_items( self, *, offset: int = 0, @@ -317,11 +318,19 @@ async def iterate( skip_hidden: bool = False, ) -> AsyncIterator[dict]: # Check for unsupported arguments and log a warning if found. - unsupported_args = [clean, fields, omit, unwind, skip_hidden] - invalid = [arg for arg in unsupported_args if arg not in (False, None)] - if invalid: + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: logger.warning( - f'The arguments {invalid} of iterate_items are not supported by the {self.__class__.__name__} client.' + f'The arguments {list(unsupported.keys())} of iterate are not supported ' + f'by the {self.__class__.__name__} client.' ) # If the dataset directory does not exist, log a warning and return immediately. @@ -374,9 +383,12 @@ async def _update_metadata( """ now = datetime.now(timezone.utc) - self._metadata.accessed_at = now if update_accessed_at else self.accessed_at - self._metadata.modified_at = now if update_modified_at else self.modified_at - self._metadata.item_count = new_item_count if new_item_count else self.item_count + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + if new_item_count is not None: + self._metadata.item_count = new_item_count # Ensure the parent directory for the metadata file exists. await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) @@ -388,8 +400,12 @@ async def _update_metadata( async def _push_item(self, item: dict[str, Any], item_id: int) -> None: """Push a single item to the dataset. - This method increments the item count, writes the item as a JSON file with a zero-padded filename, - and updates the metadata. + This method writes the item as a JSON file with a zero-padded numeric filename + that reflects its position in the dataset sequence. + + Args: + item: The data item to add to the dataset. + item_id: The sequential ID to use for this item's filename. """ # Acquire the lock to perform file operations safely. async with self._lock: @@ -407,8 +423,11 @@ async def _push_item(self, item: dict[str, Any], item_id: int) -> None: async def _get_sorted_data_files(self) -> list[Path]: """Retrieve and return a sorted list of data files in the dataset directory. - The files are sorted numerically based on the filename (without extension). - The metadata file is excluded. + The files are sorted numerically based on the filename (without extension), + which corresponds to the order items were added to the dataset. + + Returns: + A list of `Path` objects pointing to data files, sorted by numeric filename. """ # Retrieve and sort all JSON files in the dataset directory numerically. files = await asyncio.to_thread( diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py index 129c8aeb49..aa347d50d4 100644 --- a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -5,12 +5,14 @@ import shutil from datetime import datetime, timezone from logging import getLogger +from pathlib import Path from typing import TYPE_CHECKING, Any from pydantic import ValidationError from typing_extensions import override from crawlee._utils.crypto import crypto_random_object_id +from crawlee._utils.file import infer_mime_type from crawlee.storage_clients._base import KeyValueStoreClient from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata @@ -18,7 +20,7 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from pathlib import Path + logger = getLogger(__name__) @@ -27,13 +29,18 @@ class FileSystemKeyValueStoreClient(KeyValueStoreClient): - """A file system key-value store (KVS) implementation.""" + """A file system implementation of the key-value store client. + + This client persists data to the file system, making it suitable for scenarios where data needs + to survive process restarts. Each key-value pair is stored as a separate file, with its metadata + in an accompanying file. + """ _DEFAULT_NAME = 'default' - """The name of the unnamed KVS.""" + """The default name for the unnamed key-value store.""" _STORAGE_SUBDIR = 'key_value_stores' - """The name of the subdirectory where KVSs are stored.""" + """The name of the subdirectory where key-value stores are stored.""" def __init__( self, @@ -59,7 +66,7 @@ def __init__( self._storage_dir = storage_dir - # Internal attributes. + # Internal attributes self._lock = asyncio.Lock() """A lock to ensure that only one file operation is performed at a time.""" @@ -105,21 +112,8 @@ async def open( *, id: str | None = None, name: str | None = None, - storage_dir: Path, + storage_dir: Path | None = None, ) -> FileSystemKeyValueStoreClient: - """Open an existing key-value store client or create a new one if it does not exist. - - If the key-value store directory exists, this method reconstructs the client from the metadata file. - Otherwise, a new key-value store client is created with a new unique ID. - - Args: - id: The key-value store ID. - name: The key-value store name; if not provided, defaults to the default name. - storage_dir: The base directory for storage. - - Returns: - A new instance of the file system key-value store client. - """ if id: raise ValueError( 'Opening a key-value store by "id" is not supported for file system storage client, use "name" instead.' @@ -131,6 +125,7 @@ async def open( if name in _cache_by_name: return _cache_by_name[name] + storage_dir = storage_dir or Path.cwd() kvs_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = kvs_path / METADATA_FILENAME @@ -257,7 +252,7 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: @override async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: - content_type = content_type or self._infer_mime_type(value) + content_type = content_type or infer_mime_type(value) # Serialize the value to bytes. if 'application/json' in content_type: @@ -272,8 +267,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No record_path = self.path_to_kvs / key - # Get the metadata. - # Calculate the size of the value in bytes + # Prepare the metadata size = len(value_bytes) record_metadata = KeyValueStoreRecordMetadata(key=key, content_type=content_type, size=size) record_metadata_filepath = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') @@ -283,10 +277,10 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No # Ensure the key-value store directory exists. await asyncio.to_thread(self.path_to_kvs.mkdir, parents=True, exist_ok=True) - # Dump the value to the file. + # Write the value to the file. await asyncio.to_thread(record_path.write_bytes, value_bytes) - # Dump the record metadata to the file. + # Write the record metadata to the file. await asyncio.to_thread( record_metadata_filepath.write_text, record_metadata_content, @@ -332,7 +326,7 @@ async def iterate_keys( count = 0 async with self._lock: - # Get all files in the KVS directory + # Get all files in the KVS directory, sorted alphabetically files = sorted(await asyncio.to_thread(list, self.path_to_kvs.glob('*'))) for file_path in files: @@ -387,8 +381,10 @@ async def _update_metadata( """ now = datetime.now(timezone.utc) - self._metadata.accessed_at = now if update_accessed_at else self._metadata.accessed_at - self._metadata.modified_at = now if update_modified_at else self._metadata.modified_at + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now # Ensure the parent directory for the metadata file exists. await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) @@ -396,27 +392,3 @@ async def _update_metadata( # Dump the serialized metadata to the file. data = await json_dumps(self._metadata.model_dump()) await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') - - def _infer_mime_type(self, value: Any) -> str: - """Infer the MIME content type from the value. - - Args: - value: The value to infer the content type from. - - Returns: - The inferred MIME content type. - """ - # If the value is bytes (or bytearray), return binary content type. - if isinstance(value, (bytes, bytearray)): - return 'application/octet-stream' - - # If the value is a dict or list, assume JSON. - if isinstance(value, (dict, list)): - return 'application/json; charset=utf-8' - - # If the value is a string, assume plain text. - if isinstance(value, str): - return 'text/plain; charset=utf-8' - - # Default fallback. - return 'application/octet-stream' diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 6ffa22f028..558619c5f0 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -8,7 +8,7 @@ from crawlee._utils.crypto import crypto_random_object_id from crawlee.storage_clients._base import DatasetClient -from crawlee.storage_clients.models import DatasetItemsListPage +from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata if TYPE_CHECKING: from collections.abc import AsyncIterator @@ -16,12 +16,16 @@ logger = getLogger(__name__) +_cache_by_name = dict[str, 'MemoryDatasetClient']() +"""A dictionary to cache clients by their names.""" + class MemoryDatasetClient(DatasetClient): """A memory implementation of the dataset client. - This client stores dataset items in memory using a dictionary. - No data is persisted to the file system. + This client stores dataset items in memory using a list. No data is persisted, which means + all data is lost when the process terminates. This implementation is mainly useful for testing + and development purposes where persistence is not required. """ _DEFAULT_NAME = 'default' @@ -37,62 +41,76 @@ def __init__( modified_at: datetime, item_count: int, ) -> None: - """Initialize a new instance of the memory-only dataset client. + """Initialize a new instance. Preferably use the `MemoryDatasetClient.open` class method to create a new instance. """ - self._id = id - self._name = name - self._created_at = created_at - self._accessed_at = accessed_at - self._modified_at = modified_at - self._item_count = item_count - - # Dictionary to hold dataset items; keys are zero-padded strings. + self._metadata = DatasetMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + item_count=item_count, + ) + + # List to hold dataset items self._records = list[dict[str, Any]]() @override @property def id(self) -> str: - return self._id + return self._metadata.id @override @property - def name(self) -> str | None: - return self._name + def name(self) -> str: + return self._metadata.name @override @property def created_at(self) -> datetime: - return self._created_at + return self._metadata.created_at @override @property def accessed_at(self) -> datetime: - return self._accessed_at + return self._metadata.accessed_at @override @property def modified_at(self) -> datetime: - return self._modified_at + return self._metadata.modified_at @override @property def item_count(self) -> int: - return self._item_count + return self._metadata.item_count @override @classmethod async def open( cls, - id: str | None, - name: str | None, - storage_dir: Path, # Ignored in the memory-only implementation. + *, + id: str | None = None, + name: str | None = None, + storage_dir: Path | None = None, ) -> MemoryDatasetClient: + if storage_dir is not None: + logger.warning('The `storage_dir` argument is not used in the memory dataset client.') + name = name or cls._DEFAULT_NAME + + # Check if the client is already cached by name. + if name in _cache_by_name: + client = _cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + dataset_id = id or crypto_random_object_id() now = datetime.now(timezone.utc) - return cls( + + client = cls( id=dataset_id, name=name, created_at=now, @@ -101,19 +119,37 @@ async def open( item_count=0, ) + # Cache the client by name + _cache_by_name[name] = client + + return client + @override async def drop(self) -> None: self._records.clear() - self._item_count = 0 + self._metadata.item_count = 0 + + # Remove the client from the cache + if self.name in _cache_by_name: + del _cache_by_name[self.name] @override async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + new_item_count = self.item_count + if isinstance(data, list): for item in data: + new_item_count += 1 await self._push_item(item) else: + new_item_count += 1 await self._push_item(data) - await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=new_item_count, + ) @override async def get_data( @@ -131,20 +167,40 @@ async def get_data( flatten: list[str] | None = None, view: str | None = None, ) -> DatasetItemsListPage: - unsupported_args = [clean, fields, omit, unwind, skip_hidden, flatten, view] - invalid = [arg for arg in unsupported_args if arg not in (False, None)] - if invalid: + # Check for unsupported arguments and log a warning if found + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + 'flatten': flatten, + 'view': view, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: logger.warning( - f'The arguments {invalid} of get_data are not supported by the {self.__class__.__name__} client.' + f'The arguments {list(unsupported.keys())} of get_data are not supported ' + f'by the {self.__class__.__name__} client.' ) total = len(self._records) items = self._records.copy() + + # Apply skip_empty filter if requested + if skip_empty: + items = [item for item in items if item] + + # Apply sorting if desc: items = list(reversed(items)) + # Apply pagination sliced_items = items[offset : (offset + limit) if limit is not None else total] + await self._update_metadata(update_accessed_at=True) + return DatasetItemsListPage( count=len(sliced_items), offset=offset, @@ -155,7 +211,7 @@ async def get_data( ) @override - async def iterate( + async def iterate_items( self, *, offset: int = 0, @@ -168,18 +224,32 @@ async def iterate( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict]: - unsupported_args = [clean, fields, omit, unwind, skip_hidden] - invalid = [arg for arg in unsupported_args if arg not in (False, None)] - if invalid: + # Check for unsupported arguments and log a warning if found + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: logger.warning( - f'The arguments {invalid} of iterate are not supported by the {self.__class__.__name__} client.' + f'The arguments {list(unsupported.keys())} of iterate are not supported ' + f'by the {self.__class__.__name__} client.' ) items = self._records.copy() + + # Apply sorting if desc: items = list(reversed(items)) + # Apply pagination sliced_items = items[offset : (offset + limit) if limit is not None else len(items)] + + # Yield items one by one for item in sliced_items: if skip_empty and not item: continue @@ -190,22 +260,30 @@ async def iterate( async def _update_metadata( self, *, + new_item_count: int | None = None, update_accessed_at: bool = False, update_modified_at: bool = False, ) -> None: - """Update the dataset metadata file with current information. + """Update the dataset metadata with current information. Args: + new_item_count: If provided, update the item count to this value. update_accessed_at: If True, update the `accessed_at` timestamp to the current time. update_modified_at: If True, update the `modified_at` timestamp to the current time. """ now = datetime.now(timezone.utc) + if update_accessed_at: - self._accessed_at = now + self._metadata.accessed_at = now if update_modified_at: - self._modified_at = now + self._metadata.modified_at = now + if new_item_count: + self._metadata.item_count = new_item_count async def _push_item(self, item: dict[str, Any]) -> None: - """Push a single item to the dataset.""" - self._item_count += 1 + """Push a single item to the dataset. + + Args: + item: The data item to add to the dataset. + """ self._records.append(item) diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index 0dd8d9a0a5..dcaec9f458 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -1,11 +1,211 @@ from __future__ import annotations +import sys +from datetime import datetime, timezone from logging import getLogger +from typing import TYPE_CHECKING, Any +from typing_extensions import override + +from crawlee._utils.crypto import crypto_random_object_id +from crawlee._utils.file import infer_mime_type from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from pathlib import Path logger = getLogger(__name__) +_cache_by_name = dict[str, 'MemoryKeyValueStoreClient']() +"""A dictionary to cache clients by their names.""" + class MemoryKeyValueStoreClient(KeyValueStoreClient): - pass + """A memory implementation of the key-value store client. + + This client stores key-value store pairs in memory using a dictionary. No data is persisted, + which means all data is lost when the process terminates. This implementation is mainly useful + for testing and development purposes where persistence is not required. + """ + + _DEFAULT_NAME = 'default' + """The default name for the key-value store when no name is provided.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + ) -> None: + """Initialize a new instance. + + Preferably use the `MemoryKeyValueStoreClient.open` class method to create a new instance. + """ + self._metadata = KeyValueStoreMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + ) + + # Dictionary to hold key-value records with metadata + self._store = dict[str, KeyValueStoreRecord]() + + @override + @property + def id(self) -> str: + return self._metadata.id + + @override + @property + def name(self) -> str: + return self._metadata.name + + @override + @property + def created_at(self) -> datetime: + return self._metadata.created_at + + @override + @property + def accessed_at(self) -> datetime: + return self._metadata.accessed_at + + @override + @property + def modified_at(self) -> datetime: + return self._metadata.modified_at + + @override + @classmethod + async def open( + cls, + *, + id: str | None = None, + name: str | None = None, + storage_dir: Path | None = None, + ) -> MemoryKeyValueStoreClient: + if storage_dir is not None: + logger.warning('The `storage_dir` argument is not used in the memory key-value store client.') + + name = name or cls._DEFAULT_NAME + + # Check if the client is already cached by name + if name in _cache_by_name: + client = _cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + + # If specific id is provided, use it; otherwise, generate a new one + id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + client = cls( + id=id, + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + ) + + # Cache the client by name + _cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + # Clear all data + self._store.clear() + + # Remove from cache + if self.name in _cache_by_name: + del _cache_by_name[self.name] + + @override + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + await self._update_metadata(update_accessed_at=True) + + # Return None if key doesn't exist + return self._store.get(key, None) + + @override + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + content_type = content_type or infer_mime_type(value) + size = sys.getsizeof(value) + + # Create and store the record + record = KeyValueStoreRecord( + key=key, + value=value, + content_type=content_type, + size=size, + ) + + self._store[key] = record + + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def delete_value(self, *, key: str) -> None: + if key in self._store: + del self._store[key] + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + await self._update_metadata(update_accessed_at=True) + + # Get all keys, sorted alphabetically + keys = sorted(self._store.keys()) + + # Apply exclusive_start_key filter if provided + if exclusive_start_key is not None: + keys = [k for k in keys if k > exclusive_start_key] + + # Apply limit if provided + if limit is not None: + keys = keys[:limit] + + # Yield metadata for each key + for key in keys: + record = self._store[key] + yield KeyValueStoreRecordMetadata( + key=key, + content_type=record.content_type, + size=record.size, + ) + + @override + async def get_public_url(self, *, key: str) -> str: + raise NotImplementedError('Public URLs are not supported for memory key-value stores.') + + async def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the key-value store metadata with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 112addfcf1..4f5659eae0 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -43,7 +43,7 @@ # - drop # - push_data # - get_data -# - iterate +# - iterate_items # - export_to # - export_to_json # - export_to_csv @@ -219,7 +219,7 @@ async def get_data( view=view, ) - async def iterate( + async def iterate_items( self, *, offset: int = 0, @@ -254,7 +254,7 @@ async def iterate( An asynchronous iterator of dictionary objects, each representing a dataset item after applying the specified filters and transformations. """ - async for item in self._client.iterate( + async for item in self._client.iterate_items( offset=offset, limit=limit, clean=clean, @@ -313,7 +313,7 @@ async def export_to_json( ) -> None: kvs = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) dst = StringIO() - await export_json_to_stream(self.iterate(), dst, **kwargs) + await export_json_to_stream(self.iterate_items(), dst, **kwargs) await kvs.set_value(key, dst.getvalue(), 'application/json') async def export_to_csv( @@ -325,5 +325,5 @@ async def export_to_csv( ) -> None: kvs = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) dst = StringIO() - await export_csv_to_stream(self.iterate(), dst, **kwargs) + await export_csv_to_stream(self.iterate_items(), dst, **kwargs) await kvs.set_value(key, dst.getvalue(), 'text/csv') diff --git a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py index 8af2509db9..89b028ed81 100644 --- a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py +++ b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py @@ -456,7 +456,7 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None: await crawler.run(test_urls[:1]) dataset = await crawler.get_dataset() - stored_results = [item async for item in dataset.iterate()] + stored_results = [item async for item in dataset.iterate_items()] if error_in_pw_crawler: assert stored_results == [] diff --git a/tests/unit/storage_clients/_file_system/test_dataset_client.py b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py similarity index 91% rename from tests/unit/storage_clients/_file_system/test_dataset_client.py rename to tests/unit/storage_clients/_file_system/test_fs_dataset_client.py index ae17746e06..247830d435 100644 --- a/tests/unit/storage_clients/_file_system/test_dataset_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py @@ -108,7 +108,7 @@ async def test_push_data_multiple_items(dataset_client: FileSystemDatasetClient) async def test_get_data_empty_dataset(dataset_client: FileSystemDatasetClient) -> None: - """Test getting data from an empty dataset.""" + """Test getting data from an empty dataset returns empty list.""" result = await dataset_client.get_data() assert isinstance(result, DatasetItemsListPage) @@ -118,7 +118,7 @@ async def test_get_data_empty_dataset(dataset_client: FileSystemDatasetClient) - async def test_get_data_with_items(dataset_client: FileSystemDatasetClient) -> None: - """Test getting data from a dataset with items.""" + """Test getting data from a dataset returns all items in order with correct properties.""" # Add some items items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] await dataset_client.push_data(items) @@ -135,7 +135,7 @@ async def test_get_data_with_items(dataset_client: FileSystemDatasetClient) -> N async def test_get_data_with_pagination(dataset_client: FileSystemDatasetClient) -> None: - """Test getting data with pagination.""" + """Test getting data with offset and limit parameters for pagination implementation.""" # Add some items items = [{'id': i} for i in range(1, 11)] # 10 items await dataset_client.push_data(items) @@ -162,7 +162,7 @@ async def test_get_data_with_pagination(dataset_client: FileSystemDatasetClient) async def test_get_data_descending_order(dataset_client: FileSystemDatasetClient) -> None: - """Test getting data in descending order.""" + """Test getting data in descending order reverses the item order.""" # Add some items items = [{'id': i} for i in range(1, 6)] # 5 items await dataset_client.push_data(items) @@ -176,7 +176,7 @@ async def test_get_data_descending_order(dataset_client: FileSystemDatasetClient async def test_get_data_skip_empty(dataset_client: FileSystemDatasetClient) -> None: - """Test getting data with skip_empty option.""" + """Test getting data with skip_empty option filters out empty items when True.""" # Add some items including an empty one items = [ {'id': 1, 'name': 'Item 1'}, @@ -196,13 +196,13 @@ async def test_get_data_skip_empty(dataset_client: FileSystemDatasetClient) -> N async def test_iterate(dataset_client: FileSystemDatasetClient) -> None: - """Test iterating over dataset items.""" + """Test iterating over dataset items yields each item in the original order.""" # Add some items items = [{'id': i} for i in range(1, 6)] # 5 items await dataset_client.push_data(items) # Iterate over all items - collected_items = [item async for item in dataset_client.iterate()] + collected_items = [item async for item in dataset_client.iterate_items()] assert len(collected_items) == 5 assert collected_items[0]['id'] == 1 @@ -210,13 +210,13 @@ async def test_iterate(dataset_client: FileSystemDatasetClient) -> None: async def test_iterate_with_options(dataset_client: FileSystemDatasetClient) -> None: - """Test iterating with various options.""" + """Test iterating with offset, limit and desc parameters works the same as with get_data().""" # Add some items items = [{'id': i} for i in range(1, 11)] # 10 items await dataset_client.push_data(items) # Test with offset and limit - collected_items = [item async for item in dataset_client.iterate(offset=3, limit=3)] + collected_items = [item async for item in dataset_client.iterate_items(offset=3, limit=3)] assert len(collected_items) == 3 assert collected_items[0]['id'] == 4 @@ -224,7 +224,7 @@ async def test_iterate_with_options(dataset_client: FileSystemDatasetClient) -> # Test with descending order collected_items = [] - async for item in dataset_client.iterate(desc=True, limit=3): + async for item in dataset_client.iterate_items(desc=True, limit=3): collected_items.append(item) assert len(collected_items) == 3 @@ -233,7 +233,7 @@ async def test_iterate_with_options(dataset_client: FileSystemDatasetClient) -> async def test_drop(tmp_path: Path) -> None: - """Test dropping a dataset.""" + """Test dropping a dataset removes the entire dataset directory from disk.""" # Create a dataset and add an item client = await FileSystemDatasetClient.open(name='to_drop', storage_dir=tmp_path) await client.push_data({'test': 'data'}) @@ -249,7 +249,7 @@ async def test_drop(tmp_path: Path) -> None: async def test_metadata_updates(dataset_client: FileSystemDatasetClient) -> None: - """Test that metadata is updated correctly after operations.""" + """Test that metadata timestamps are updated correctly after read and write operations.""" # Record initial timestamps initial_created = dataset_client.created_at initial_accessed = dataset_client.accessed_at diff --git a/tests/unit/storage_clients/_file_system/test_key_value_store_client.py b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py similarity index 85% rename from tests/unit/storage_clients/_file_system/test_key_value_store_client.py rename to tests/unit/storage_clients/_file_system/test_fs_kvs_client.py index 4f1431ea59..7f20fbfea1 100644 --- a/tests/unit/storage_clients/_file_system/test_key_value_store_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py @@ -19,14 +19,14 @@ @pytest.fixture async def kvs_client(tmp_path: Path) -> AsyncGenerator[FileSystemKeyValueStoreClient, None]: - """A fixture for a file system key-value store client.""" + """Fixture that provides a fresh file system key-value store client using a temporary directory.""" client = await FileSystemKeyValueStoreClient.open(name='test_kvs', storage_dir=tmp_path) yield client await client.drop() async def test_open_creates_new_kvs(tmp_path: Path) -> None: - """Test that open() creates a new key-value store with proper metadata when it doesn't exist.""" + """Test that open() creates a new key-value store with proper metadata and files on disk.""" client = await FileSystemKeyValueStoreClient.open(name='new_kvs', storage_dir=tmp_path) # Verify client properties @@ -48,7 +48,7 @@ async def test_open_creates_new_kvs(tmp_path: Path) -> None: async def test_open_existing_kvs(kvs_client: FileSystemKeyValueStoreClient, tmp_path: Path) -> None: - """Test that open() loads an existing key-value store correctly.""" + """Test that open() loads an existing key-value store with matching properties.""" # Open the same key-value store again reopened_client = await FileSystemKeyValueStoreClient.open(name=kvs_client.name, storage_dir=tmp_path) @@ -61,13 +61,13 @@ async def test_open_existing_kvs(kvs_client: FileSystemKeyValueStoreClient, tmp_ async def test_open_with_id_raises_error(tmp_path: Path) -> None: - """Test that open() raises an error when an ID is provided.""" + """Test that open() raises an error when an ID is provided (unsupported for file system client).""" with pytest.raises(ValueError, match='not supported for file system storage client'): await FileSystemKeyValueStoreClient.open(id='some-id', storage_dir=tmp_path) async def test_set_get_value_string(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test setting and getting a string value.""" + """Test setting and getting a string value with correct file creation and metadata.""" # Set a value test_key = 'test-key' test_value = 'Hello, world!' @@ -100,7 +100,7 @@ async def test_set_get_value_string(kvs_client: FileSystemKeyValueStoreClient) - async def test_set_get_value_json(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test setting and getting a JSON value.""" + """Test setting and getting a JSON value with correct serialization and deserialization.""" # Set a value test_key = 'test-json' test_value = {'name': 'John', 'age': 30, 'items': [1, 2, 3]} @@ -115,7 +115,7 @@ async def test_set_get_value_json(kvs_client: FileSystemKeyValueStoreClient) -> async def test_set_get_value_bytes(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test setting and getting binary data.""" + """Test setting and getting binary data without corruption and with correct content type.""" # Set a value test_key = 'test-binary' test_value = b'\x00\x01\x02\x03\x04' @@ -131,7 +131,7 @@ async def test_set_get_value_bytes(kvs_client: FileSystemKeyValueStoreClient) -> async def test_set_value_explicit_content_type(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test setting a value with an explicit content type.""" + """Test that an explicitly provided content type overrides the automatically inferred one.""" test_key = 'test-explicit-content-type' test_value = 'Hello, world!' explicit_content_type = 'text/html; charset=utf-8' @@ -144,13 +144,13 @@ async def test_set_value_explicit_content_type(kvs_client: FileSystemKeyValueSto async def test_get_nonexistent_value(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test getting a value that doesn't exist.""" + """Test that attempting to get a non-existent key returns None.""" record = await kvs_client.get_value(key='nonexistent-key') assert record is None async def test_overwrite_value(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test overwriting an existing value.""" + """Test that an existing value can be overwritten and the updated value is retrieved correctly.""" test_key = 'test-overwrite' # Set initial value @@ -168,7 +168,7 @@ async def test_overwrite_value(kvs_client: FileSystemKeyValueStoreClient) -> Non async def test_delete_value(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test deleting a value.""" + """Test that deleting a value removes its files from disk and makes it irretrievable.""" test_key = 'test-delete' test_value = 'Delete me' @@ -194,19 +194,19 @@ async def test_delete_value(kvs_client: FileSystemKeyValueStoreClient) -> None: async def test_delete_nonexistent_value(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test deleting a value that doesn't exist.""" + """Test that attempting to delete a non-existent key is a no-op and doesn't raise errors.""" # Should not raise an error await kvs_client.delete_value(key='nonexistent-key') async def test_iterate_keys_empty_store(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test iterating over keys in an empty store.""" + """Test that iterating over an empty store yields no keys.""" keys = [key async for key in kvs_client.iterate_keys()] assert len(keys) == 0 async def test_iterate_keys(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test iterating over keys.""" + """Test that all keys can be iterated over and are returned in sorted order.""" # Add some values await kvs_client.set_value(key='key1', value='value1') await kvs_client.set_value(key='key2', value='value2') @@ -219,7 +219,7 @@ async def test_iterate_keys(kvs_client: FileSystemKeyValueStoreClient) -> None: async def test_iterate_keys_with_limit(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test iterating over keys with a limit.""" + """Test that the limit parameter returns only the specified number of keys.""" # Add some values await kvs_client.set_value(key='key1', value='value1') await kvs_client.set_value(key='key2', value='value2') @@ -231,7 +231,7 @@ async def test_iterate_keys_with_limit(kvs_client: FileSystemKeyValueStoreClient async def test_iterate_keys_with_exclusive_start_key(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test iterating over keys with an exclusive start key.""" + """Test that exclusive_start_key parameter returns only keys after it alphabetically.""" # Add some values with alphabetical keys await kvs_client.set_value(key='a-key', value='value-a') await kvs_client.set_value(key='b-key', value='value-b') @@ -248,7 +248,7 @@ async def test_iterate_keys_with_exclusive_start_key(kvs_client: FileSystemKeyVa async def test_drop(tmp_path: Path) -> None: - """Test dropping a key-value store.""" + """Test that drop removes the entire store directory from disk.""" # Create a store and add a value client = await FileSystemKeyValueStoreClient.open(name='to_drop', storage_dir=tmp_path) await client.set_value(key='test', value='test-value') @@ -265,7 +265,7 @@ async def test_drop(tmp_path: Path) -> None: async def test_metadata_updates(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test that metadata is updated correctly after operations.""" + """Test that read/write operations properly update accessed_at and modified_at timestamps.""" # Record initial timestamps initial_created = kvs_client.created_at initial_accessed = kvs_client.accessed_at @@ -297,29 +297,13 @@ async def test_metadata_updates(kvs_client: FileSystemKeyValueStoreClient) -> No async def test_get_public_url_not_supported(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test that get_public_url raises NotImplementedError.""" + """Test that get_public_url raises NotImplementedError for the file system implementation.""" with pytest.raises(NotImplementedError, match='Public URLs are not supported'): await kvs_client.get_public_url(key='any-key') -async def test_infer_mime_type(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test MIME type inference for different value types.""" - # Test string - assert kvs_client._infer_mime_type('text') == 'text/plain; charset=utf-8' - - # Test JSON - assert kvs_client._infer_mime_type({'key': 'value'}) == 'application/json; charset=utf-8' - assert kvs_client._infer_mime_type([1, 2, 3]) == 'application/json; charset=utf-8' - - # Test binary - assert kvs_client._infer_mime_type(b'binary data') == 'application/octet-stream' - - # Test other types - assert kvs_client._infer_mime_type(123) == 'application/octet-stream' - - async def test_concurrent_operations(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test concurrent operations on the key-value store.""" + """Test that multiple concurrent set operations can be performed safely with correct results.""" # Create multiple tasks to set different values concurrently async def set_value(key: str, value: str) -> None: diff --git a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py new file mode 100644 index 0000000000..0554a9d3cb --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee.storage_clients._memory._dataset_client import MemoryDatasetClient, _cache_by_name +from crawlee.storage_clients.models import DatasetItemsListPage + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +pytestmark = pytest.mark.only + + +@pytest.fixture +async def dataset_client() -> AsyncGenerator[MemoryDatasetClient, None]: + """Fixture that provides a fresh memory dataset client for each test.""" + # Clear any existing dataset clients in the cache + _cache_by_name.clear() + + client = await MemoryDatasetClient.open(name='test_dataset') + yield client + await client.drop() + + +async def test_open_creates_new_dataset() -> None: + """Test that open() creates a new dataset with proper metadata and adds it to the cache.""" + client = await MemoryDatasetClient.open(name='new_dataset') + + # Verify client properties + assert client.id is not None + assert client.name == 'new_dataset' + assert client.item_count == 0 + assert isinstance(client.created_at, datetime) + assert isinstance(client.accessed_at, datetime) + assert isinstance(client.modified_at, datetime) + + # Verify the client was cached + assert 'new_dataset' in _cache_by_name + + +async def test_open_existing_dataset(dataset_client: MemoryDatasetClient) -> None: + """Test that open() loads an existing dataset with matching properties.""" + # Open the same dataset again + reopened_client = await MemoryDatasetClient.open(name=dataset_client.name) + + # Verify client properties + assert dataset_client.id == reopened_client.id + assert dataset_client.name == reopened_client.name + assert dataset_client.item_count == reopened_client.item_count + + # Verify clients (python) ids + assert id(dataset_client) == id(reopened_client) + + +async def test_open_with_id_and_name() -> None: + """Test that open() can be used with both id and name parameters.""" + client = await MemoryDatasetClient.open(id='some-id', name='some-name') + assert client.id == 'some-id' + assert client.name == 'some-name' + + +async def test_push_data_single_item(dataset_client: MemoryDatasetClient) -> None: + """Test pushing a single item to the dataset and verifying it was stored correctly.""" + item = {'key': 'value', 'number': 42} + await dataset_client.push_data(item) + + # Verify item count was updated + assert dataset_client.item_count == 1 + + # Verify item was stored + result = await dataset_client.get_data() + assert result.count == 1 + assert result.items[0] == item + + +async def test_push_data_multiple_items(dataset_client: MemoryDatasetClient) -> None: + """Test pushing multiple items to the dataset and verifying they were stored correctly.""" + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Verify item count was updated + assert dataset_client.item_count == 3 + + # Verify items were stored + result = await dataset_client.get_data() + assert result.count == 3 + assert result.items == items + + +async def test_get_data_empty_dataset(dataset_client: MemoryDatasetClient) -> None: + """Test that getting data from an empty dataset returns empty results with correct metadata.""" + result = await dataset_client.get_data() + + assert isinstance(result, DatasetItemsListPage) + assert result.count == 0 + assert result.total == 0 + assert result.items == [] + + +async def test_get_data_with_items(dataset_client: MemoryDatasetClient) -> None: + """Test that all items pushed to the dataset can be retrieved with correct metadata.""" + # Add some items + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + + assert result.count == 3 + assert result.total == 3 + assert len(result.items) == 3 + assert result.items[0]['id'] == 1 + assert result.items[1]['id'] == 2 + assert result.items[2]['id'] == 3 + + +async def test_get_data_with_pagination(dataset_client: MemoryDatasetClient) -> None: + """Test that offset and limit parameters work correctly for dataset pagination.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test offset + result = await dataset_client.get_data(offset=3) + assert result.count == 7 + assert result.offset == 3 + assert result.items[0]['id'] == 4 + + # Test limit + result = await dataset_client.get_data(limit=5) + assert result.count == 5 + assert result.limit == 5 + assert result.items[-1]['id'] == 5 + + # Test both offset and limit + result = await dataset_client.get_data(offset=2, limit=3) + assert result.count == 3 + assert result.offset == 2 + assert result.limit == 3 + assert result.items[0]['id'] == 3 + assert result.items[-1]['id'] == 5 + + +async def test_get_data_descending_order(dataset_client: MemoryDatasetClient) -> None: + """Test that the desc parameter correctly reverses the order of returned items.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Get items in descending order + result = await dataset_client.get_data(desc=True) + + assert result.desc is True + assert result.items[0]['id'] == 5 + assert result.items[-1]['id'] == 1 + + +async def test_get_data_skip_empty(dataset_client: MemoryDatasetClient) -> None: + """Test that the skip_empty parameter correctly filters out empty items.""" + # Add some items including an empty one + items = [ + {'id': 1, 'name': 'Item 1'}, + {}, # Empty item + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + assert result.count == 3 + + # Get non-empty items + result = await dataset_client.get_data(skip_empty=True) + assert result.count == 2 + assert all(item != {} for item in result.items) + + +async def test_iterate(dataset_client: MemoryDatasetClient) -> None: + """Test that iterate_items yields each item in the dataset in the correct order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Iterate over all items + collected_items = [item async for item in dataset_client.iterate_items()] + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_iterate_with_options(dataset_client: MemoryDatasetClient) -> None: + """Test that iterate_items respects offset, limit, and desc parameters.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test with offset and limit + collected_items = [item async for item in dataset_client.iterate_items(offset=3, limit=3)] + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 4 + assert collected_items[-1]['id'] == 6 + + # Test with descending order + collected_items = [] + async for item in dataset_client.iterate_items(desc=True, limit=3): + collected_items.append(item) + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 10 + assert collected_items[-1]['id'] == 8 + + +async def test_drop(dataset_client: MemoryDatasetClient) -> None: + """Test that drop removes the dataset from cache and resets its state.""" + # Add an item to the dataset + await dataset_client.push_data({'test': 'data'}) + + # Verify the dataset exists in the cache + assert dataset_client.name in _cache_by_name + + # Drop the dataset + await dataset_client.drop() + + # Verify the dataset was removed from the cache + assert dataset_client.name not in _cache_by_name + + # Verify the dataset is empty + assert dataset_client.item_count == 0 + result = await dataset_client.get_data() + assert result.count == 0 + + +async def test_metadata_updates(dataset_client: MemoryDatasetClient) -> None: + """Test that read/write operations properly update accessed_at and modified_at timestamps.""" + # Record initial timestamps + initial_created = dataset_client.created_at + initial_accessed = dataset_client.accessed_at + initial_modified = dataset_client.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await dataset_client.get_data() + + # Verify timestamps + assert dataset_client.created_at == initial_created + assert dataset_client.accessed_at > initial_accessed + assert dataset_client.modified_at == initial_modified + + accessed_after_get = dataset_client.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await dataset_client.push_data({'new': 'item'}) + + # Verify timestamps again + assert dataset_client.created_at == initial_created + assert dataset_client.modified_at > initial_modified + assert dataset_client.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py new file mode 100644 index 0000000000..ce8e889573 --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from typing import TYPE_CHECKING, Any + +import pytest + +from crawlee.storage_clients._memory._key_value_store_client import MemoryKeyValueStoreClient, _cache_by_name +from crawlee.storage_clients.models import KeyValueStoreRecordMetadata + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +pytestmark = pytest.mark.only + + +@pytest.fixture +async def kvs_client() -> AsyncGenerator[MemoryKeyValueStoreClient, None]: + """Fixture that provides a fresh memory key-value store client for each test.""" + # Clear any existing key-value store clients in the cache + _cache_by_name.clear() + + client = await MemoryKeyValueStoreClient.open(name='test_kvs') + yield client + await client.drop() + + +async def test_open_creates_new_store() -> None: + """Test that open() creates a new key-value store with proper metadata and adds it to the cache.""" + client = await MemoryKeyValueStoreClient.open(name='new_kvs') + + # Verify client properties + assert client.id is not None + assert client.name == 'new_kvs' + assert isinstance(client.created_at, datetime) + assert isinstance(client.accessed_at, datetime) + assert isinstance(client.modified_at, datetime) + + # Verify the client was cached + assert 'new_kvs' in _cache_by_name + + +async def test_open_existing_store(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that open() loads an existing key-value store with matching properties.""" + # Open the same key-value store again + reopened_client = await MemoryKeyValueStoreClient.open(name=kvs_client.name) + + # Verify client properties + assert kvs_client.id == reopened_client.id + assert kvs_client.name == reopened_client.name + + # Verify clients (python) ids + assert id(kvs_client) == id(reopened_client) + + +async def test_open_with_id_and_name() -> None: + """Test that open() can be used with both id and name parameters.""" + client = await MemoryKeyValueStoreClient.open(id='some-id', name='some-name') + assert client.id == 'some-id' + assert client.name == 'some-name' + + +@pytest.mark.parametrize( + ('key', 'value', 'expected_content_type'), + [ + pytest.param('string_key', 'string value', 'text/plain; charset=utf-8', id='string'), + pytest.param('dict_key', {'name': 'test', 'value': 42}, 'application/json; charset=utf-8', id='dictionary'), + pytest.param('list_key', [1, 2, 3], 'application/json; charset=utf-8', id='list'), + pytest.param('bytes_key', b'binary data', 'application/octet-stream', id='bytes'), + ], +) +async def test_set_get_value( + kvs_client: MemoryKeyValueStoreClient, + key: str, + value: Any, + expected_content_type: str, +) -> None: + """Test storing and retrieving different types of values with correct content types.""" + # Set value + await kvs_client.set_value(key=key, value=value) + + # Get and verify value + record = await kvs_client.get_value(key=key) + assert record is not None + assert record.key == key + assert record.value == value + assert record.content_type == expected_content_type + + +async def test_get_nonexistent_value(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that attempting to get a non-existent key returns None.""" + record = await kvs_client.get_value(key='nonexistent') + assert record is None + + +async def test_set_value_with_explicit_content_type(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that an explicitly provided content type overrides the automatically inferred one.""" + value = 'This could be XML' + content_type = 'application/xml' + + await kvs_client.set_value(key='xml_key', value=value, content_type=content_type) + + record = await kvs_client.get_value(key='xml_key') + assert record is not None + assert record.value == value + assert record.content_type == content_type + + +async def test_delete_value(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that a stored value can be deleted and is no longer retrievable after deletion.""" + # Set a value + await kvs_client.set_value(key='delete_me', value='to be deleted') + + # Verify it exists + record = await kvs_client.get_value(key='delete_me') + assert record is not None + + # Delete it + await kvs_client.delete_value(key='delete_me') + + # Verify it's gone + record = await kvs_client.get_value(key='delete_me') + assert record is None + + +async def test_delete_nonexistent_value(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that attempting to delete a non-existent key is a no-op and doesn't raise errors.""" + # Should not raise an error + await kvs_client.delete_value(key='nonexistent') + + +async def test_iterate_keys(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that all keys can be iterated over and are returned in sorted order with correct metadata.""" + # Set some values + items = { + 'a_key': 'value A', + 'b_key': 'value B', + 'c_key': 'value C', + 'd_key': 'value D', + } + + for key, value in items.items(): + await kvs_client.set_value(key=key, value=value) + + # Get all keys + metadata_list = [metadata async for metadata in kvs_client.iterate_keys()] + + # Verify keys are returned in sorted order + assert len(metadata_list) == 4 + assert [m.key for m in metadata_list] == sorted(items.keys()) + assert all(isinstance(m, KeyValueStoreRecordMetadata) for m in metadata_list) + + +async def test_iterate_keys_with_exclusive_start_key(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that exclusive_start_key parameter returns only keys after it alphabetically.""" + # Set some values + for key in ['a_key', 'b_key', 'c_key', 'd_key', 'e_key']: + await kvs_client.set_value(key=key, value=f'value for {key}') + + # Get keys starting after 'b_key' + metadata_list = [metadata async for metadata in kvs_client.iterate_keys(exclusive_start_key='b_key')] + + # Verify only keys after 'b_key' are returned + assert len(metadata_list) == 3 + assert [m.key for m in metadata_list] == ['c_key', 'd_key', 'e_key'] + + +async def test_iterate_keys_with_limit(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that the limit parameter returns only the specified number of keys.""" + # Set some values + for key in ['a_key', 'b_key', 'c_key', 'd_key', 'e_key']: + await kvs_client.set_value(key=key, value=f'value for {key}') + + # Get first 3 keys + metadata_list = [metadata async for metadata in kvs_client.iterate_keys(limit=3)] + + # Verify only the first 3 keys are returned + assert len(metadata_list) == 3 + assert [m.key for m in metadata_list] == ['a_key', 'b_key', 'c_key'] + + +async def test_drop(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that drop removes the store from cache and clears all data.""" + # Add some values to the store + await kvs_client.set_value(key='test', value='data') + + # Verify the store exists in the cache + assert kvs_client.name in _cache_by_name + + # Drop the store + await kvs_client.drop() + + # Verify the store was removed from the cache + assert kvs_client.name not in _cache_by_name + + # Verify the store is empty + record = await kvs_client.get_value(key='test') + assert record is None + + +async def test_get_public_url(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that get_public_url raises NotImplementedError for the memory implementation.""" + with pytest.raises(NotImplementedError): + await kvs_client.get_public_url(key='any-key') + + +async def test_metadata_updates(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that read/write operations properly update accessed_at and modified_at timestamps.""" + # Record initial timestamps + initial_created = kvs_client.created_at + initial_accessed = kvs_client.accessed_at + initial_modified = kvs_client.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await kvs_client.get_value(key='nonexistent') + + # Verify timestamps + assert kvs_client.created_at == initial_created + assert kvs_client.accessed_at > initial_accessed + assert kvs_client.modified_at == initial_modified + + accessed_after_get = kvs_client.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at and accessed_at + await kvs_client.set_value(key='new_key', value='new value') + + # Verify timestamps again + assert kvs_client.created_at == initial_created + assert kvs_client.modified_at > initial_modified + assert kvs_client.accessed_at > accessed_after_get diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index 2d21eac3b8..f299aee08d 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -129,7 +129,7 @@ async def test_iterate_items(dataset: Dataset) -> None: idx = 0 await dataset.push_data([{'id': i} for i in range(desired_item_count)]) - async for item in dataset.iterate(): + async for item in dataset.iterate_items(): assert item['id'] == idx idx += 1 From 026cbf9f8a9cc77785833a88155a818a1706ac1a Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Mon, 14 Apr 2025 08:59:48 +0200 Subject: [PATCH 09/22] Caching of Dataset and KVS --- .../_file_system/_dataset_client.py | 18 +++--- .../_file_system/_key_value_store_client.py | 18 +++--- .../_memory/_dataset_client.py | 18 +++--- .../_memory/_key_value_store_client.py | 18 +++--- src/crawlee/storages/_dataset.py | 56 +++++++++---------- src/crawlee/storages/_key_value_store.py | 56 +++++++++---------- .../_file_system/test_fs_dataset_client.py | 19 ++++--- .../_file_system/test_fs_kvs_client.py | 20 +++---- .../_memory/test_memory_dataset_client.py | 11 ++-- .../_memory/test_memory_kvs_client.py | 11 ++-- 10 files changed, 120 insertions(+), 125 deletions(-) diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py index b0a665cee7..2381f96cae 100644 --- a/src/crawlee/storage_clients/_file_system/_dataset_client.py +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from logging import getLogger from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from pydantic import ValidationError from typing_extensions import override @@ -23,9 +23,6 @@ logger = getLogger(__name__) -_cache_by_name = dict[str, 'FileSystemDatasetClient']() -"""A dictionary to cache clients by their names.""" - class FileSystemDatasetClient(DatasetClient): """A file system implementation of the dataset client. @@ -44,6 +41,9 @@ class FileSystemDatasetClient(DatasetClient): _LOCAL_ENTRY_NAME_DIGITS = 9 """Number of digits used for the file names (e.g., 000000019.json).""" + _cache_by_name: ClassVar[dict[str, FileSystemDatasetClient]] = {} + """A dictionary to cache clients by their names.""" + def __init__( self, *, @@ -131,8 +131,8 @@ async def open( name = name or cls._DEFAULT_NAME # Check if the client is already cached by name. - if name in _cache_by_name: - client = _cache_by_name[name] + if name in cls._cache_by_name: + client = cls._cache_by_name[name] await client._update_metadata(update_accessed_at=True) # noqa: SLF001 return client @@ -182,7 +182,7 @@ async def open( await client._update_metadata() # Cache the client by name. - _cache_by_name[name] = client + cls._cache_by_name[name] = client return client @@ -194,8 +194,8 @@ async def drop(self) -> None: await asyncio.to_thread(shutil.rmtree, self.path_to_dataset) # Remove the client from the cache. - if self.name in _cache_by_name: - del _cache_by_name[self.name] + if self.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.name] # noqa: SLF001 @override async def push_data(self, data: list[Any] | dict[str, Any]) -> None: diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py index aa347d50d4..2c7dc61651 100644 --- a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from logging import getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from pydantic import ValidationError from typing_extensions import override @@ -24,9 +24,6 @@ logger = getLogger(__name__) -_cache_by_name = dict[str, 'FileSystemKeyValueStoreClient']() -"""A dictionary to cache clients by their names.""" - class FileSystemKeyValueStoreClient(KeyValueStoreClient): """A file system implementation of the key-value store client. @@ -42,6 +39,9 @@ class FileSystemKeyValueStoreClient(KeyValueStoreClient): _STORAGE_SUBDIR = 'key_value_stores' """The name of the subdirectory where key-value stores are stored.""" + _cache_by_name: ClassVar[dict[str, FileSystemKeyValueStoreClient]] = {} + """A dictionary to cache clients by their names.""" + def __init__( self, *, @@ -122,8 +122,8 @@ async def open( name = name or cls._DEFAULT_NAME # Check if the client is already cached by name. - if name in _cache_by_name: - return _cache_by_name[name] + if name in cls._cache_by_name: + return cls._cache_by_name[name] storage_dir = storage_dir or Path.cwd() kvs_path = storage_dir / cls._STORAGE_SUBDIR / name @@ -169,7 +169,7 @@ async def open( await client._update_metadata() # Cache the client by name. - _cache_by_name[name] = client + cls._cache_by_name[name] = client return client @@ -181,8 +181,8 @@ async def drop(self) -> None: await asyncio.to_thread(shutil.rmtree, self.path_to_kvs) # Remove the client from the cache. - if self.name in _cache_by_name: - del _cache_by_name[self.name] + if self.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.name] # noqa: SLF001 @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 558619c5f0..684c384bd5 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -2,7 +2,7 @@ from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from typing_extensions import override @@ -16,9 +16,6 @@ logger = getLogger(__name__) -_cache_by_name = dict[str, 'MemoryDatasetClient']() -"""A dictionary to cache clients by their names.""" - class MemoryDatasetClient(DatasetClient): """A memory implementation of the dataset client. @@ -31,6 +28,9 @@ class MemoryDatasetClient(DatasetClient): _DEFAULT_NAME = 'default' """The default name for the dataset when no name is provided.""" + _cache_by_name: ClassVar[dict[str, MemoryDatasetClient]] = {} + """A dictionary to cache clients by their names.""" + def __init__( self, *, @@ -102,8 +102,8 @@ async def open( name = name or cls._DEFAULT_NAME # Check if the client is already cached by name. - if name in _cache_by_name: - client = _cache_by_name[name] + if name in cls._cache_by_name: + client = cls._cache_by_name[name] await client._update_metadata(update_accessed_at=True) # noqa: SLF001 return client @@ -120,7 +120,7 @@ async def open( ) # Cache the client by name - _cache_by_name[name] = client + cls._cache_by_name[name] = client return client @@ -130,8 +130,8 @@ async def drop(self) -> None: self._metadata.item_count = 0 # Remove the client from the cache - if self.name in _cache_by_name: - del _cache_by_name[self.name] + if self.name in self.__class__._cache_by_name: + del self.__class__._cache_by_name[self.name] @override async def push_data(self, data: list[Any] | dict[str, Any]) -> None: diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index dcaec9f458..0bb9651ee9 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -3,7 +3,7 @@ import sys from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from typing_extensions import override @@ -18,9 +18,6 @@ logger = getLogger(__name__) -_cache_by_name = dict[str, 'MemoryKeyValueStoreClient']() -"""A dictionary to cache clients by their names.""" - class MemoryKeyValueStoreClient(KeyValueStoreClient): """A memory implementation of the key-value store client. @@ -33,6 +30,9 @@ class MemoryKeyValueStoreClient(KeyValueStoreClient): _DEFAULT_NAME = 'default' """The default name for the key-value store when no name is provided.""" + _cache_by_name: ClassVar[dict[str, MemoryKeyValueStoreClient]] = {} + """A dictionary to cache clients by their names.""" + def __init__( self, *, @@ -97,8 +97,8 @@ async def open( name = name or cls._DEFAULT_NAME # Check if the client is already cached by name - if name in _cache_by_name: - client = _cache_by_name[name] + if name in cls._cache_by_name: + client = cls._cache_by_name[name] await client._update_metadata(update_accessed_at=True) # noqa: SLF001 return client @@ -115,7 +115,7 @@ async def open( ) # Cache the client by name - _cache_by_name[name] = client + cls._cache_by_name[name] = client return client @@ -125,8 +125,8 @@ async def drop(self) -> None: self._store.clear() # Remove from cache - if self.name in _cache_by_name: - del _cache_by_name[self.name] + if self.name in self.__class__._cache_by_name: + del self.__class__._cache_by_name[self.name] @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 4f5659eae0..7a7c8f4baf 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -3,7 +3,7 @@ import logging from io import StringIO from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, ClassVar, Literal from typing_extensions import override @@ -30,32 +30,6 @@ logger = logging.getLogger(__name__) -# TODO: -# - caching / memoization of Dataset - -# Properties: -# - id -# - name -# - metadata - -# Methods: -# - open -# - drop -# - push_data -# - get_data -# - iterate_items -# - export_to -# - export_to_json -# - export_to_csv - -# Breaking changes: -# - from_storage_object method has been removed - Use the open method with name and/or id instead. -# - get_info -> metadata property -# - storage_object -> metadata property -# - set_metadata method has been removed - Do we want to support it (e.g. for renaming)? -# - write_to_json -> export_to_json -# - write_to_csv -> export_to_csv - @docs_group('Classes') class Dataset(Storage): @@ -90,6 +64,12 @@ class Dataset(Storage): ``` """ + _cache_by_id: ClassVar[dict[str, Dataset]] = {} + """A dictionary to cache datasets by their IDs.""" + + _cache_by_name: ClassVar[dict[str, Dataset]] = {} + """A dictionary to cache datasets by their names.""" + def __init__(self, client: DatasetClient) -> None: """Initialize a new instance. @@ -137,6 +117,12 @@ async def open( if id and name: raise ValueError('Only one of "id" or "name" can be specified, not both.') + # Check if dataset is already cached by id or name + if id and id in cls._cache_by_id: + return cls._cache_by_id[id] + if name and name in cls._cache_by_name: + return cls._cache_by_name[name] + configuration = service_locator.get_configuration() if configuration is None else configuration storage_client = service_locator.get_storage_client() if storage_client is None else storage_client purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start @@ -149,10 +135,24 @@ async def open( storage_dir=storage_dir, ) - return cls(client) + dataset = cls(client) + + # Cache the dataset by id and name if available + if dataset.id: + cls._cache_by_id[dataset.id] = dataset + if dataset.name: + cls._cache_by_name[dataset.name] = dataset + + return dataset @override async def drop(self) -> None: + # Remove from cache before dropping + if self.id in self._cache_by_id: + del self._cache_by_id[self.id] + if self.name and self.name in self._cache_by_name: + del self._cache_by_name[self.name] + await self._client.drop() async def push_data(self, data: list[Any] | dict[str, Any]) -> None: diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 6567adeb8f..5bbe15e08d 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -14,38 +14,12 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from crawlee._types import JsonSerializable from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient from crawlee.storage_clients._base import KeyValueStoreClient T = TypeVar('T') -# TODO: -# - caching / memoization of KVS - -# Properties: -# - id -# - name -# - metadata - -# Methods: -# - open -# - drop -# - get_value -# - set_value -# - delete_value (new method) -# - iterate_keys -# - list_keys (new method) -# - get_public_url - -# Breaking changes: -# - from_storage_object method has been removed - Use the open method with name and/or id instead. -# - get_info -> metadata property -# - storage_object -> metadata property -# - set_metadata method has been removed - Do we want to support it (e.g. for renaming)? -# - get_auto_saved_value method has been removed -> It should be managed by the underlying client. -# - persist_autosaved_values method has been removed -> It should be managed by the underlying client. @docs_group('Classes') class KeyValueStore(Storage): @@ -82,9 +56,11 @@ class KeyValueStore(Storage): ``` """ - # Cache for persistent (auto-saved) values - _general_cache: ClassVar[dict[str, dict[str, dict[str, JsonSerializable]]]] = {} - _persist_state_event_started = False + _cache_by_id: ClassVar[dict[str, KeyValueStore]] = {} + """A dictionary to cache key-value stores by their IDs.""" + + _cache_by_name: ClassVar[dict[str, KeyValueStore]] = {} + """A dictionary to cache key-value stores by their names.""" def __init__(self, client: KeyValueStoreClient) -> None: """Initialize a new instance. @@ -132,6 +108,12 @@ async def open( if id and name: raise ValueError('Only one of "id" or "name" can be specified, not both.') + # Check if key value store is already cached by id or name + if id and id in cls._cache_by_id: + return cls._cache_by_id[id] + if name and name in cls._cache_by_name: + return cls._cache_by_name[name] + configuration = service_locator.get_configuration() if configuration is None else configuration storage_client = service_locator.get_storage_client() if storage_client is None else storage_client purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start @@ -144,10 +126,24 @@ async def open( storage_dir=storage_dir, ) - return cls(client) + kvs = cls(client) + + # Cache the key value store by id and name if available + if kvs.id: + cls._cache_by_id[kvs.id] = kvs + if kvs.name: + cls._cache_by_name[kvs.name] = kvs + + return kvs @override async def drop(self) -> None: + # Remove from cache before dropping + if self.id in self._cache_by_id: + del self._cache_by_id[self.id] + if self.name and self.name in self._cache_by_name: + del self._cache_by_name[self.name] + await self._client.drop() @overload diff --git a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py index 247830d435..0570b4db70 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py @@ -21,6 +21,9 @@ @pytest.fixture async def dataset_client(tmp_path: Path) -> AsyncGenerator[FileSystemDatasetClient, None]: """A fixture for a file system dataset client.""" + # Clear any existing dataset clients in the cache + FileSystemDatasetClient._cache_by_name.clear() + client = await FileSystemDatasetClient.open(name='test_dataset', storage_dir=tmp_path) yield client await client.drop() @@ -232,20 +235,18 @@ async def test_iterate_with_options(dataset_client: FileSystemDatasetClient) -> assert collected_items[-1]['id'] == 8 -async def test_drop(tmp_path: Path) -> None: +async def test_drop(dataset_client: FileSystemDatasetClient) -> None: """Test dropping a dataset removes the entire dataset directory from disk.""" - # Create a dataset and add an item - client = await FileSystemDatasetClient.open(name='to_drop', storage_dir=tmp_path) - await client.push_data({'test': 'data'}) + await dataset_client.push_data({'test': 'data'}) - # Verify the dataset directory exists - assert client.path_to_dataset.exists() + assert dataset_client.name in FileSystemDatasetClient._cache_by_name + assert dataset_client.path_to_dataset.exists() # Drop the dataset - await client.drop() + await dataset_client.drop() - # Verify the dataset directory was removed - assert not client.path_to_dataset.exists() + assert dataset_client.name not in FileSystemDatasetClient._cache_by_name + assert not dataset_client.path_to_dataset.exists() async def test_metadata_updates(dataset_client: FileSystemDatasetClient) -> None: diff --git a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py index 7f20fbfea1..38e65a16a2 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py @@ -20,6 +20,9 @@ @pytest.fixture async def kvs_client(tmp_path: Path) -> AsyncGenerator[FileSystemKeyValueStoreClient, None]: """Fixture that provides a fresh file system key-value store client using a temporary directory.""" + # Clear any existing dataset clients in the cache + FileSystemKeyValueStoreClient._cache_by_name.clear() + client = await FileSystemKeyValueStoreClient.open(name='test_kvs', storage_dir=tmp_path) yield client await client.drop() @@ -247,21 +250,18 @@ async def test_iterate_keys_with_exclusive_start_key(kvs_client: FileSystemKeyVa assert 'b-key' not in keys -async def test_drop(tmp_path: Path) -> None: +async def test_drop(kvs_client: FileSystemKeyValueStoreClient) -> None: """Test that drop removes the entire store directory from disk.""" - # Create a store and add a value - client = await FileSystemKeyValueStoreClient.open(name='to_drop', storage_dir=tmp_path) - await client.set_value(key='test', value='test-value') + await kvs_client.set_value(key='test', value='test-value') - # Verify the store directory exists - kvs_path = client.path_to_kvs - assert kvs_path.exists() + assert kvs_client.name in FileSystemKeyValueStoreClient._cache_by_name + assert kvs_client.path_to_kvs.exists() # Drop the store - await client.drop() + await kvs_client.drop() - # Verify the directory was removed - assert not kvs_path.exists() + assert kvs_client.name not in FileSystemKeyValueStoreClient._cache_by_name + assert not kvs_client.path_to_kvs.exists() async def test_metadata_updates(kvs_client: FileSystemKeyValueStoreClient) -> None: diff --git a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py index 0554a9d3cb..6f18c5b6b7 100644 --- a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py @@ -6,7 +6,7 @@ import pytest -from crawlee.storage_clients._memory._dataset_client import MemoryDatasetClient, _cache_by_name +from crawlee.storage_clients._memory._dataset_client import MemoryDatasetClient from crawlee.storage_clients.models import DatasetItemsListPage if TYPE_CHECKING: @@ -19,7 +19,7 @@ async def dataset_client() -> AsyncGenerator[MemoryDatasetClient, None]: """Fixture that provides a fresh memory dataset client for each test.""" # Clear any existing dataset clients in the cache - _cache_by_name.clear() + MemoryDatasetClient._cache_by_name.clear() client = await MemoryDatasetClient.open(name='test_dataset') yield client @@ -39,7 +39,7 @@ async def test_open_creates_new_dataset() -> None: assert isinstance(client.modified_at, datetime) # Verify the client was cached - assert 'new_dataset' in _cache_by_name + assert 'new_dataset' in MemoryDatasetClient._cache_by_name async def test_open_existing_dataset(dataset_client: MemoryDatasetClient) -> None: @@ -226,17 +226,16 @@ async def test_iterate_with_options(dataset_client: MemoryDatasetClient) -> None async def test_drop(dataset_client: MemoryDatasetClient) -> None: """Test that drop removes the dataset from cache and resets its state.""" - # Add an item to the dataset await dataset_client.push_data({'test': 'data'}) # Verify the dataset exists in the cache - assert dataset_client.name in _cache_by_name + assert dataset_client.name in MemoryDatasetClient._cache_by_name # Drop the dataset await dataset_client.drop() # Verify the dataset was removed from the cache - assert dataset_client.name not in _cache_by_name + assert dataset_client.name not in MemoryDatasetClient._cache_by_name # Verify the dataset is empty assert dataset_client.item_count == 0 diff --git a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py index ce8e889573..8940764839 100644 --- a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py @@ -6,7 +6,7 @@ import pytest -from crawlee.storage_clients._memory._key_value_store_client import MemoryKeyValueStoreClient, _cache_by_name +from crawlee.storage_clients._memory._key_value_store_client import MemoryKeyValueStoreClient from crawlee.storage_clients.models import KeyValueStoreRecordMetadata if TYPE_CHECKING: @@ -19,13 +19,12 @@ async def kvs_client() -> AsyncGenerator[MemoryKeyValueStoreClient, None]: """Fixture that provides a fresh memory key-value store client for each test.""" # Clear any existing key-value store clients in the cache - _cache_by_name.clear() + MemoryKeyValueStoreClient._cache_by_name.clear() client = await MemoryKeyValueStoreClient.open(name='test_kvs') yield client await client.drop() - async def test_open_creates_new_store() -> None: """Test that open() creates a new key-value store with proper metadata and adds it to the cache.""" client = await MemoryKeyValueStoreClient.open(name='new_kvs') @@ -38,7 +37,7 @@ async def test_open_creates_new_store() -> None: assert isinstance(client.modified_at, datetime) # Verify the client was cached - assert 'new_kvs' in _cache_by_name + assert 'new_kvs' in MemoryKeyValueStoreClient._cache_by_name async def test_open_existing_store(kvs_client: MemoryKeyValueStoreClient) -> None: @@ -186,13 +185,13 @@ async def test_drop(kvs_client: MemoryKeyValueStoreClient) -> None: await kvs_client.set_value(key='test', value='data') # Verify the store exists in the cache - assert kvs_client.name in _cache_by_name + assert kvs_client.name in MemoryKeyValueStoreClient._cache_by_name # Drop the store await kvs_client.drop() # Verify the store was removed from the cache - assert kvs_client.name not in _cache_by_name + assert kvs_client.name not in MemoryKeyValueStoreClient._cache_by_name # Verify the store is empty record = await kvs_client.get_value(key='test') From 8554115dcc233aa7301cf367b1f48d8cd5008286 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Mon, 14 Apr 2025 09:00:00 +0200 Subject: [PATCH 10/22] Storage clients (entrypoints) and their tests --- .../storage_clients/_base/_storage_client.py | 24 +-- .../_file_system/_storage_client.py | 32 ++-- .../_memory/_storage_client.py | 42 +++--- .../_file_system/test_fs_storage_client.py | 142 ++++++++++++++++++ 4 files changed, 191 insertions(+), 49 deletions(-) create mode 100644 tests/unit/storage_clients/_file_system/test_fs_storage_client.py diff --git a/src/crawlee/storage_clients/_base/_storage_client.py b/src/crawlee/storage_clients/_base/_storage_client.py index b12792f202..0d5a67c1fc 100644 --- a/src/crawlee/storage_clients/_base/_storage_client.py +++ b/src/crawlee/storage_clients/_base/_storage_client.py @@ -18,10 +18,10 @@ class StorageClient(ABC): async def open_dataset_client( self, *, - id: str | None, - name: str | None, - purge_on_start: bool, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + purge_on_start: bool = True, + storage_dir: Path | None = None, ) -> DatasetClient: """Open the dataset client.""" @@ -29,10 +29,10 @@ async def open_dataset_client( async def open_key_value_store_client( self, *, - id: str | None, - name: str | None, - purge_on_start: bool, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + purge_on_start: bool = True, + storage_dir: Path | None = None, ) -> KeyValueStoreClient: """Open the key-value store client.""" @@ -40,9 +40,9 @@ async def open_key_value_store_client( async def open_request_queue_client( self, *, - id: str | None, - name: str | None, - purge_on_start: bool, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + purge_on_start: bool = True, + storage_dir: Path | None = None, ) -> RequestQueueClient: """Open the request queue client.""" diff --git a/src/crawlee/storage_clients/_file_system/_storage_client.py b/src/crawlee/storage_clients/_file_system/_storage_client.py index 9d3adefb76..5f5dc95abf 100644 --- a/src/crawlee/storage_clients/_file_system/_storage_client.py +++ b/src/crawlee/storage_clients/_file_system/_storage_client.py @@ -21,10 +21,10 @@ class FileSystemStorageClient(StorageClient): async def open_dataset_client( self, *, - id: str | None, - name: str | None, - purge_on_start: bool, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + purge_on_start: bool = True, + storage_dir: Path | None = None, ) -> FileSystemDatasetClient: client = await FileSystemDatasetClient.open(id=id, name=name, storage_dir=storage_dir) @@ -38,10 +38,10 @@ async def open_dataset_client( async def open_key_value_store_client( self, *, - id: str | None, - name: str | None, - purge_on_start: bool, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + purge_on_start: bool = True, + storage_dir: Path | None = None, ) -> FileSystemKeyValueStoreClient: client = await FileSystemKeyValueStoreClient.open(id=id, name=name, storage_dir=storage_dir) @@ -55,15 +55,9 @@ async def open_key_value_store_client( async def open_request_queue_client( self, *, - id: str | None, - name: str | None, - purge_on_start: bool, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + purge_on_start: bool = True, + storage_dir: Path | None = None, ) -> FileSystemRequestQueueClient: - client = await FileSystemRequestQueueClient.open(id=id, name=name, storage_dir=storage_dir) - - if purge_on_start: - await client.drop() - client = await FileSystemRequestQueueClient.open(id=id, name=name, storage_dir=storage_dir) - - return client + pass diff --git a/src/crawlee/storage_clients/_memory/_storage_client.py b/src/crawlee/storage_clients/_memory/_storage_client.py index 4d9d090b38..5ce9b16dd1 100644 --- a/src/crawlee/storage_clients/_memory/_storage_client.py +++ b/src/crawlee/storage_clients/_memory/_storage_client.py @@ -21,37 +21,43 @@ class MemoryStorageClient(StorageClient): async def open_dataset_client( self, *, - id: str | None, - name: str | None, - purge_on_start: bool, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + purge_on_start: bool = True, + storage_dir: Path | None = None ) -> MemoryDatasetClient: - dataset_client = await MemoryDatasetClient.open(id=id, name=name, storage_dir=storage_dir) + client = await MemoryDatasetClient.open(id=id, name=name, storage_dir=storage_dir) if purge_on_start: - await dataset_client.drop() - dataset_client = await MemoryDatasetClient.open(id=id, name=name, storage_dir=storage_dir) + await client.drop() + client = await MemoryDatasetClient.open(id=id, name=name, storage_dir=storage_dir) - return dataset_client + return client @override async def open_key_value_store_client( self, *, - id: str | None, - name: str | None, - purge_on_start: bool, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + purge_on_start: bool = True, + storage_dir: Path | None = None ) -> MemoryKeyValueStoreClient: - return MemoryKeyValueStoreClient() + client = await MemoryKeyValueStoreClient.open(id=id, name=name, storage_dir=storage_dir) + + if purge_on_start: + await client.drop() + client = await MemoryKeyValueStoreClient.open(id=id, name=name, storage_dir=storage_dir) + + return client @override async def open_request_queue_client( self, *, - id: str | None, - name: str | None, - purge_on_start: bool, - storage_dir: Path, + id: str | None = None, + name: str | None = None, + purge_on_start: bool = True, + storage_dir: Path | None = None ) -> MemoryRequestQueueClient: - return MemoryRequestQueueClient() + pass diff --git a/tests/unit/storage_clients/_file_system/test_fs_storage_client.py b/tests/unit/storage_clients/_file_system/test_fs_storage_client.py new file mode 100644 index 0000000000..843911de97 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_storage_client.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from crawlee.storage_clients._file_system._dataset_client import FileSystemDatasetClient +from crawlee.storage_clients._file_system._key_value_store_client import FileSystemKeyValueStoreClient +from crawlee.storage_clients._file_system._storage_client import FileSystemStorageClient + +if TYPE_CHECKING: + from pathlib import Path + +pytestmark = pytest.mark.only + + +@pytest.fixture +async def client() -> FileSystemStorageClient: + return FileSystemStorageClient() + + +async def test_open_dataset_client(client: FileSystemStorageClient, tmp_path: Path) -> None: + """Test that open_dataset_client creates a dataset client with correct type and properties.""" + dataset_client = await client.open_dataset_client(name='test-dataset', storage_dir=tmp_path) + + # Verify correct client type and properties + assert isinstance(dataset_client, FileSystemDatasetClient) + assert dataset_client.name == 'test-dataset' + + # Verify directory structure was created + assert dataset_client.path_to_dataset.exists() + + +async def test_dataset_client_purge_on_start(client: FileSystemStorageClient, tmp_path: Path) -> None: + """Test that purge_on_start=True clears existing data in the dataset.""" + # Create dataset and add data + dataset_client1 = await client.open_dataset_client( + name='test-purge-dataset', + storage_dir=tmp_path, + purge_on_start=True, + ) + await dataset_client1.push_data({'item': 'initial data'}) + + # Verify data was added + items = await dataset_client1.get_data() + assert len(items.items) == 1 + + # Reopen + dataset_client2 = await client.open_dataset_client( + name='test-purge-dataset', + storage_dir=tmp_path, + purge_on_start=True, + ) + + # Verify data was purged + items = await dataset_client2.get_data() + assert len(items.items) == 0 + + +async def test_dataset_client_no_purge_on_start(client: FileSystemStorageClient, tmp_path: Path) -> None: + """Test that purge_on_start=False keeps existing data in the dataset.""" + # Create dataset and add data + dataset_client1 = await client.open_dataset_client( + name='test-no-purge-dataset', + storage_dir=tmp_path, + purge_on_start=False, + ) + await dataset_client1.push_data({'item': 'preserved data'}) + + # Reopen + dataset_client2 = await client.open_dataset_client( + name='test-no-purge-dataset', + storage_dir=tmp_path, + purge_on_start=False, + ) + + # Verify data was preserved + items = await dataset_client2.get_data() + assert len(items.items) == 1 + assert items.items[0]['item'] == 'preserved data' + + +async def test_open_kvs_client(client: FileSystemStorageClient, tmp_path: Path) -> None: + """Test that open_key_value_store_client creates a KVS client with correct type and properties.""" + kvs_client = await client.open_key_value_store_client(name='test-kvs', storage_dir=tmp_path) + + # Verify correct client type and properties + assert isinstance(kvs_client, FileSystemKeyValueStoreClient) + assert kvs_client.name == 'test-kvs' + + # Verify directory structure was created + assert kvs_client.path_to_kvs.exists() + + +async def test_kvs_client_purge_on_start(client: FileSystemStorageClient, tmp_path: Path) -> None: + """Test that purge_on_start=True clears existing data in the key-value store.""" + # Create KVS and add data + kvs_client1 = await client.open_key_value_store_client( + name='test-purge-kvs', + storage_dir=tmp_path, + purge_on_start=True, + ) + await kvs_client1.set_value(key='test-key', value='initial value') + + # Verify value was set + record = await kvs_client1.get_value(key='test-key') + assert record is not None + assert record.value == 'initial value' + + # Reopen + kvs_client2 = await client.open_key_value_store_client( + name='test-purge-kvs', + storage_dir=tmp_path, + purge_on_start=True, + ) + + # Verify value was purged + record = await kvs_client2.get_value(key='test-key') + assert record is None + + +async def test_kvs_client_no_purge_on_start(client: FileSystemStorageClient, tmp_path: Path) -> None: + """Test that purge_on_start=False keeps existing data in the key-value store.""" + # Create KVS and add data + kvs_client1 = await client.open_key_value_store_client( + name='test-no-purge-kvs', + storage_dir=tmp_path, + purge_on_start=False, + ) + await kvs_client1.set_value(key='test-key', value='preserved value') + + # Reopen + kvs_client2 = await client.open_key_value_store_client( + name='test-no-purge-kvs', + storage_dir=tmp_path, + purge_on_start=False, + ) + + # Verify value was preserved + record = await kvs_client2.get_value(key='test-key') + assert record is not None + assert record.value == 'preserved value' From c051f650bfc2e42ad85da10a03b5590ab8e4416c Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Mon, 14 Apr 2025 09:00:05 +0200 Subject: [PATCH 11/22] Dataset and KVS tests --- src/crawlee/configuration.py | 18 +- src/crawlee/storages/_dataset.py | 67 ++- tests/unit/storages/test_dataset.py | 482 +++++++++++++++----- tests/unit/storages/test_key_value_store.py | 458 ++++++++++++------- 4 files changed, 679 insertions(+), 346 deletions(-) diff --git a/src/crawlee/configuration.py b/src/crawlee/configuration.py index de22118816..e3ef39f486 100644 --- a/src/crawlee/configuration.py +++ b/src/crawlee/configuration.py @@ -118,21 +118,7 @@ class Configuration(BaseSettings): ) ), ] = True - """Whether to purge the storage on the start. This option is utilized by the `MemoryStorageClient`.""" - - write_metadata: Annotated[bool, Field(alias='crawlee_write_metadata')] = True - """Whether to write the storage metadata. This option is utilized by the `MemoryStorageClient`.""" - - persist_storage: Annotated[ - bool, - Field( - validation_alias=AliasChoices( - 'apify_persist_storage', - 'crawlee_persist_storage', - ) - ), - ] = True - """Whether to persist the storage. This option is utilized by the `MemoryStorageClient`.""" + """Whether to purge the storage on the start. This option is utilized by the storage clients.""" persist_state_interval: Annotated[ timedelta_ms, @@ -239,7 +225,7 @@ class Configuration(BaseSettings): ), ), ] = './storage' - """The path to the storage directory. This option is utilized by the `MemoryStorageClient`.""" + """The path to the storage directory. This option is utilized by the storage clients.""" headless: Annotated[ bool, diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 7a7c8f4baf..099b8ffcf4 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -3,7 +3,7 @@ import logging from io import StringIO from pathlib import Path -from typing import TYPE_CHECKING, ClassVar, Literal +from typing import TYPE_CHECKING, overload from typing_extensions import override @@ -17,7 +17,7 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from typing import Any + from typing import Any, ClassVar, Literal from typing_extensions import Unpack @@ -267,12 +267,33 @@ async def iterate_items( ): yield item + @overload + async def export_to( + self, + key: str, + content_type: Literal['json'], + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, + **kwargs: Unpack[ExportDataJsonKwargs], + ) -> None: ... + + @overload + async def export_to( + self, + key: str, + content_type: Literal['csv'], + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, + **kwargs: Unpack[ExportDataCsvKwargs], + ) -> None: ... + async def export_to( self, key: str, content_type: Literal['json', 'csv'] = 'json', to_key_value_store_id: str | None = None, to_key_value_store_name: str | None = None, + **kwargs: Any, ) -> None: """Export the entire dataset into a specified file stored under a key in a key-value store. @@ -288,42 +309,16 @@ async def export_to( Specify only one of ID or name. to_key_value_store_name: Name of the key-value store to save the exported file. Specify only one of ID or name. + kwargs: Additional parameters for the export operation, specific to the chosen content type. """ + kvs = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) + dst = StringIO() + if content_type == 'csv': - await self.export_to_csv( - key, - to_key_value_store_id, - to_key_value_store_name, - ) + await export_csv_to_stream(self.iterate_items(), dst, **kwargs) + await kvs.set_value(key, dst.getvalue(), 'text/csv') elif content_type == 'json': - await self.export_to_json( - key, - to_key_value_store_id, - to_key_value_store_name, - ) + await export_json_to_stream(self.iterate_items(), dst, **kwargs) + await kvs.set_value(key, dst.getvalue(), 'application/json') else: raise ValueError('Unsupported content type, expecting CSV or JSON') - - async def export_to_json( - self, - key: str, - to_key_value_store_id: str | None = None, - to_key_value_store_name: str | None = None, - **kwargs: Unpack[ExportDataJsonKwargs], - ) -> None: - kvs = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) - dst = StringIO() - await export_json_to_stream(self.iterate_items(), dst, **kwargs) - await kvs.set_value(key, dst.getvalue(), 'application/json') - - async def export_to_csv( - self, - key: str, - to_key_value_store_id: str | None = None, - to_key_value_store_name: str | None = None, - **kwargs: Unpack[ExportDataCsvKwargs], - ) -> None: - kvs = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) - dst = StringIO() - await export_csv_to_stream(self.iterate_items(), dst, **kwargs) - await kvs.set_value(key, dst.getvalue(), 'text/csv') diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index f299aee08d..4a1c8a9e23 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -1,156 +1,398 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations -from datetime import datetime, timezone from typing import TYPE_CHECKING import pytest -from crawlee import service_locator -from crawlee.storage_clients.models import StorageMetadata +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient from crawlee.storages import Dataset, KeyValueStore if TYPE_CHECKING: from collections.abc import AsyncGenerator + from pathlib import Path + + from crawlee.storage_clients import StorageClient + +pytestmark = pytest.mark.only + + +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() + + return FileSystemStorageClient() @pytest.fixture -async def dataset() -> AsyncGenerator[Dataset, None]: - dataset = await Dataset.open() +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration(crawlee_storage_dir=str(tmp_path)) # type: ignore[call-arg] + + +@pytest.fixture +async def dataset( + storage_client: StorageClient, + configuration: Configuration, + tmp_path: Path, +) -> AsyncGenerator[Dataset, None]: + """Fixture that provides a dataset instance for each test.""" + Dataset._cache_by_id.clear() + Dataset._cache_by_name.clear() + + dataset = await Dataset.open( + name='test_dataset', + storage_dir=tmp_path, + storage_client=storage_client, + configuration=configuration, + ) + yield dataset await dataset.drop() -async def test_open() -> None: - default_dataset = await Dataset.open() - default_dataset_by_id = await Dataset.open(id=default_dataset.id) +async def test_open_creates_new_dataset( + storage_client: StorageClient, + configuration: Configuration, + tmp_path: Path, +) -> None: + """Test that open() creates a new dataset with proper metadata.""" + dataset = await Dataset.open( + name='new_dataset', + storage_dir=tmp_path, + storage_client=storage_client, + configuration=configuration, + ) - assert default_dataset is default_dataset_by_id + # Verify dataset properties + assert dataset.id is not None + assert dataset.name == 'new_dataset' + assert dataset.metadata.item_count == 0 - dataset_name = 'dummy-name' - named_dataset = await Dataset.open(name=dataset_name) - assert default_dataset is not named_dataset + await dataset.drop() - with pytest.raises(RuntimeError, match='Dataset with id "nonexistent-id" does not exist!'): - await Dataset.open(id='nonexistent-id') - # Test that when you try to open a dataset by ID and you use a name of an existing dataset, - # it doesn't work - with pytest.raises(RuntimeError, match='Dataset with id "dummy-name" does not exist!'): - await Dataset.open(id='dummy-name') +async def test_open_existing_dataset( + dataset: Dataset, + storage_client: StorageClient, + tmp_path: Path, +) -> None: + """Test that open() loads an existing dataset correctly.""" + # Open the same dataset again + reopened_dataset = await Dataset.open( + name=dataset.name, + storage_dir=tmp_path, + storage_client=storage_client, + ) + # Verify dataset properties + assert dataset.id == reopened_dataset.id + assert dataset.name == reopened_dataset.name + assert dataset.metadata.item_count == reopened_dataset.metadata.item_count + + # Verify they are the same object (from cache) + assert id(dataset) == id(reopened_dataset) + + +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, + tmp_path: Path, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await Dataset.open( + id='some-id', + name='some-name', + storage_dir=tmp_path, + storage_client=storage_client, + configuration=configuration, + ) + + +async def test_push_data_single_item(dataset: Dataset) -> None: + """Test pushing a single item to the dataset.""" + item = {'key': 'value', 'number': 42} + await dataset.push_data(item) + + # Verify item was stored + result = await dataset.get_data() + assert result.count == 1 + assert result.items[0] == item + + +async def test_push_data_multiple_items(dataset: Dataset) -> None: + """Test pushing multiple items to the dataset.""" + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Verify items were stored + result = await dataset.get_data() + assert result.count == 3 + assert result.items == items + + +async def test_get_data_empty_dataset(dataset: Dataset) -> None: + """Test getting data from an empty dataset returns empty results.""" + result = await dataset.get_data() + + assert result.count == 0 + assert result.total == 0 + assert result.items == [] + + +async def test_get_data_with_pagination(dataset: Dataset) -> None: + """Test getting data with offset and limit parameters for pagination.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset.push_data(items) + + # Test offset + result = await dataset.get_data(offset=3) + assert result.count == 7 + assert result.offset == 3 + assert result.items[0]['id'] == 4 + + # Test limit + result = await dataset.get_data(limit=5) + assert result.count == 5 + assert result.limit == 5 + assert result.items[-1]['id'] == 5 + + # Test both offset and limit + result = await dataset.get_data(offset=2, limit=3) + assert result.count == 3 + assert result.offset == 2 + assert result.limit == 3 + assert result.items[0]['id'] == 3 + assert result.items[-1]['id'] == 5 + + +async def test_get_data_descending_order(dataset: Dataset) -> None: + """Test getting data in descending order reverses the item order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) + + # Get items in descending order + result = await dataset.get_data(desc=True) + + assert result.desc is True + assert result.items[0]['id'] == 5 + assert result.items[-1]['id'] == 1 + + +async def test_get_data_skip_empty(dataset: Dataset) -> None: + """Test getting data with skip_empty option filters out empty items.""" + # Add some items including an empty one + items = [ + {'id': 1, 'name': 'Item 1'}, + {}, # Empty item + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Get all items + result = await dataset.get_data() + assert result.count == 3 + + # Get non-empty items + result = await dataset.get_data(skip_empty=True) + assert result.count == 2 + assert all(item != {} for item in result.items) -async def test_consistency_accross_two_clients() -> None: - dataset = await Dataset.open(name='my-dataset') - await dataset.push_data({'key': 'value'}) - dataset_by_id = await Dataset.open(id=dataset.id) - await dataset_by_id.push_data({'key2': 'value2'}) +async def test_iterate_items(dataset: Dataset) -> None: + """Test iterating over dataset items yields each item in the correct order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) + + # Iterate over all items + collected_items = [item async for item in dataset.iterate_items()] + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_iterate_items_with_options(dataset: Dataset) -> None: + """Test iterating with offset, limit and desc parameters.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset.push_data(items) + + # Test with offset and limit + collected_items = [item async for item in dataset.iterate_items(offset=3, limit=3)] + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 4 + assert collected_items[-1]['id'] == 6 + + # Test with descending order + collected_items = [] + async for item in dataset.iterate_items(desc=True, limit=3): + collected_items.append(item) + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 10 + assert collected_items[-1]['id'] == 8 + + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, + tmp_path: Path, +) -> None: + """Test dropping a dataset removes it from cache and clears its data.""" + dataset = await Dataset.open( + name='drop_test', + storage_dir=tmp_path, + storage_client=storage_client, + configuration=configuration, + ) + + # Add some data + await dataset.push_data({'test': 'data'}) - assert (await dataset.get_data()).items == [{'key': 'value'}, {'key2': 'value2'}] - assert (await dataset_by_id.get_data()).items == [{'key': 'value'}, {'key2': 'value2'}] + # Verify dataset exists in cache + assert dataset.id in Dataset._cache_by_id + if dataset.name: + assert dataset.name in Dataset._cache_by_name + # Drop the dataset await dataset.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await dataset_by_id.drop() - - -async def test_same_references() -> None: - dataset1 = await Dataset.open() - dataset2 = await Dataset.open() - assert dataset1 is dataset2 - - dataset_name = 'non-default' - dataset_named1 = await Dataset.open(name=dataset_name) - dataset_named2 = await Dataset.open(name=dataset_name) - assert dataset_named1 is dataset_named2 - - -async def test_drop() -> None: - dataset1 = await Dataset.open() - await dataset1.drop() - dataset2 = await Dataset.open() - assert dataset1 is not dataset2 - - -async def test_export(dataset: Dataset) -> None: - expected_csv = 'id,test\r\n0,test\r\n1,test\r\n2,test\r\n' - expected_json = [{'id': 0, 'test': 'test'}, {'id': 1, 'test': 'test'}, {'id': 2, 'test': 'test'}] - desired_item_count = 3 - await dataset.push_data([{'id': i, 'test': 'test'} for i in range(desired_item_count)]) - await dataset.export_to(key='dataset-csv', content_type='csv') - await dataset.export_to(key='dataset-json', content_type='json') - kvs = await KeyValueStore.open() - dataset_csv = await kvs.get_value(key='dataset-csv') - dataset_json = await kvs.get_value(key='dataset-json') - assert dataset_csv == expected_csv - assert dataset_json == expected_json - - -async def test_push_data(dataset: Dataset) -> None: - desired_item_count = 2000 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == desired_item_count - list_page = await dataset.get_data(limit=desired_item_count) - assert list_page.items[0]['id'] == 0 - assert list_page.items[-1]['id'] == desired_item_count - 1 - - -async def test_push_data_empty(dataset: Dataset) -> None: - await dataset.push_data([]) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == 0 - - -async def test_push_data_singular(dataset: Dataset) -> None: - await dataset.push_data({'id': 1}) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == 1 - list_page = await dataset.get_data() - assert list_page.items[0]['id'] == 1 - - -async def test_get_data(dataset: Dataset) -> None: # We don't test everything, that's done in memory storage tests - desired_item_count = 3 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) - list_page = await dataset.get_data() - assert list_page.count == desired_item_count - assert list_page.desc is False - assert list_page.offset == 0 - assert list_page.items[0]['id'] == 0 - assert list_page.items[-1]['id'] == desired_item_count - 1 + # Verify dataset was removed from cache + assert dataset.id not in Dataset._cache_by_id + if dataset.name: + assert dataset.name not in Dataset._cache_by_name + + # Verify dataset is empty (by creating a new one with the same name) + new_dataset = await Dataset.open( + name='drop_test', + storage_dir=tmp_path, + storage_client=storage_client, + configuration=configuration, + ) -async def test_iterate_items(dataset: Dataset) -> None: - desired_item_count = 3 - idx = 0 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) + result = await new_dataset.get_data() + assert result.count == 0 + await new_dataset.drop() + + +async def test_export_to_json( + dataset: Dataset, + storage_client: StorageClient, + tmp_path: Path, +) -> None: + """Test exporting dataset to JSON format.""" + # Create a key-value store for export + kvs = await KeyValueStore.open( + name='export_kvs', + storage_dir=tmp_path, + storage_client=storage_client, + ) + + # Add some items to the dataset + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Export to JSON + await dataset.export_to( + key='dataset_export.json', + content_type='json', + to_key_value_store_name='export_kvs', + ) - async for item in dataset.iterate_items(): - assert item['id'] == idx - idx += 1 + # Retrieve the exported file + record = await kvs.get_value(key='dataset_export.json') + assert record is not None + + # Verify content has all the items + assert '"id": 1' in record + assert '"id": 2' in record + assert '"id": 3' in record + + await kvs.drop() + + +async def test_export_to_csv( + dataset: Dataset, + storage_client: StorageClient, + tmp_path: Path, +) -> None: + """Test exporting dataset to CSV format.""" + # Create a key-value store for export + kvs = await KeyValueStore.open( + name='export_kvs', + storage_dir=tmp_path, + storage_client=storage_client, + ) + + # Add some items to the dataset + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Export to CSV + await dataset.export_to( + key='dataset_export.csv', + content_type='csv', + to_key_value_store_name='export_kvs', + ) - assert idx == desired_item_count + # Retrieve the exported file + record = await kvs.get_value(key='dataset_export.csv') + assert record is not None + # Verify content has all the items + assert 'id,name' in record + assert '1,Item 1' in record + assert '2,Item 2' in record + assert '3,Item 3' in record -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() + await kvs.drop() - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', - ) - dataset = Dataset.from_storage_object(storage_client, storage_object) +async def test_export_to_invalid_content_type(dataset: Dataset) -> None: + """Test exporting dataset with invalid content type raises error.""" + with pytest.raises(ValueError, match='Unsupported content type'): + await dataset.export_to( + key='invalid_export', + content_type='invalid', # type: ignore[call-overload] # Intentionally invalid content type + ) + + +async def test_large_dataset(dataset: Dataset) -> None: + """Test handling a large dataset with many items.""" + items = [{'id': i, 'value': f'value-{i}'} for i in range(100)] + await dataset.push_data(items) + + # Test that all items are retrieved + result = await dataset.get_data(limit=None) + assert result.count == 100 + assert result.total == 100 - assert dataset.id == storage_object.id - assert dataset.name == storage_object.name - assert dataset.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] + # Test pagination with large datasets + result = await dataset.get_data(offset=50, limit=25) + assert result.count == 25 + assert result.offset == 50 + assert result.items[0]['id'] == 50 + assert result.items[-1]['id'] == 74 diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index dc82c412c2..6312009f81 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -1,229 +1,339 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations -import asyncio -from datetime import datetime, timedelta, timezone -from itertools import chain, repeat -from typing import TYPE_CHECKING, cast -from unittest.mock import patch -from urllib.parse import urlparse +import json +from typing import TYPE_CHECKING import pytest -from crawlee import service_locator -from crawlee.events import EventManager -from crawlee.storage_clients.models import StorageMetadata +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient from crawlee.storages import KeyValueStore if TYPE_CHECKING: from collections.abc import AsyncGenerator + from pathlib import Path - from crawlee._types import JsonSerializable - - -@pytest.fixture -async def mock_event_manager() -> AsyncGenerator[EventManager, None]: - async with EventManager(persist_state_interval=timedelta(milliseconds=50)) as event_manager: - with patch('crawlee.service_locator.get_event_manager', return_value=event_manager): - yield event_manager - - -async def test_open() -> None: - default_key_value_store = await KeyValueStore.open() - default_key_value_store_by_id = await KeyValueStore.open(id=default_key_value_store.id) - - assert default_key_value_store is default_key_value_store_by_id + from crawlee.storage_clients import StorageClient - key_value_store_name = 'dummy-name' - named_key_value_store = await KeyValueStore.open(name=key_value_store_name) - assert default_key_value_store is not named_key_value_store +pytestmark = pytest.mark.only - with pytest.raises(RuntimeError, match='KeyValueStore with id "nonexistent-id" does not exist!'): - await KeyValueStore.open(id='nonexistent-id') - # Test that when you try to open a key-value store by ID and you use a name of an existing key-value store, - # it doesn't work - with pytest.raises(RuntimeError, match='KeyValueStore with id "dummy-name" does not exist!'): - await KeyValueStore.open(id='dummy-name') +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() + return FileSystemStorageClient() -async def test_open_save_storage_object() -> None: - default_key_value_store = await KeyValueStore.open() - assert default_key_value_store.storage_object is not None - assert default_key_value_store.storage_object.id == default_key_value_store.id - - -async def test_consistency_accross_two_clients() -> None: - kvs = await KeyValueStore.open(name='my-kvs') - await kvs.set_value('key', 'value') - - kvs_by_id = await KeyValueStore.open(id=kvs.id) - await kvs_by_id.set_value('key2', 'value2') +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration(crawlee_storage_dir=str(tmp_path)) # type: ignore[call-arg] - assert (await kvs.get_value('key')) == 'value' - assert (await kvs.get_value('key2')) == 'value2' - assert (await kvs_by_id.get_value('key')) == 'value' - assert (await kvs_by_id.get_value('key2')) == 'value2' +@pytest.fixture +async def kvs( + storage_client: StorageClient, + configuration: Configuration, + tmp_path: Path, +) -> AsyncGenerator[KeyValueStore, None]: + """Fixture that provides a key-value store instance for each test.""" + KeyValueStore._cache_by_id.clear() + KeyValueStore._cache_by_name.clear() + + kvs = await KeyValueStore.open( + name='test_kvs', + storage_dir=tmp_path, + storage_client=storage_client, + configuration=configuration, + ) + yield kvs await kvs.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await kvs_by_id.drop() - -async def test_same_references() -> None: - kvs1 = await KeyValueStore.open() - kvs2 = await KeyValueStore.open() - assert kvs1 is kvs2 - kvs_name = 'non-default' - kvs_named1 = await KeyValueStore.open(name=kvs_name) - kvs_named2 = await KeyValueStore.open(name=kvs_name) - assert kvs_named1 is kvs_named2 - - -async def test_drop() -> None: - kvs1 = await KeyValueStore.open() - await kvs1.drop() - kvs2 = await KeyValueStore.open() - assert kvs1 is not kvs2 - - -async def test_get_set_value(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-str', 'string') - await key_value_store.set_value('test-int', 123) - await key_value_store.set_value('test-dict', {'abc': '123'}) - str_value = await key_value_store.get_value('test-str') - int_value = await key_value_store.get_value('test-int') - dict_value = await key_value_store.get_value('test-dict') - non_existent_value = await key_value_store.get_value('test-non-existent') - assert str_value == 'string' - assert int_value == 123 - assert dict_value['abc'] == '123' - assert non_existent_value is None - - -async def test_for_each_key(key_value_store: KeyValueStore) -> None: - keys = [item.key async for item in key_value_store.iterate_keys()] - assert len(keys) == 0 +async def test_open_creates_new_kvs( + storage_client: StorageClient, + configuration: Configuration, + tmp_path: Path, +) -> None: + """Test that open() creates a new key-value store with proper metadata.""" + kvs = await KeyValueStore.open( + name='new_kvs', + storage_dir=tmp_path, + storage_client=storage_client, + configuration=configuration, + ) - for i in range(2001): - await key_value_store.set_value(str(i).zfill(4), i) - index = 0 - async for item in key_value_store.iterate_keys(): - assert item.key == str(index).zfill(4) - index += 1 - assert index == 2001 + # Verify key-value store properties + assert kvs.id is not None + assert kvs.name == 'new_kvs' + await kvs.drop() -async def test_static_get_set_value(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-static', 'static') - value = await key_value_store.get_value('test-static') - assert value == 'static' +async def test_open_existing_kvs( + kvs: KeyValueStore, + storage_client: StorageClient, + tmp_path: Path, +) -> None: + """Test that open() loads an existing key-value store correctly.""" + # Open the same key-value store again + reopened_kvs = await KeyValueStore.open( + name=kvs.name, + storage_dir=tmp_path, + storage_client=storage_client, + ) -async def test_get_public_url_raises_for_non_existing_key(key_value_store: KeyValueStore) -> None: - with pytest.raises(ValueError, match='was not found'): - await key_value_store.get_public_url('i-do-not-exist') + # Verify key-value store properties + assert kvs.id == reopened_kvs.id + assert kvs.name == reopened_kvs.name + # Verify they are the same object (from cache) + assert id(kvs) == id(reopened_kvs) -async def test_get_public_url(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-static', 'static') - public_url = await key_value_store.get_public_url('test-static') - url = urlparse(public_url) - path = url.netloc if url.netloc else url.path +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, + tmp_path: Path, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await KeyValueStore.open( + id='some-id', + name='some-name', + storage_dir=tmp_path, + storage_client=storage_client, + configuration=configuration, + ) - with open(path) as f: - content = await asyncio.to_thread(f.read) - assert content == 'static' +async def test_set_get_value(kvs: KeyValueStore) -> None: + """Test setting and getting a value from the key-value store.""" + # Set a value + test_key = 'test-key' + test_value = {'data': 'value', 'number': 42} + await kvs.set_value(test_key, test_value) -async def test_get_auto_saved_value_default_value(key_value_store: KeyValueStore) -> None: - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - value = await key_value_store.get_auto_saved_value('state', default_value) - assert value == default_value + # Get the value + result = await kvs.get_value(test_key) + assert result == test_value -async def test_get_auto_saved_value_cache_value(key_value_store: KeyValueStore) -> None: - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - key_name = 'state' +async def test_get_value_nonexistent(kvs: KeyValueStore) -> None: + """Test getting a nonexistent value returns None.""" + result = await kvs.get_value('nonexistent-key') + assert result is None - value = await key_value_store.get_auto_saved_value(key_name, default_value) - value['hello'] = 'new_world' - value_one = await key_value_store.get_auto_saved_value(key_name) - assert value_one == {'hello': 'new_world'} - value_one['hello'] = ['new_world'] - value_two = await key_value_store.get_auto_saved_value(key_name) - assert value_two == {'hello': ['new_world']} +async def test_get_value_with_default(kvs: KeyValueStore) -> None: + """Test getting a nonexistent value with a default value.""" + default_value = {'default': True} + result = await kvs.get_value('nonexistent-key', default_value=default_value) + assert result == default_value -async def test_get_auto_saved_value_auto_save(key_value_store: KeyValueStore, mock_event_manager: EventManager) -> None: # noqa: ARG001 - # This is not a realtime system and timing constrains can be hard to enforce. - # For the test to avoid flakiness it needs some time tolerance. - autosave_deadline_time = 1 - autosave_check_period = 0.01 +async def test_set_value_with_content_type(kvs: KeyValueStore) -> None: + """Test setting a value with a specific content type.""" + test_key = 'test-json' + test_value = {'data': 'value', 'items': [1, 2, 3]} + await kvs.set_value(test_key, test_value, content_type='application/json') - async def autosaved_within_deadline(key: str, expected_value: dict[str, str]) -> bool: - """Check if the `key_value_store` of `key` has expected value within `autosave_deadline_time` seconds.""" - deadline = datetime.now(tz=timezone.utc) + timedelta(seconds=autosave_deadline_time) - while datetime.now(tz=timezone.utc) < deadline: - await asyncio.sleep(autosave_check_period) - if await key_value_store.get_value(key) == expected_value: - return True - return False + # Verify the value is retrievable + result = await kvs.get_value(test_key) + assert result == test_value - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - key_name = 'state' - value = await key_value_store.get_auto_saved_value(key_name, default_value) - assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'world'}) - value['hello'] = 'new_world' - assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'new_world'}) +async def test_delete_value(kvs: KeyValueStore) -> None: + """Test deleting a value from the key-value store.""" + # Set a value first + test_key = 'delete-me' + test_value = 'value to delete' + await kvs.set_value(test_key, test_value) + # Verify value exists + assert await kvs.get_value(test_key) == test_value -async def test_get_auto_saved_value_auto_save_race_conditions(key_value_store: KeyValueStore) -> None: - """Two parallel functions increment global variable obtained by `get_auto_saved_value`. + # Delete the value + await kvs.delete_value(test_key) - Result should be incremented by 2. - Method `get_auto_saved_value` must be implemented in a way that prevents race conditions in such scenario. - Test creates situation where first `get_auto_saved_value` call to kvs gets delayed. Such situation can happen - and unless handled, it can cause race condition in getting the state value.""" - await key_value_store.set_value('state', {'counter': 0}) + # Verify value is gone + assert await kvs.get_value(test_key) is None - sleep_time_iterator = chain(iter([0.5]), repeat(0)) - async def delayed_get_value(key: str, default_value: None) -> None: - await asyncio.sleep(next(sleep_time_iterator)) - return await KeyValueStore.get_value(key_value_store, key=key, default_value=default_value) +async def test_list_keys_empty_kvs(kvs: KeyValueStore) -> None: + """Test listing keys from an empty key-value store.""" + keys = await kvs.list_keys() + assert len(keys) == 0 - async def increment_counter() -> None: - state = cast('dict[str, int]', await key_value_store.get_auto_saved_value('state')) - state['counter'] += 1 - with patch.object(key_value_store, 'get_value', delayed_get_value): - tasks = [asyncio.create_task(increment_counter()), asyncio.create_task(increment_counter())] - await asyncio.gather(*tasks) +async def test_list_keys(kvs: KeyValueStore) -> None: + """Test listing keys from a key-value store with items.""" + # Add some items + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + + # List keys + keys = await kvs.list_keys() + + # Verify keys + assert len(keys) == 3 + key_names = [k.key for k in keys] + assert 'key1' in key_names + assert 'key2' in key_names + assert 'key3' in key_names + + +async def test_list_keys_with_limit(kvs: KeyValueStore) -> None: + """Test listing keys with a limit parameter.""" + # Add some items + for i in range(10): + await kvs.set_value(f'key{i}', f'value{i}') + + # List with limit + keys = await kvs.list_keys(limit=5) + assert len(keys) == 5 + + +async def test_list_keys_with_exclusive_start_key(kvs: KeyValueStore) -> None: + """Test listing keys with an exclusive start key.""" + # Add some items in a known order + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + await kvs.set_value('key4', 'value4') + await kvs.set_value('key5', 'value5') + + # Get all keys first to determine their order + all_keys = await kvs.list_keys() + all_key_names = [k.key for k in all_keys] + + if len(all_key_names) >= 3: + # Start from the second key + start_key = all_key_names[1] + keys = await kvs.list_keys(exclusive_start_key=start_key) + + # We should get all keys after the start key + expected_count = len(all_key_names) - all_key_names.index(start_key) - 1 + assert len(keys) == expected_count + + # First key should be the one after start_key + first_returned_key = keys[0].key + assert first_returned_key != start_key + assert all_key_names.index(first_returned_key) > all_key_names.index(start_key) + + +async def test_iterate_keys(kvs: KeyValueStore) -> None: + """Test iterating over keys in the key-value store.""" + # Add some items + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + + collected_keys = [key async for key in kvs.iterate_keys()] + + # Verify iteration result + assert len(collected_keys) == 3 + key_names = [k.key for k in collected_keys] + assert 'key1' in key_names + assert 'key2' in key_names + assert 'key3' in key_names + + +async def test_iterate_keys_with_limit(kvs: KeyValueStore) -> None: + """Test iterating over keys with a limit parameter.""" + # Add some items + for i in range(10): + await kvs.set_value(f'key{i}', f'value{i}') + + collected_keys = [key async for key in kvs.iterate_keys(limit=5)] + + # Verify iteration result + assert len(collected_keys) == 5 + + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, + tmp_path: Path, +) -> None: + """Test dropping a key-value store removes it from cache and clears its data.""" + kvs = await KeyValueStore.open( + name='drop_test', + storage_dir=tmp_path, + storage_client=storage_client, + configuration=configuration, + ) - assert (await key_value_store.get_auto_saved_value('state'))['counter'] == 2 + # Add some data + await kvs.set_value('test', 'data') + # Verify key-value store exists in cache + assert kvs.id in KeyValueStore._cache_by_id + if kvs.name: + assert kvs.name in KeyValueStore._cache_by_name -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() + # Drop the key-value store + await kvs.drop() - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', + # Verify key-value store was removed from cache + assert kvs.id not in KeyValueStore._cache_by_id + if kvs.name: + assert kvs.name not in KeyValueStore._cache_by_name + + # Verify key-value store is empty (by creating a new one with the same name) + new_kvs = await KeyValueStore.open( + name='drop_test', + storage_dir=tmp_path, + storage_client=storage_client, + configuration=configuration, ) - key_value_store = KeyValueStore.from_storage_object(storage_client, storage_object) - - assert key_value_store.id == storage_object.id - assert key_value_store.name == storage_object.name - assert key_value_store.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] + # Attempt to get a previously stored value + result = await new_kvs.get_value('test') + assert result is None + await new_kvs.drop() + + +async def test_complex_data_types(kvs: KeyValueStore) -> None: + """Test storing and retrieving complex data types.""" + # Test nested dictionaries + nested_dict = { + 'level1': { + 'level2': { + 'level3': 'deep value', + 'numbers': [1, 2, 3], + }, + }, + 'array': [{'a': 1}, {'b': 2}], + } + await kvs.set_value('nested', nested_dict) + result = await kvs.get_value('nested') + assert result == nested_dict + + # Test lists + test_list = [1, 'string', True, None, {'key': 'value'}] + await kvs.set_value('list', test_list) + result = await kvs.get_value('list') + assert result == test_list + + +async def test_string_data(kvs: KeyValueStore) -> None: + """Test storing and retrieving string data.""" + # Plain string + await kvs.set_value('string', 'simple string') + result = await kvs.get_value('string') + assert result == 'simple string' + + # JSON string + json_string = json.dumps({'key': 'value'}) + await kvs.set_value('json_string', json_string) + result = await kvs.get_value('json_string') + assert result == json_string From 55abd88c815d478e6be459954aa8e9b3ec82cd88 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Mon, 14 Apr 2025 19:25:18 +0200 Subject: [PATCH 12/22] Init of request queue and its clients --- .../storage_clients/_base/_dataset_client.py | 32 +-- .../_base/_key_value_store_client.py | 27 +-- .../_base/_request_queue_client.py | 108 +++------- .../storage_clients/_base/_storage_client.py | 6 +- .../storage_clients/_file_system/__init__.py | 10 +- .../_file_system/_dataset_client.py | 44 +--- .../_file_system/_key_value_store_client.py | 37 +--- .../_file_system/_request_queue_client.py | 202 +++++++++++++++++- .../_file_system/_storage_client.py | 8 +- .../storage_clients/_memory/__init__.py | 10 +- .../_memory/_dataset_client.py | 35 +-- .../_memory/_key_value_store_client.py | 44 ++-- .../_memory/_request_queue_client.py | 134 +++++++++++- src/crawlee/storage_clients/models.py | 13 +- src/crawlee/storages/_dataset.py | 16 +- src/crawlee/storages/_key_value_store.py | 14 +- src/crawlee/storages/_request_queue.py | 64 ++---- .../_file_system/test_fs_dataset_client.py | 52 ++--- .../_file_system/test_fs_kvs_client.py | 44 ++-- .../_file_system/test_fs_storage_client.py | 12 +- .../_memory/test_memory_dataset_client.py | 56 ++--- .../_memory/test_memory_kvs_client.py | 46 ++-- 22 files changed, 564 insertions(+), 450 deletions(-) diff --git a/src/crawlee/storage_clients/_base/_dataset_client.py b/src/crawlee/storage_clients/_base/_dataset_client.py index f9086f4d9c..dc20ce89f3 100644 --- a/src/crawlee/storage_clients/_base/_dataset_client.py +++ b/src/crawlee/storage_clients/_base/_dataset_client.py @@ -7,11 +7,10 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from datetime import datetime from pathlib import Path from typing import Any - from crawlee.storage_clients.models import DatasetItemsListPage + from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata @docs_group('Abstract classes') @@ -30,33 +29,8 @@ class DatasetClient(ABC): @property @abstractmethod - def id(self) -> str: - """The ID of the dataet, a unique identifier, typically a UUID or similar value.""" - - @property - @abstractmethod - def name(self) -> str | None: - """The optional human-readable name of the dataset.""" - - @property - @abstractmethod - def created_at(self) -> datetime: - """Timestamp when the dataset was first created, remains unchanged.""" - - @property - @abstractmethod - def accessed_at(self) -> datetime: - """Timestamp of last access to the dataset, updated on read or write operations.""" - - @property - @abstractmethod - def modified_at(self) -> datetime: - """Timestamp of last modification of the dataset, updated when new data are added.""" - - @property - @abstractmethod - def item_count(self) -> int: - """Total count of data items stored in the dataset.""" + def metadata(self) -> DatasetMetadata: + """The metadata of the dataset.""" @classmethod @abstractmethod diff --git a/src/crawlee/storage_clients/_base/_key_value_store_client.py b/src/crawlee/storage_clients/_base/_key_value_store_client.py index 50b7175745..7eb9160d16 100644 --- a/src/crawlee/storage_clients/_base/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_base/_key_value_store_client.py @@ -7,10 +7,9 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from datetime import datetime from pathlib import Path - from crawlee.storage_clients.models import KeyValueStoreRecord, KeyValueStoreRecordMetadata + from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata @docs_group('Abstract classes') @@ -29,28 +28,8 @@ class KeyValueStoreClient(ABC): @property @abstractmethod - def id(self) -> str: - """The unique identifier of the key-value store (typically a UUID).""" - - @property - @abstractmethod - def name(self) -> str | None: - """The optional human-readable name for the KVS.""" - - @property - @abstractmethod - def created_at(self) -> datetime: - """Timestamp when the KVS was first created, remains unchanged.""" - - @property - @abstractmethod - def accessed_at(self) -> datetime: - """Timestamp of last access to the KVS, updated on read or write operations.""" - - @property - @abstractmethod - def modified_at(self) -> datetime: - """Timestamp of last modification of the KVS, updated when new data are added, updated or deleted.""" + def metadata(self) -> KeyValueStoreMetadata: + """The metadata of the key-value store.""" @classmethod @abstractmethod diff --git a/src/crawlee/storage_clients/_base/_request_queue_client.py b/src/crawlee/storage_clients/_base/_request_queue_client.py index 0d7c2ddb45..8b0b11ef0a 100644 --- a/src/crawlee/storage_clients/_base/_request_queue_client.py +++ b/src/crawlee/storage_clients/_base/_request_queue_client.py @@ -1,22 +1,21 @@ from __future__ import annotations from abc import ABC, abstractmethod -from datetime import datetime from typing import TYPE_CHECKING from crawlee._utils.docs import docs_group if TYPE_CHECKING: from collections.abc import Sequence - from datetime import datetime + from pathlib import Path from crawlee.storage_clients.models import ( BatchRequestsOperationResponse, ProcessedRequest, ProlongRequestLockResponse, Request, - RequestQueueHead, RequestQueueHeadWithLocks, + RequestQueueMetadata, ) @@ -30,58 +29,29 @@ class RequestQueueClient(ABC): @property @abstractmethod - def id(self) -> str: - """The ID of the dataset.""" + def metadata(self) -> RequestQueueMetadata: + """The metadata of the request queue.""" - @property - @abstractmethod - def name(self) -> str | None: - """The name of the dataset.""" - - @property - @abstractmethod - def created_at(self) -> datetime: - """The time at which the dataset was created.""" - - @property - @abstractmethod - def accessed_at(self) -> datetime: - """The time at which the dataset was last accessed.""" - - @property - @abstractmethod - def modified_at(self) -> datetime: - """The time at which the dataset was last modified.""" - - @property + @classmethod @abstractmethod - def had_multiple_clients(self) -> bool: - """TODO.""" - - @property - @abstractmethod - def handled_request_count(self) -> int: - """TODO.""" - - @property - @abstractmethod - def pending_request_count(self) -> int: - """TODO.""" - - @property - @abstractmethod - def stats(self) -> dict: - """TODO.""" + async def open( + cls, + *, + id: str | None = None, + name: str | None = None, + storage_dir: Path | None = None, + ) -> RequestQueueClient: + """Open a request queue client. - @property - @abstractmethod - def total_request_count(self) -> int: - """TODO.""" + Args: + id: ID of the queue to open. If not provided, a new queue will be created with a random ID. + name: Name of the queue to open. If not provided, the queue will be unnamed. + purge_on_start: If True, the queue will be purged before opening. + storage_dir: Directory to store the queue data in. If not provided, uses the default storage directory. - @property - @abstractmethod - def resource_directory(self) -> str: - """TODO.""" + Returns: + A request queue client. + """ @abstractmethod async def drop(self) -> None: @@ -90,17 +60,6 @@ async def drop(self) -> None: The backend method for the `RequestQueue.drop` call. """ - @abstractmethod - async def list_head(self, *, limit: int | None = None) -> RequestQueueHead: - """Retrieve a given number of requests from the beginning of the queue. - - Args: - limit: How many requests to retrieve. - - Returns: - The desired number of requests from the beginning of the queue. - """ - @abstractmethod async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) -> RequestQueueHeadWithLocks: """Fetch and lock a specified number of requests from the start of the queue. @@ -117,33 +76,16 @@ async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) """ @abstractmethod - async def add_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - """Add a request to the queue. - - Args: - request: The request to add to the queue. - forefront: Whether to add the request to the head or the end of the queue. - - Returns: - Request queue operation information. - """ - - @abstractmethod - async def batch_add_requests( + async def add_requests_batch( self, requests: Sequence[Request], *, forefront: bool = False, ) -> BatchRequestsOperationResponse: - """Add a batch of requests to the queue. + """Add a requests to the queue in batches. Args: - requests: The requests to add to the queue. + requests: The batch of requests to add to the queue. forefront: Whether to add the requests to the head or the end of the queue. Returns: @@ -187,7 +129,7 @@ async def delete_request(self, request_id: str) -> None: """ @abstractmethod - async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: + async def delete_requests_batch(self, requests: list[Request]) -> BatchRequestsOperationResponse: """Delete given requests from the queue. Args: diff --git a/src/crawlee/storage_clients/_base/_storage_client.py b/src/crawlee/storage_clients/_base/_storage_client.py index 0d5a67c1fc..f85d5a5bb5 100644 --- a/src/crawlee/storage_clients/_base/_storage_client.py +++ b/src/crawlee/storage_clients/_base/_storage_client.py @@ -23,7 +23,7 @@ async def open_dataset_client( purge_on_start: bool = True, storage_dir: Path | None = None, ) -> DatasetClient: - """Open the dataset client.""" + """Open a dataset client.""" @abstractmethod async def open_key_value_store_client( @@ -34,7 +34,7 @@ async def open_key_value_store_client( purge_on_start: bool = True, storage_dir: Path | None = None, ) -> KeyValueStoreClient: - """Open the key-value store client.""" + """Open a key-value store client.""" @abstractmethod async def open_request_queue_client( @@ -45,4 +45,4 @@ async def open_request_queue_client( purge_on_start: bool = True, storage_dir: Path | None = None, ) -> RequestQueueClient: - """Open the request queue client.""" + """Open a request queue client.""" diff --git a/src/crawlee/storage_clients/_file_system/__init__.py b/src/crawlee/storage_clients/_file_system/__init__.py index bac1291176..2169896d86 100644 --- a/src/crawlee/storage_clients/_file_system/__init__.py +++ b/src/crawlee/storage_clients/_file_system/__init__.py @@ -1,3 +1,11 @@ +from ._dataset_client import FileSystemDatasetClient +from ._key_value_store_client import FileSystemKeyValueStoreClient +from ._request_queue_client import FileSystemRequestQueueClient from ._storage_client import FileSystemStorageClient -__all__ = ['FileSystemStorageClient'] +__all__ = [ + 'FileSystemDatasetClient', + 'FileSystemKeyValueStoreClient', + 'FileSystemRequestQueueClient', + 'FileSystemStorageClient', +] diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py index 2381f96cae..1bd27beef6 100644 --- a/src/crawlee/storage_clients/_file_system/_dataset_client.py +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -76,38 +76,13 @@ def __init__( @override @property - def id(self) -> str: - return self._metadata.id - - @override - @property - def name(self) -> str: - return self._metadata.name - - @override - @property - def created_at(self) -> datetime: - return self._metadata.created_at - - @override - @property - def accessed_at(self) -> datetime: - return self._metadata.accessed_at - - @override - @property - def modified_at(self) -> datetime: - return self._metadata.modified_at - - @override - @property - def item_count(self) -> int: - return self._metadata.item_count + def metadata(self) -> DatasetMetadata: + return self._metadata @property def path_to_dataset(self) -> Path: """The full path to the dataset directory.""" - return self._storage_dir / self._STORAGE_SUBDIR / self.name + return self._storage_dir / self._STORAGE_SUBDIR / self.metadata.name @property def path_to_metadata(self) -> Path: @@ -170,12 +145,13 @@ async def open( # Otherwise, create a new dataset client. else: + now = datetime.now(timezone.utc) client = cls( id=crypto_random_object_id(), name=name, - created_at=datetime.now(timezone.utc), - accessed_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), + created_at=now, + accessed_at=now, + modified_at=now, item_count=0, storage_dir=storage_dir, ) @@ -194,12 +170,12 @@ async def drop(self) -> None: await asyncio.to_thread(shutil.rmtree, self.path_to_dataset) # Remove the client from the cache. - if self.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.name] # noqa: SLF001 + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 @override async def push_data(self, data: list[Any] | dict[str, Any]) -> None: - new_item_count = self.item_count + new_item_count = self.metadata.item_count # If data is a list, push each item individually. if isinstance(data, list): diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py index 2c7dc61651..79c3d7102d 100644 --- a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -72,33 +72,13 @@ def __init__( @override @property - def id(self) -> str: - return self._metadata.id - - @override - @property - def name(self) -> str: - return self._metadata.name - - @override - @property - def created_at(self) -> datetime: - return self._metadata.created_at - - @override - @property - def accessed_at(self) -> datetime: - return self._metadata.accessed_at - - @override - @property - def modified_at(self) -> datetime: - return self._metadata.modified_at + def metadata(self) -> KeyValueStoreMetadata: + return self._metadata @property def path_to_kvs(self) -> Path: """The full path to the key-value store directory.""" - return self._storage_dir / self._STORAGE_SUBDIR / self.name + return self._storage_dir / self._STORAGE_SUBDIR / self.metadata.name @property def path_to_metadata(self) -> Path: @@ -158,12 +138,13 @@ async def open( # Otherwise, create a new key-value store client. else: + now = datetime.now(timezone.utc) client = cls( id=crypto_random_object_id(), name=name, - created_at=datetime.now(timezone.utc), - accessed_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), + created_at=now, + accessed_at=now, + modified_at=now, storage_dir=storage_dir, ) await client._update_metadata() @@ -181,8 +162,8 @@ async def drop(self) -> None: await asyncio.to_thread(shutil.rmtree, self.path_to_kvs) # Remove the client from the cache. - if self.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.name] # noqa: SLF001 + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py index f8a6bfe88e..fd9866e17f 100644 --- a/src/crawlee/storage_clients/_file_system/_request_queue_client.py +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -1,11 +1,211 @@ from __future__ import annotations +import asyncio +import json +import shutil +from datetime import datetime, timezone from logging import getLogger +from pathlib import Path +from typing import ClassVar +from pydantic import ValidationError +from typing_extensions import override + +from crawlee._utils.crypto import crypto_random_object_id from crawlee.storage_clients._base import RequestQueueClient +from crawlee.storage_clients.models import RequestQueueMetadata + +from ._utils import METADATA_FILENAME, json_dumps logger = getLogger(__name__) class FileSystemRequestQueueClient(RequestQueueClient): - pass + """A file system implementation of the request queue client. + + This client persists requests to the file system, making it suitable for scenarios where data needs + to survive process restarts. Each request is stored as a separate file, allowing for proper request + handling and tracking across crawler runs. + """ + + _DEFAULT_NAME = 'default' + """The default name for the unnamed request queue.""" + + _STORAGE_SUBDIR = 'request_queues' + """The name of the subdirectory where request queues are stored.""" + + _cache_by_name: ClassVar[dict[str, FileSystemRequestQueueClient]] = {} + """A dictionary to cache clients by their names.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + had_multiple_clients: bool, + handled_request_count: int, + pending_request_count: int, + stats: dict, + total_request_count: int, + storage_dir: Path, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemRequestQueueClient.open` class method to create a new instance. + """ + self._metadata = RequestQueueMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + had_multiple_clients=had_multiple_clients, + handled_request_count=handled_request_count, + pending_request_count=pending_request_count, + stats=stats, + total_request_count=total_request_count, + ) + + self._storage_dir = storage_dir + + # Internal attributes + self._lock = asyncio.Lock() + """A lock to ensure that only one file operation is performed at a time.""" + + @override + @property + def metadata(self) -> RequestQueueMetadata: + return self._metadata + + @property + def path_to_rq(self) -> Path: + """The full path to the request queue directory.""" + return self._storage_dir / self._STORAGE_SUBDIR / self.metadata.name + + @property + def path_to_metadata(self) -> Path: + """The full path to the request queue metadata file.""" + return self.path_to_rq / METADATA_FILENAME + + @override + @classmethod + async def open( + cls, + *, + id: str | None = None, + name: str | None = None, + storage_dir: Path | None = None, + ) -> FileSystemRequestQueueClient: + if id: + raise ValueError( + 'Opening a dataset by "id" is not supported for file system storage client, use "name" instead.' + ) + + name = name or cls._DEFAULT_NAME + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + + storage_dir = storage_dir or Path.cwd() + rq_path = storage_dir / cls._STORAGE_SUBDIR / name + metadata_path = rq_path / METADATA_FILENAME + + # If the RQ directory exists, reconstruct the client from the metadata file. + if rq_path.exists(): + # If metadata file is missing, raise an error. + if not metadata_path.exists(): + raise ValueError(f'Metadata file not found for RQ "{name}"') + + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = RequestQueueMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for RQ "{name}"') from exc + + client = cls( + id=metadata.id, + name=name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + had_multiple_clients=metadata.had_multiple_clients, + handled_request_count=metadata.handled_request_count, + pending_request_count=metadata.pending_request_count, + stats=metadata.stats, + total_request_count=metadata.total_request_count, + storage_dir=storage_dir, + ) + + await client._update_metadata(update_accessed_at=True) + + # Otherwise, create a new dataset client. + else: + now = datetime.now(timezone.utc) + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + had_multiple_clients=False, + handled_request_count=0, + pending_request_count=0, + stats={}, + total_request_count=0, + storage_dir=storage_dir, + ) + await client._update_metadata() + + # Cache the client by name. + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + # If the client directory exists, remove it recursively. + if self.path_to_rq.exists(): + async with self._lock: + await asyncio.to_thread(shutil.rmtree, self.path_to_rq) + + # Remove the client from the cache. + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + # TODO: other methods + + async def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the dataset metadata file with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(self._metadata.model_dump()) + await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') diff --git a/src/crawlee/storage_clients/_file_system/_storage_client.py b/src/crawlee/storage_clients/_file_system/_storage_client.py index 5f5dc95abf..8679e67c2f 100644 --- a/src/crawlee/storage_clients/_file_system/_storage_client.py +++ b/src/crawlee/storage_clients/_file_system/_storage_client.py @@ -60,4 +60,10 @@ async def open_request_queue_client( purge_on_start: bool = True, storage_dir: Path | None = None, ) -> FileSystemRequestQueueClient: - pass + client = await FileSystemRequestQueueClient.open(id=id, name=name, storage_dir=storage_dir) + + if purge_on_start: + await client.drop() + client = await FileSystemRequestQueueClient.open(id=id, name=name, storage_dir=storage_dir) + + return client diff --git a/src/crawlee/storage_clients/_memory/__init__.py b/src/crawlee/storage_clients/_memory/__init__.py index 0d117a8a6c..3746907b4f 100644 --- a/src/crawlee/storage_clients/_memory/__init__.py +++ b/src/crawlee/storage_clients/_memory/__init__.py @@ -1,3 +1,11 @@ +from ._dataset_client import MemoryDatasetClient +from ._key_value_store_client import MemoryKeyValueStoreClient +from ._request_queue_client import MemoryRequestQueueClient from ._storage_client import MemoryStorageClient -__all__ = ['MemoryStorageClient'] +__all__ = [ + 'MemoryDatasetClient', + 'MemoryKeyValueStoreClient', + 'MemoryRequestQueueClient', + 'MemoryStorageClient', +] diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 684c384bd5..f067d6dcd0 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -59,33 +59,8 @@ def __init__( @override @property - def id(self) -> str: - return self._metadata.id - - @override - @property - def name(self) -> str: - return self._metadata.name - - @override - @property - def created_at(self) -> datetime: - return self._metadata.created_at - - @override - @property - def accessed_at(self) -> datetime: - return self._metadata.accessed_at - - @override - @property - def modified_at(self) -> datetime: - return self._metadata.modified_at - - @override - @property - def item_count(self) -> int: - return self._metadata.item_count + def metadata(self) -> DatasetMetadata: + return self._metadata @override @classmethod @@ -130,12 +105,12 @@ async def drop(self) -> None: self._metadata.item_count = 0 # Remove the client from the cache - if self.name in self.__class__._cache_by_name: - del self.__class__._cache_by_name[self.name] + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 @override async def push_data(self, data: list[Any] | dict[str, Any]) -> None: - new_item_count = self.item_count + new_item_count = self.metadata.item_count if isinstance(data, list): for item in data: diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index 0bb9651ee9..2240734930 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -55,32 +55,12 @@ def __init__( ) # Dictionary to hold key-value records with metadata - self._store = dict[str, KeyValueStoreRecord]() + self._records = dict[str, KeyValueStoreRecord]() @override @property - def id(self) -> str: - return self._metadata.id - - @override - @property - def name(self) -> str: - return self._metadata.name - - @override - @property - def created_at(self) -> datetime: - return self._metadata.created_at - - @override - @property - def accessed_at(self) -> datetime: - return self._metadata.accessed_at - - @override - @property - def modified_at(self) -> datetime: - return self._metadata.modified_at + def metadata(self) -> KeyValueStoreMetadata: + return self._metadata @override @classmethod @@ -122,18 +102,18 @@ async def open( @override async def drop(self) -> None: # Clear all data - self._store.clear() + self._records.clear() # Remove from cache - if self.name in self.__class__._cache_by_name: - del self.__class__._cache_by_name[self.name] + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: await self._update_metadata(update_accessed_at=True) # Return None if key doesn't exist - return self._store.get(key, None) + return self._records.get(key, None) @override async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: @@ -148,14 +128,14 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No size=size, ) - self._store[key] = record + self._records[key] = record await self._update_metadata(update_accessed_at=True, update_modified_at=True) @override async def delete_value(self, *, key: str) -> None: - if key in self._store: - del self._store[key] + if key in self._records: + del self._records[key] await self._update_metadata(update_accessed_at=True, update_modified_at=True) @override @@ -168,7 +148,7 @@ async def iterate_keys( await self._update_metadata(update_accessed_at=True) # Get all keys, sorted alphabetically - keys = sorted(self._store.keys()) + keys = sorted(self._records.keys()) # Apply exclusive_start_key filter if provided if exclusive_start_key is not None: @@ -180,7 +160,7 @@ async def iterate_keys( # Yield metadata for each key for key in keys: - record = self._store[key] + record = self._records[key] yield KeyValueStoreRecordMetadata( key=key, content_type=record.content_type, diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index d31c0602a0..293fa5b88c 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -1,11 +1,143 @@ from __future__ import annotations +from datetime import datetime, timezone from logging import getLogger +from typing import TYPE_CHECKING, ClassVar +from typing_extensions import override + +from crawlee import Request +from crawlee._utils.crypto import crypto_random_object_id from crawlee.storage_clients._base import RequestQueueClient +from crawlee.storage_clients.models import RequestQueueMetadata + +if TYPE_CHECKING: + from pathlib import Path logger = getLogger(__name__) class MemoryRequestQueueClient(RequestQueueClient): - pass + """A memory implementation of the request queue client. + + This client stores requests in memory using a list. No data is persisted, which means + all requests are lost when the process terminates. This implementation is mainly useful + for testing and development purposes where persistence is not required. + """ + + _DEFAULT_NAME = 'default' + """The default name for the dataset when no name is provided.""" + + _cache_by_name: ClassVar[dict[str, MemoryRequestQueueClient]] = {} + """A dictionary to cache clients by their names.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + had_multiple_clients: bool, + handled_request_count: int, + pending_request_count: int, + stats: dict, + total_request_count: int, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemRequestQueueClient.open` class method to create a new instance. + """ + self._metadata = RequestQueueMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + had_multiple_clients=had_multiple_clients, + handled_request_count=handled_request_count, + pending_request_count=pending_request_count, + stats=stats, + total_request_count=total_request_count, + ) + + # List to hold RQ items + self._records = list[Request]() + + @override + @property + def metadata(self) -> RequestQueueMetadata: + return self._metadata + + @override + @classmethod + async def open( + cls, + *, + id: str | None = None, + name: str | None = None, + storage_dir: Path | None = None, + ) -> MemoryRequestQueueClient: + if storage_dir is not None: + logger.warning('The `storage_dir` argument is not used in the memory request queue client.') + + name = name or cls._DEFAULT_NAME + + # Check if the client is already cached by name + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client + + # If specific id is provided, use it; otherwise, generate a new one + id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + had_multiple_clients=False, + handled_request_count=0, + pending_request_count=0, + stats={}, + total_request_count=0, + ) + + # Cache the client by name + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + # Clear all data + self._records.clear() + + # Remove from cache + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + # TODO: other methods + + async def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the request queue metadata with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now diff --git a/src/crawlee/storage_clients/models.py b/src/crawlee/storage_clients/models.py index 2887492885..c470028fd4 100644 --- a/src/crawlee/storage_clients/models.py +++ b/src/crawlee/storage_clients/models.py @@ -59,7 +59,6 @@ class RequestQueueMetadata(StorageMetadata): pending_request_count: Annotated[int, Field(alias='pendingRequestCount')] stats: Annotated[dict, Field(alias='stats')] total_request_count: Annotated[int, Field(alias='totalRequestCount')] - resource_directory: Annotated[str, Field(alias='resourceDirectory')] @docs_group('Data structures') @@ -122,23 +121,17 @@ class RequestQueueHeadState(BaseModel): @docs_group('Data structures') -class RequestQueueHead(BaseModel): - """Model for the request queue head.""" +class RequestQueueHeadWithLocks(BaseModel): + """Model for request queue head with locks.""" model_config = ConfigDict(populate_by_name=True) limit: Annotated[int | None, Field(alias='limit', default=None)] had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] queue_modified_at: Annotated[datetime, Field(alias='queueModifiedAt')] - items: Annotated[list[Request], Field(alias='items', default_factory=list)] - - -@docs_group('Data structures') -class RequestQueueHeadWithLocks(RequestQueueHead): - """Model for request queue head with locks.""" - lock_secs: Annotated[int, Field(alias='lockSecs')] queue_has_locked_requests: Annotated[bool | None, Field(alias='queueHasLockedRequests')] = None + items: Annotated[list[Request], Field(alias='items', default_factory=list)] class _ListPage(BaseModel): diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 099b8ffcf4..ea386c082d 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -10,7 +10,6 @@ from crawlee import service_locator from crawlee._utils.docs import docs_group from crawlee._utils.file import export_csv_to_stream, export_json_to_stream -from crawlee.storage_clients.models import DatasetMetadata from ._base import Storage from ._key_value_store import KeyValueStore @@ -24,7 +23,7 @@ from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient from crawlee.storage_clients._base import DatasetClient - from crawlee.storage_clients.models import DatasetItemsListPage + from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata from ._types import ExportDataCsvKwargs, ExportDataJsonKwargs @@ -83,24 +82,17 @@ def __init__(self, client: DatasetClient) -> None: @override @property def id(self) -> str: - return self._client.id + return self._client.metadata.id @override @property def name(self) -> str | None: - return self._client.name + return self._client.metadata.name @override @property def metadata(self) -> DatasetMetadata: - return DatasetMetadata( - id=self._client.id, - name=self._client.id, - accessed_at=self._client.accessed_at, - created_at=self._client.created_at, - modified_at=self._client.modified_at, - item_count=self._client.item_count, - ) + return self._client.metadata @override @classmethod diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 5bbe15e08d..54264b754e 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -7,7 +7,6 @@ from crawlee import service_locator from crawlee._utils.docs import docs_group -from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata from ._base import Storage @@ -17,6 +16,7 @@ from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient from crawlee.storage_clients._base import KeyValueStoreClient + from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata T = TypeVar('T') @@ -75,23 +75,17 @@ def __init__(self, client: KeyValueStoreClient) -> None: @override @property def id(self) -> str: - return self._client.id + return self._client.metadata.id @override @property def name(self) -> str | None: - return self._client.name + return self._client.metadata.name @override @property def metadata(self) -> KeyValueStoreMetadata: - return KeyValueStoreMetadata( - id=self._client.id, - name=self._client.id, - accessed_at=self._client.accessed_at, - created_at=self._client.created_at, - modified_at=self._client.modified_at, - ) + return self._client.metadata @override @classmethod diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index fd5dd017b2..3e3a65b2d3 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -10,7 +10,7 @@ from crawlee import service_locator from crawlee._utils.docs import docs_group from crawlee.request_loaders import RequestManager -from crawlee.storage_clients.models import Request, RequestQueueMetadata +from crawlee.storage_clients.models import Request from ._base import Storage @@ -21,39 +21,12 @@ from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient from crawlee.storage_clients._base import RequestQueueClient - from crawlee.storage_clients.models import ProcessedRequest + from crawlee.storage_clients.models import ProcessedRequest, RequestQueueMetadata logger = getLogger(__name__) T = TypeVar('T') -# TODO: implement: -# - caching / memoization of both KVS & KVS clients - -# Properties: -# - id -# - name -# - metadata - -# Methods -# - open -# - drop -# - add_request -# - add_requests_batched -# - get_handled_count -# - get_total_count -# - get_request -# - fetch_next_request -# - mark_request_as_handled -# - reclaim_request -# - is_empty -# - is_finished - -# Breaking changes: -# - from_storage_object method has been removed - Use the open method with name and/or id instead. -# - get_info -> metadata property -# - storage_object -> metadata property - @docs_group('Classes') class RequestQueue(Storage, RequestManager): @@ -104,29 +77,17 @@ def __init__(self, client: RequestQueueClient) -> None: @override @property def id(self) -> str: - return self._client.id + return self._client.metadata.id @override @property def name(self) -> str | None: - return self._client.name + return self._client.metadata.name @override @property def metadata(self) -> RequestQueueMetadata: - return RequestQueueMetadata( - id=self._client.id, - name=self._client.id, - accessed_at=self._client.accessed_at, - created_at=self._client.created_at, - modified_at=self._client.modified_at, - had_multiple_clients=self._client.had_multiple_clients, - handled_request_count=self._client.handled_request_count, - pending_request_count=self._client.pending_request_count, - stats=self._client.stats, - total_request_count=self._client.total_request_count, - resource_directory=self._client.resource_directory, - ) + return self._client.metadata @override @classmethod @@ -158,7 +119,7 @@ async def open( return cls(client) @override - async def drop(self, *, timeout: timedelta | None = None) -> None: + async def drop(self) -> None: await self._client.drop() @override @@ -168,7 +129,7 @@ async def add_request( *, forefront: bool = False, ) -> ProcessedRequest: - return await self._client.add_request(request, forefront=forefront) + return await self._client.add_requests_batch([request], forefront=forefront) @override async def add_requests_batched( @@ -263,6 +224,17 @@ async def fetch_next_request(self) -> Request | None: """ # TODO: implement + async def get_request(self, request_id: str) -> Request | None: + """Retrieve a request by its ID. + + Args: + request_id: The ID of the request to retrieve. + + Returns: + The request if found, otherwise `None`. + """ + return await self._client.get_request(request_id) + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: """Mark a request as handled after successful processing. diff --git a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py index 0570b4db70..0368b517f4 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py @@ -9,7 +9,7 @@ import pytest from crawlee._consts import METADATA_FILENAME -from crawlee.storage_clients._file_system._dataset_client import FileSystemDatasetClient +from crawlee.storage_clients._file_system import FileSystemDatasetClient from crawlee.storage_clients.models import DatasetItemsListPage if TYPE_CHECKING: @@ -34,12 +34,12 @@ async def test_open_creates_new_dataset(tmp_path: Path) -> None: client = await FileSystemDatasetClient.open(name='new_dataset', storage_dir=tmp_path) # Verify client properties - assert client.id is not None - assert client.name == 'new_dataset' - assert client.item_count == 0 - assert isinstance(client.created_at, datetime) - assert isinstance(client.accessed_at, datetime) - assert isinstance(client.modified_at, datetime) + assert client.metadata.id is not None + assert client.metadata.name == 'new_dataset' + assert client.metadata.item_count == 0 + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) # Verify files were created assert client.path_to_dataset.exists() @@ -48,7 +48,7 @@ async def test_open_creates_new_dataset(tmp_path: Path) -> None: # Verify metadata content with client.path_to_metadata.open() as f: metadata = json.load(f) - assert metadata['id'] == client.id + assert metadata['id'] == client.metadata.id assert metadata['name'] == 'new_dataset' assert metadata['item_count'] == 0 @@ -56,12 +56,12 @@ async def test_open_creates_new_dataset(tmp_path: Path) -> None: async def test_open_existing_dataset(dataset_client: FileSystemDatasetClient, tmp_path: Path) -> None: """Test that open() loads an existing dataset correctly.""" # Open the same dataset again - reopened_client = await FileSystemDatasetClient.open(name=dataset_client.name, storage_dir=tmp_path) + reopened_client = await FileSystemDatasetClient.open(name=dataset_client.metadata.name, storage_dir=tmp_path) # Verify client properties - assert dataset_client.id == reopened_client.id - assert dataset_client.name == reopened_client.name - assert dataset_client.item_count == reopened_client.item_count + assert dataset_client.metadata.id == reopened_client.metadata.id + assert dataset_client.metadata.name == reopened_client.metadata.name + assert dataset_client.metadata.item_count == reopened_client.metadata.item_count # Verify clients (python) ids assert id(dataset_client) == id(reopened_client) @@ -79,7 +79,7 @@ async def test_push_data_single_item(dataset_client: FileSystemDatasetClient) -> await dataset_client.push_data(item) # Verify item count was updated - assert dataset_client.item_count == 1 + assert dataset_client.metadata.item_count == 1 all_files = list(dataset_client.path_to_dataset.glob('*.json')) assert len(all_files) == 2 # 1 data file + 1 metadata file @@ -100,7 +100,7 @@ async def test_push_data_multiple_items(dataset_client: FileSystemDatasetClient) await dataset_client.push_data(items) # Verify item count was updated - assert dataset_client.item_count == 3 + assert dataset_client.metadata.item_count == 3 all_files = list(dataset_client.path_to_dataset.glob('*.json')) assert len(all_files) == 4 # 3 data files + 1 metadata file @@ -239,22 +239,22 @@ async def test_drop(dataset_client: FileSystemDatasetClient) -> None: """Test dropping a dataset removes the entire dataset directory from disk.""" await dataset_client.push_data({'test': 'data'}) - assert dataset_client.name in FileSystemDatasetClient._cache_by_name + assert dataset_client.metadata.name in FileSystemDatasetClient._cache_by_name assert dataset_client.path_to_dataset.exists() # Drop the dataset await dataset_client.drop() - assert dataset_client.name not in FileSystemDatasetClient._cache_by_name + assert dataset_client.metadata.name not in FileSystemDatasetClient._cache_by_name assert not dataset_client.path_to_dataset.exists() async def test_metadata_updates(dataset_client: FileSystemDatasetClient) -> None: """Test that metadata timestamps are updated correctly after read and write operations.""" # Record initial timestamps - initial_created = dataset_client.created_at - initial_accessed = dataset_client.accessed_at - initial_modified = dataset_client.modified_at + initial_created = dataset_client.metadata.created_at + initial_accessed = dataset_client.metadata.accessed_at + initial_modified = dataset_client.metadata.modified_at # Wait a moment to ensure timestamps can change await asyncio.sleep(0.01) @@ -263,11 +263,11 @@ async def test_metadata_updates(dataset_client: FileSystemDatasetClient) -> None await dataset_client.get_data() # Verify timestamps - assert dataset_client.created_at == initial_created - assert dataset_client.accessed_at > initial_accessed - assert dataset_client.modified_at == initial_modified + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.accessed_at > initial_accessed + assert dataset_client.metadata.modified_at == initial_modified - accessed_after_get = dataset_client.accessed_at + accessed_after_get = dataset_client.metadata.accessed_at # Wait a moment to ensure timestamps can change await asyncio.sleep(0.01) @@ -276,6 +276,6 @@ async def test_metadata_updates(dataset_client: FileSystemDatasetClient) -> None await dataset_client.push_data({'new': 'item'}) # Verify timestamps again - assert dataset_client.created_at == initial_created - assert dataset_client.modified_at > initial_modified - assert dataset_client.accessed_at > accessed_after_get + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.modified_at > initial_modified + assert dataset_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py index 38e65a16a2..dc0e8e721a 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py @@ -8,7 +8,7 @@ import pytest from crawlee._consts import METADATA_FILENAME -from crawlee.storage_clients._file_system._key_value_store_client import FileSystemKeyValueStoreClient +from crawlee.storage_clients._file_system import FileSystemKeyValueStoreClient if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -33,11 +33,11 @@ async def test_open_creates_new_kvs(tmp_path: Path) -> None: client = await FileSystemKeyValueStoreClient.open(name='new_kvs', storage_dir=tmp_path) # Verify client properties - assert client.id is not None - assert client.name == 'new_kvs' - assert isinstance(client.created_at, datetime) - assert isinstance(client.accessed_at, datetime) - assert isinstance(client.modified_at, datetime) + assert client.metadata.id is not None + assert client.metadata.name == 'new_kvs' + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) # Verify files were created assert client.path_to_kvs.exists() @@ -46,18 +46,18 @@ async def test_open_creates_new_kvs(tmp_path: Path) -> None: # Verify metadata content with client.path_to_metadata.open() as f: metadata = json.load(f) - assert metadata['id'] == client.id + assert metadata['id'] == client.metadata.id assert metadata['name'] == 'new_kvs' async def test_open_existing_kvs(kvs_client: FileSystemKeyValueStoreClient, tmp_path: Path) -> None: """Test that open() loads an existing key-value store with matching properties.""" # Open the same key-value store again - reopened_client = await FileSystemKeyValueStoreClient.open(name=kvs_client.name, storage_dir=tmp_path) + reopened_client = await FileSystemKeyValueStoreClient.open(name=kvs_client.metadata.name, storage_dir=tmp_path) # Verify client properties - assert kvs_client.id == reopened_client.id - assert kvs_client.name == reopened_client.name + assert kvs_client.metadata.id == reopened_client.metadata.id + assert kvs_client.metadata.name == reopened_client.metadata.name # Verify clients (python) ids - should be the same object due to caching assert id(kvs_client) == id(reopened_client) @@ -254,22 +254,22 @@ async def test_drop(kvs_client: FileSystemKeyValueStoreClient) -> None: """Test that drop removes the entire store directory from disk.""" await kvs_client.set_value(key='test', value='test-value') - assert kvs_client.name in FileSystemKeyValueStoreClient._cache_by_name + assert kvs_client.metadata.name in FileSystemKeyValueStoreClient._cache_by_name assert kvs_client.path_to_kvs.exists() # Drop the store await kvs_client.drop() - assert kvs_client.name not in FileSystemKeyValueStoreClient._cache_by_name + assert kvs_client.metadata.name not in FileSystemKeyValueStoreClient._cache_by_name assert not kvs_client.path_to_kvs.exists() async def test_metadata_updates(kvs_client: FileSystemKeyValueStoreClient) -> None: """Test that read/write operations properly update accessed_at and modified_at timestamps.""" # Record initial timestamps - initial_created = kvs_client.created_at - initial_accessed = kvs_client.accessed_at - initial_modified = kvs_client.modified_at + initial_created = kvs_client.metadata.created_at + initial_accessed = kvs_client.metadata.accessed_at + initial_modified = kvs_client.metadata.modified_at # Wait a moment to ensure timestamps can change await asyncio.sleep(0.01) @@ -278,11 +278,11 @@ async def test_metadata_updates(kvs_client: FileSystemKeyValueStoreClient) -> No await kvs_client.get_value(key='nonexistent') # Verify timestamps - assert kvs_client.created_at == initial_created - assert kvs_client.accessed_at > initial_accessed - assert kvs_client.modified_at == initial_modified + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.accessed_at > initial_accessed + assert kvs_client.metadata.modified_at == initial_modified - accessed_after_get = kvs_client.accessed_at + accessed_after_get = kvs_client.metadata.accessed_at # Wait a moment to ensure timestamps can change await asyncio.sleep(0.01) @@ -291,9 +291,9 @@ async def test_metadata_updates(kvs_client: FileSystemKeyValueStoreClient) -> No await kvs_client.set_value(key='new-key', value='new-value') # Verify timestamps again - assert kvs_client.created_at == initial_created - assert kvs_client.modified_at > initial_modified - assert kvs_client.accessed_at > accessed_after_get + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.modified_at > initial_modified + assert kvs_client.metadata.accessed_at > accessed_after_get async def test_get_public_url_not_supported(kvs_client: FileSystemKeyValueStoreClient) -> None: diff --git a/tests/unit/storage_clients/_file_system/test_fs_storage_client.py b/tests/unit/storage_clients/_file_system/test_fs_storage_client.py index 843911de97..d5eefeffc1 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_storage_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_storage_client.py @@ -4,9 +4,11 @@ import pytest -from crawlee.storage_clients._file_system._dataset_client import FileSystemDatasetClient -from crawlee.storage_clients._file_system._key_value_store_client import FileSystemKeyValueStoreClient -from crawlee.storage_clients._file_system._storage_client import FileSystemStorageClient +from crawlee.storage_clients._file_system import ( + FileSystemDatasetClient, + FileSystemKeyValueStoreClient, + FileSystemStorageClient, +) if TYPE_CHECKING: from pathlib import Path @@ -25,7 +27,7 @@ async def test_open_dataset_client(client: FileSystemStorageClient, tmp_path: Pa # Verify correct client type and properties assert isinstance(dataset_client, FileSystemDatasetClient) - assert dataset_client.name == 'test-dataset' + assert dataset_client.metadata.name == 'test-dataset' # Verify directory structure was created assert dataset_client.path_to_dataset.exists() @@ -86,7 +88,7 @@ async def test_open_kvs_client(client: FileSystemStorageClient, tmp_path: Path) # Verify correct client type and properties assert isinstance(kvs_client, FileSystemKeyValueStoreClient) - assert kvs_client.name == 'test-kvs' + assert kvs_client.metadata.name == 'test-kvs' # Verify directory structure was created assert kvs_client.path_to_kvs.exists() diff --git a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py index 6f18c5b6b7..7f349daf23 100644 --- a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py @@ -6,7 +6,7 @@ import pytest -from crawlee.storage_clients._memory._dataset_client import MemoryDatasetClient +from crawlee.storage_clients._memory import MemoryDatasetClient from crawlee.storage_clients.models import DatasetItemsListPage if TYPE_CHECKING: @@ -31,12 +31,12 @@ async def test_open_creates_new_dataset() -> None: client = await MemoryDatasetClient.open(name='new_dataset') # Verify client properties - assert client.id is not None - assert client.name == 'new_dataset' - assert client.item_count == 0 - assert isinstance(client.created_at, datetime) - assert isinstance(client.accessed_at, datetime) - assert isinstance(client.modified_at, datetime) + assert client.metadata.id is not None + assert client.metadata.name == 'new_dataset' + assert client.metadata.item_count == 0 + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) # Verify the client was cached assert 'new_dataset' in MemoryDatasetClient._cache_by_name @@ -45,12 +45,12 @@ async def test_open_creates_new_dataset() -> None: async def test_open_existing_dataset(dataset_client: MemoryDatasetClient) -> None: """Test that open() loads an existing dataset with matching properties.""" # Open the same dataset again - reopened_client = await MemoryDatasetClient.open(name=dataset_client.name) + reopened_client = await MemoryDatasetClient.open(name=dataset_client.metadata.name) # Verify client properties - assert dataset_client.id == reopened_client.id - assert dataset_client.name == reopened_client.name - assert dataset_client.item_count == reopened_client.item_count + assert dataset_client.metadata.id == reopened_client.metadata.id + assert dataset_client.metadata.name == reopened_client.metadata.name + assert dataset_client.metadata.item_count == reopened_client.metadata.item_count # Verify clients (python) ids assert id(dataset_client) == id(reopened_client) @@ -59,8 +59,8 @@ async def test_open_existing_dataset(dataset_client: MemoryDatasetClient) -> Non async def test_open_with_id_and_name() -> None: """Test that open() can be used with both id and name parameters.""" client = await MemoryDatasetClient.open(id='some-id', name='some-name') - assert client.id == 'some-id' - assert client.name == 'some-name' + assert client.metadata.id == 'some-id' + assert client.metadata.name == 'some-name' async def test_push_data_single_item(dataset_client: MemoryDatasetClient) -> None: @@ -69,7 +69,7 @@ async def test_push_data_single_item(dataset_client: MemoryDatasetClient) -> Non await dataset_client.push_data(item) # Verify item count was updated - assert dataset_client.item_count == 1 + assert dataset_client.metadata.item_count == 1 # Verify item was stored result = await dataset_client.get_data() @@ -87,7 +87,7 @@ async def test_push_data_multiple_items(dataset_client: MemoryDatasetClient) -> await dataset_client.push_data(items) # Verify item count was updated - assert dataset_client.item_count == 3 + assert dataset_client.metadata.item_count == 3 # Verify items were stored result = await dataset_client.get_data() @@ -229,16 +229,16 @@ async def test_drop(dataset_client: MemoryDatasetClient) -> None: await dataset_client.push_data({'test': 'data'}) # Verify the dataset exists in the cache - assert dataset_client.name in MemoryDatasetClient._cache_by_name + assert dataset_client.metadata.name in MemoryDatasetClient._cache_by_name # Drop the dataset await dataset_client.drop() # Verify the dataset was removed from the cache - assert dataset_client.name not in MemoryDatasetClient._cache_by_name + assert dataset_client.metadata.name not in MemoryDatasetClient._cache_by_name # Verify the dataset is empty - assert dataset_client.item_count == 0 + assert dataset_client.metadata.item_count == 0 result = await dataset_client.get_data() assert result.count == 0 @@ -246,9 +246,9 @@ async def test_drop(dataset_client: MemoryDatasetClient) -> None: async def test_metadata_updates(dataset_client: MemoryDatasetClient) -> None: """Test that read/write operations properly update accessed_at and modified_at timestamps.""" # Record initial timestamps - initial_created = dataset_client.created_at - initial_accessed = dataset_client.accessed_at - initial_modified = dataset_client.modified_at + initial_created = dataset_client.metadata.created_at + initial_accessed = dataset_client.metadata.accessed_at + initial_modified = dataset_client.metadata.modified_at # Wait a moment to ensure timestamps can change await asyncio.sleep(0.01) @@ -257,11 +257,11 @@ async def test_metadata_updates(dataset_client: MemoryDatasetClient) -> None: await dataset_client.get_data() # Verify timestamps - assert dataset_client.created_at == initial_created - assert dataset_client.accessed_at > initial_accessed - assert dataset_client.modified_at == initial_modified + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.accessed_at > initial_accessed + assert dataset_client.metadata.modified_at == initial_modified - accessed_after_get = dataset_client.accessed_at + accessed_after_get = dataset_client.metadata.accessed_at # Wait a moment to ensure timestamps can change await asyncio.sleep(0.01) @@ -270,6 +270,6 @@ async def test_metadata_updates(dataset_client: MemoryDatasetClient) -> None: await dataset_client.push_data({'new': 'item'}) # Verify timestamps again - assert dataset_client.created_at == initial_created - assert dataset_client.modified_at > initial_modified - assert dataset_client.accessed_at > accessed_after_get + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.modified_at > initial_modified + assert dataset_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py index 8940764839..3b3b4806a7 100644 --- a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py @@ -6,7 +6,7 @@ import pytest -from crawlee.storage_clients._memory._key_value_store_client import MemoryKeyValueStoreClient +from crawlee.storage_clients._memory import MemoryKeyValueStoreClient from crawlee.storage_clients.models import KeyValueStoreRecordMetadata if TYPE_CHECKING: @@ -30,11 +30,11 @@ async def test_open_creates_new_store() -> None: client = await MemoryKeyValueStoreClient.open(name='new_kvs') # Verify client properties - assert client.id is not None - assert client.name == 'new_kvs' - assert isinstance(client.created_at, datetime) - assert isinstance(client.accessed_at, datetime) - assert isinstance(client.modified_at, datetime) + assert client.metadata.id is not None + assert client.metadata.name == 'new_kvs' + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) # Verify the client was cached assert 'new_kvs' in MemoryKeyValueStoreClient._cache_by_name @@ -43,11 +43,11 @@ async def test_open_creates_new_store() -> None: async def test_open_existing_store(kvs_client: MemoryKeyValueStoreClient) -> None: """Test that open() loads an existing key-value store with matching properties.""" # Open the same key-value store again - reopened_client = await MemoryKeyValueStoreClient.open(name=kvs_client.name) + reopened_client = await MemoryKeyValueStoreClient.open(name=kvs_client.metadata.name) # Verify client properties - assert kvs_client.id == reopened_client.id - assert kvs_client.name == reopened_client.name + assert kvs_client.metadata.id == reopened_client.metadata.id + assert kvs_client.metadata.name == reopened_client.metadata.name # Verify clients (python) ids assert id(kvs_client) == id(reopened_client) @@ -56,8 +56,8 @@ async def test_open_existing_store(kvs_client: MemoryKeyValueStoreClient) -> Non async def test_open_with_id_and_name() -> None: """Test that open() can be used with both id and name parameters.""" client = await MemoryKeyValueStoreClient.open(id='some-id', name='some-name') - assert client.id == 'some-id' - assert client.name == 'some-name' + assert client.metadata.id == 'some-id' + assert client.metadata.name == 'some-name' @pytest.mark.parametrize( @@ -185,13 +185,13 @@ async def test_drop(kvs_client: MemoryKeyValueStoreClient) -> None: await kvs_client.set_value(key='test', value='data') # Verify the store exists in the cache - assert kvs_client.name in MemoryKeyValueStoreClient._cache_by_name + assert kvs_client.metadata.name in MemoryKeyValueStoreClient._cache_by_name # Drop the store await kvs_client.drop() # Verify the store was removed from the cache - assert kvs_client.name not in MemoryKeyValueStoreClient._cache_by_name + assert kvs_client.metadata.name not in MemoryKeyValueStoreClient._cache_by_name # Verify the store is empty record = await kvs_client.get_value(key='test') @@ -207,9 +207,9 @@ async def test_get_public_url(kvs_client: MemoryKeyValueStoreClient) -> None: async def test_metadata_updates(kvs_client: MemoryKeyValueStoreClient) -> None: """Test that read/write operations properly update accessed_at and modified_at timestamps.""" # Record initial timestamps - initial_created = kvs_client.created_at - initial_accessed = kvs_client.accessed_at - initial_modified = kvs_client.modified_at + initial_created = kvs_client.metadata.created_at + initial_accessed = kvs_client.metadata.accessed_at + initial_modified = kvs_client.metadata.modified_at # Wait a moment to ensure timestamps can change await asyncio.sleep(0.01) @@ -218,11 +218,11 @@ async def test_metadata_updates(kvs_client: MemoryKeyValueStoreClient) -> None: await kvs_client.get_value(key='nonexistent') # Verify timestamps - assert kvs_client.created_at == initial_created - assert kvs_client.accessed_at > initial_accessed - assert kvs_client.modified_at == initial_modified + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.accessed_at > initial_accessed + assert kvs_client.metadata.modified_at == initial_modified - accessed_after_get = kvs_client.accessed_at + accessed_after_get = kvs_client.metadata.accessed_at # Wait a moment to ensure timestamps can change await asyncio.sleep(0.01) @@ -231,6 +231,6 @@ async def test_metadata_updates(kvs_client: MemoryKeyValueStoreClient) -> None: await kvs_client.set_value(key='new_key', value='new value') # Verify timestamps again - assert kvs_client.created_at == initial_created - assert kvs_client.modified_at > initial_modified - assert kvs_client.accessed_at > accessed_after_get + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.modified_at > initial_modified + assert kvs_client.metadata.accessed_at > accessed_after_get From b833f915e4cf7d9ae36ca359fd3931d62595ff2f Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Tue, 15 Apr 2025 14:05:54 +0200 Subject: [PATCH 13/22] Utilize pathlib and use Config in constructors --- src/crawlee/crawlers/_basic/_basic_crawler.py | 24 ++- .../hooks/post_gen_project.py | 6 +- .../storage_clients/_base/_dataset_client.py | 11 +- .../_base/_key_value_store_client.py | 11 +- .../_base/_request_queue_client.py | 11 +- .../storage_clients/_base/_storage_client.py | 11 +- .../_file_system/_dataset_client.py | 23 ++- .../_file_system/_key_value_store_client.py | 21 +-- .../_file_system/_request_queue_client.py | 20 +-- .../_file_system/_storage_client.py | 36 ++--- .../_memory/_dataset_client.py | 17 +-- .../_memory/_key_value_store_client.py | 17 +-- .../_memory/_request_queue_client.py | 16 +- .../_memory/_storage_client.py | 38 ++--- src/crawlee/storages/_dataset.py | 4 +- src/crawlee/storages/_key_value_store.py | 4 +- src/crawlee/storages/_request_queue.py | 4 +- tests/e2e/project_template/utils.py | 20 ++- tests/unit/_utils/test_file.py | 6 +- tests/unit/conftest.py | 2 - .../_file_system/test_fs_dataset_client.py | 94 ++++++++++-- .../_file_system/test_fs_kvs_client.py | 95 ++++++++++-- .../_file_system/test_fs_storage_client.py | 144 ------------------ .../_memory/test_memory_dataset_client.py | 73 ++++++++- .../_memory/test_memory_kvs_client.py | 78 ++++++++-- 25 files changed, 437 insertions(+), 349 deletions(-) delete mode 100644 tests/unit/storage_clients/_file_system/test_fs_storage_client.py diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index f87234abad..f800f43dc3 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -660,14 +660,12 @@ async def _use_state( self, default_value: dict[str, JsonSerializable] | None = None, ) -> dict[str, JsonSerializable]: - kvs = await self.get_key_value_store() - # TODO: - # return some kvs value + # TODO: implement + return {} async def _save_crawler_state(self) -> None: - kvs = await self.get_key_value_store() - # TODO: - # some kvs call + pass + # TODO: implement async def get_data( self, @@ -697,16 +695,16 @@ async def export_data( dataset_id: str | None = None, dataset_name: str | None = None, ) -> None: - """Export data from a `Dataset`. + """Export all items from a Dataset to a JSON or CSV file. - This helper method simplifies the process of exporting data from a `Dataset`. It opens the specified - one and then exports the data based on the provided parameters. If you need to pass options - specific to the output format, use the `export_data_csv` or `export_data_json` method instead. + This method simplifies the process of exporting data collected during crawling. It automatically + determines the export format based on the file extension (`.json` or `.csv`) and handles + the conversion of `Dataset` items to the appropriate format. Args: - path: The destination path. - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. + path: The destination file path. Must end with '.json' or '.csv'. + dataset_id: The ID of the Dataset to export from. If None, uses `name` parameter instead. + dataset_name: The name of the Dataset to export from. If None, uses `id` parameter instead. """ dataset = await self.get_dataset(id=dataset_id, name=dataset_name) diff --git a/src/crawlee/project_template/hooks/post_gen_project.py b/src/crawlee/project_template/hooks/post_gen_project.py index e076ff9308..c0495a724d 100644 --- a/src/crawlee/project_template/hooks/post_gen_project.py +++ b/src/crawlee/project_template/hooks/post_gen_project.py @@ -2,7 +2,6 @@ import subprocess from pathlib import Path - # % if cookiecutter.package_manager in ['poetry', 'uv'] Path('requirements.txt').unlink() @@ -32,8 +31,9 @@ # Install requirements and generate requirements.txt as an impromptu lockfile subprocess.check_call([str(path / 'pip'), 'install', '-r', 'requirements.txt']) -with open('requirements.txt', 'w') as requirements_txt: - subprocess.check_call([str(path / 'pip'), 'freeze'], stdout=requirements_txt) +Path('requirements.txt').write_text( + subprocess.check_output([str(path / 'pip'), 'freeze']).decode() +) # % if cookiecutter.crawler_type == 'playwright' subprocess.check_call([str(path / 'playwright'), 'install']) diff --git a/src/crawlee/storage_clients/_base/_dataset_client.py b/src/crawlee/storage_clients/_base/_dataset_client.py index dc20ce89f3..c73eb6f51f 100644 --- a/src/crawlee/storage_clients/_base/_dataset_client.py +++ b/src/crawlee/storage_clients/_base/_dataset_client.py @@ -7,9 +7,9 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from pathlib import Path from typing import Any + from crawlee.configuration import Configuration from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata @@ -37,9 +37,9 @@ def metadata(self) -> DatasetMetadata: async def open( cls, *, - id: str | None = None, - name: str | None = None, - storage_dir: Path | None = None, + id: str | None, + name: str | None, + configuration: Configuration, ) -> DatasetClient: """Open existing or create a new dataset client. @@ -51,8 +51,7 @@ async def open( Args: id: The ID of the dataset. If not provided, an ID may be generated. name: The name of the dataset. If not provided a default name may be used. - storage_dir: The path to the storage directory. If the client persists data, - it should use this directory. May be ignored by non-persistent implementations. + configuration: The configuration object. Returns: A dataset client instance. diff --git a/src/crawlee/storage_clients/_base/_key_value_store_client.py b/src/crawlee/storage_clients/_base/_key_value_store_client.py index 7eb9160d16..957b53db0e 100644 --- a/src/crawlee/storage_clients/_base/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_base/_key_value_store_client.py @@ -7,8 +7,8 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from pathlib import Path + from crawlee.configuration import Configuration from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata @@ -36,9 +36,9 @@ def metadata(self) -> KeyValueStoreMetadata: async def open( cls, *, - id: str | None = None, - name: str | None = None, - storage_dir: Path | None = None, + id: str | None, + name: str | None, + configuration: Configuration, ) -> KeyValueStoreClient: """Open existing or create a new key-value store client. @@ -51,8 +51,7 @@ async def open( Args: id: The ID of the key-value store. If not provided, an ID may be generated. name: The name of the key-value store. If not provided a default name may be used. - storage_dir: The path to the storage directory. If the client persists data, - it should use this directory. May be ignored by non-persistent implementations. + configuration: The configuration object. Returns: A key-value store client instance. diff --git a/src/crawlee/storage_clients/_base/_request_queue_client.py b/src/crawlee/storage_clients/_base/_request_queue_client.py index 8b0b11ef0a..184d2ca97c 100644 --- a/src/crawlee/storage_clients/_base/_request_queue_client.py +++ b/src/crawlee/storage_clients/_base/_request_queue_client.py @@ -7,8 +7,8 @@ if TYPE_CHECKING: from collections.abc import Sequence - from pathlib import Path + from crawlee.configuration import Configuration from crawlee.storage_clients.models import ( BatchRequestsOperationResponse, ProcessedRequest, @@ -37,17 +37,16 @@ def metadata(self) -> RequestQueueMetadata: async def open( cls, *, - id: str | None = None, - name: str | None = None, - storage_dir: Path | None = None, + id: str | None, + name: str | None, + configuration: Configuration, ) -> RequestQueueClient: """Open a request queue client. Args: id: ID of the queue to open. If not provided, a new queue will be created with a random ID. name: Name of the queue to open. If not provided, the queue will be unnamed. - purge_on_start: If True, the queue will be purged before opening. - storage_dir: Directory to store the queue data in. If not provided, uses the default storage directory. + configuration: The configuration object. Returns: A request queue client. diff --git a/src/crawlee/storage_clients/_base/_storage_client.py b/src/crawlee/storage_clients/_base/_storage_client.py index f85d5a5bb5..fefa7ea5cb 100644 --- a/src/crawlee/storage_clients/_base/_storage_client.py +++ b/src/crawlee/storage_clients/_base/_storage_client.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from pathlib import Path + from crawlee.configuration import Configuration from ._dataset_client import DatasetClient from ._key_value_store_client import KeyValueStoreClient @@ -20,8 +20,7 @@ async def open_dataset_client( *, id: str | None = None, name: str | None = None, - purge_on_start: bool = True, - storage_dir: Path | None = None, + configuration: Configuration | None = None, ) -> DatasetClient: """Open a dataset client.""" @@ -31,8 +30,7 @@ async def open_key_value_store_client( *, id: str | None = None, name: str | None = None, - purge_on_start: bool = True, - storage_dir: Path | None = None, + configuration: Configuration | None = None, ) -> KeyValueStoreClient: """Open a key-value store client.""" @@ -42,7 +40,6 @@ async def open_request_queue_client( *, id: str | None = None, name: str | None = None, - purge_on_start: bool = True, - storage_dir: Path | None = None, + configuration: Configuration | None = None, ) -> RequestQueueClient: """Open a request queue client.""" diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py index 1bd27beef6..958566d925 100644 --- a/src/crawlee/storage_clients/_file_system/_dataset_client.py +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -21,6 +21,8 @@ from collections.abc import AsyncIterator from typing import Any + from crawlee.configuration import Configuration + logger = getLogger(__name__) @@ -32,14 +34,11 @@ class FileSystemDatasetClient(DatasetClient): filename, allowing for easy ordering and pagination. """ - _DEFAULT_NAME = 'default' - """The default name for the dataset when no name is provided.""" - _STORAGE_SUBDIR = 'datasets' """The name of the subdirectory where datasets are stored.""" - _LOCAL_ENTRY_NAME_DIGITS = 9 - """Number of digits used for the file names (e.g., 000000019.json).""" + _ITEM_FILENAME_DIGITS = 9 + """Number of digits used for the dataset item file names (e.g., 000000019.json).""" _cache_by_name: ClassVar[dict[str, FileSystemDatasetClient]] = {} """A dictionary to cache clients by their names.""" @@ -72,7 +71,7 @@ def __init__( # Internal attributes self._lock = asyncio.Lock() - """A lock to ensure that only one file operation is performed at a time.""" + """A lock to ensure that only one operation is performed at a time.""" @override @property @@ -94,16 +93,16 @@ def path_to_metadata(self) -> Path: async def open( cls, *, - id: str | None = None, - name: str | None = None, - storage_dir: Path | None = None, + id: str | None, + name: str | None, + configuration: Configuration, ) -> FileSystemDatasetClient: if id: raise ValueError( 'Opening a dataset by "id" is not supported for file system storage client, use "name" instead.' ) - name = name or cls._DEFAULT_NAME + name = name or configuration.default_dataset_id # Check if the client is already cached by name. if name in cls._cache_by_name: @@ -111,7 +110,7 @@ async def open( await client._update_metadata(update_accessed_at=True) # noqa: SLF001 return client - storage_dir = storage_dir or Path.cwd() + storage_dir = Path(configuration.storage_dir) dataset_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = dataset_path / METADATA_FILENAME @@ -386,7 +385,7 @@ async def _push_item(self, item: dict[str, Any], item_id: int) -> None: # Acquire the lock to perform file operations safely. async with self._lock: # Generate the filename for the new item using zero-padded numbering. - filename = f'{str(item_id).zfill(self._LOCAL_ENTRY_NAME_DIGITS)}.json' + filename = f'{str(item_id).zfill(self._ITEM_FILENAME_DIGITS)}.json' file_path = self.path_to_dataset / filename # Ensure the dataset directory exists. diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py index 79c3d7102d..7799b71583 100644 --- a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -21,6 +21,8 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator + from crawlee.configuration import Configuration + logger = getLogger(__name__) @@ -33,9 +35,6 @@ class FileSystemKeyValueStoreClient(KeyValueStoreClient): in an accompanying file. """ - _DEFAULT_NAME = 'default' - """The default name for the unnamed key-value store.""" - _STORAGE_SUBDIR = 'key_value_stores' """The name of the subdirectory where key-value stores are stored.""" @@ -68,7 +67,7 @@ def __init__( # Internal attributes self._lock = asyncio.Lock() - """A lock to ensure that only one file operation is performed at a time.""" + """A lock to ensure that only one operation is performed at a time.""" @override @property @@ -90,22 +89,24 @@ def path_to_metadata(self) -> Path: async def open( cls, *, - id: str | None = None, - name: str | None = None, - storage_dir: Path | None = None, + id: str | None, + name: str | None, + configuration: Configuration, ) -> FileSystemKeyValueStoreClient: if id: raise ValueError( 'Opening a key-value store by "id" is not supported for file system storage client, use "name" instead.' ) - name = name or cls._DEFAULT_NAME + name = name or configuration.default_dataset_id # Check if the client is already cached by name. if name in cls._cache_by_name: - return cls._cache_by_name[name] + client = cls._cache_by_name[name] + await client._update_metadata(update_accessed_at=True) # noqa: SLF001 + return client - storage_dir = storage_dir or Path.cwd() + storage_dir = Path(configuration.storage_dir) kvs_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = kvs_path / METADATA_FILENAME diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py index fd9866e17f..e5c32d860a 100644 --- a/src/crawlee/storage_clients/_file_system/_request_queue_client.py +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from logging import getLogger from pathlib import Path -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar from pydantic import ValidationError from typing_extensions import override @@ -17,6 +17,9 @@ from ._utils import METADATA_FILENAME, json_dumps +if TYPE_CHECKING: + from crawlee.configuration import Configuration + logger = getLogger(__name__) @@ -28,9 +31,6 @@ class FileSystemRequestQueueClient(RequestQueueClient): handling and tracking across crawler runs. """ - _DEFAULT_NAME = 'default' - """The default name for the unnamed request queue.""" - _STORAGE_SUBDIR = 'request_queues' """The name of the subdirectory where request queues are stored.""" @@ -73,7 +73,7 @@ def __init__( # Internal attributes self._lock = asyncio.Lock() - """A lock to ensure that only one file operation is performed at a time.""" + """A lock to ensure that only one operation is performed at a time.""" @override @property @@ -95,16 +95,16 @@ def path_to_metadata(self) -> Path: async def open( cls, *, - id: str | None = None, - name: str | None = None, - storage_dir: Path | None = None, + id: str | None, + name: str | None, + configuration: Configuration, ) -> FileSystemRequestQueueClient: if id: raise ValueError( 'Opening a dataset by "id" is not supported for file system storage client, use "name" instead.' ) - name = name or cls._DEFAULT_NAME + name = name or configuration.default_dataset_id # Check if the client is already cached by name. if name in cls._cache_by_name: @@ -112,7 +112,7 @@ async def open( await client._update_metadata(update_accessed_at=True) # noqa: SLF001 return client - storage_dir = storage_dir or Path.cwd() + storage_dir = Path(configuration.storage_dir) rq_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = rq_path / METADATA_FILENAME diff --git a/src/crawlee/storage_clients/_file_system/_storage_client.py b/src/crawlee/storage_clients/_file_system/_storage_client.py index 8679e67c2f..2765d15536 100644 --- a/src/crawlee/storage_clients/_file_system/_storage_client.py +++ b/src/crawlee/storage_clients/_file_system/_storage_client.py @@ -1,18 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from typing_extensions import override +from crawlee.configuration import Configuration from crawlee.storage_clients._base import StorageClient from ._dataset_client import FileSystemDatasetClient from ._key_value_store_client import FileSystemKeyValueStoreClient from ._request_queue_client import FileSystemRequestQueueClient -if TYPE_CHECKING: - from pathlib import Path - class FileSystemStorageClient(StorageClient): """File system storage client.""" @@ -23,14 +19,14 @@ async def open_dataset_client( *, id: str | None = None, name: str | None = None, - purge_on_start: bool = True, - storage_dir: Path | None = None, + configuration: Configuration | None = None, ) -> FileSystemDatasetClient: - client = await FileSystemDatasetClient.open(id=id, name=name, storage_dir=storage_dir) + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemDatasetClient.open(id=id, name=name, configuration=configuration) - if purge_on_start: + if configuration.purge_on_start: await client.drop() - client = await FileSystemDatasetClient.open(id=id, name=name, storage_dir=storage_dir) + client = await FileSystemDatasetClient.open(id=id, name=name, configuration=configuration) return client @@ -40,14 +36,14 @@ async def open_key_value_store_client( *, id: str | None = None, name: str | None = None, - purge_on_start: bool = True, - storage_dir: Path | None = None, + configuration: Configuration | None = None, ) -> FileSystemKeyValueStoreClient: - client = await FileSystemKeyValueStoreClient.open(id=id, name=name, storage_dir=storage_dir) + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemKeyValueStoreClient.open(id=id, name=name, configuration=configuration) - if purge_on_start: + if configuration.purge_on_start: await client.drop() - client = await FileSystemKeyValueStoreClient.open(id=id, name=name, storage_dir=storage_dir) + client = await FileSystemKeyValueStoreClient.open(id=id, name=name, configuration=configuration) return client @@ -57,13 +53,13 @@ async def open_request_queue_client( *, id: str | None = None, name: str | None = None, - purge_on_start: bool = True, - storage_dir: Path | None = None, + configuration: Configuration | None = None, ) -> FileSystemRequestQueueClient: - client = await FileSystemRequestQueueClient.open(id=id, name=name, storage_dir=storage_dir) + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemRequestQueueClient.open(id=id, name=name, configuration=configuration) - if purge_on_start: + if configuration.purge_on_start: await client.drop() - client = await FileSystemRequestQueueClient.open(id=id, name=name, storage_dir=storage_dir) + client = await FileSystemRequestQueueClient.open(id=id, name=name, configuration=configuration) return client diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index f067d6dcd0..3a0e486330 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -12,7 +12,8 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from pathlib import Path + + from crawlee.configuration import Configuration logger = getLogger(__name__) @@ -25,9 +26,6 @@ class MemoryDatasetClient(DatasetClient): and development purposes where persistence is not required. """ - _DEFAULT_NAME = 'default' - """The default name for the dataset when no name is provided.""" - _cache_by_name: ClassVar[dict[str, MemoryDatasetClient]] = {} """A dictionary to cache clients by their names.""" @@ -67,14 +65,11 @@ def metadata(self) -> DatasetMetadata: async def open( cls, *, - id: str | None = None, - name: str | None = None, - storage_dir: Path | None = None, + id: str | None, + name: str | None, + configuration: Configuration, ) -> MemoryDatasetClient: - if storage_dir is not None: - logger.warning('The `storage_dir` argument is not used in the memory dataset client.') - - name = name or cls._DEFAULT_NAME + name = name or configuration.default_dataset_id # Check if the client is already cached by name. if name in cls._cache_by_name: diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index 2240734930..76bcd5761e 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -14,7 +14,8 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from pathlib import Path + + from crawlee.configuration import Configuration logger = getLogger(__name__) @@ -27,9 +28,6 @@ class MemoryKeyValueStoreClient(KeyValueStoreClient): for testing and development purposes where persistence is not required. """ - _DEFAULT_NAME = 'default' - """The default name for the key-value store when no name is provided.""" - _cache_by_name: ClassVar[dict[str, MemoryKeyValueStoreClient]] = {} """A dictionary to cache clients by their names.""" @@ -67,14 +65,11 @@ def metadata(self) -> KeyValueStoreMetadata: async def open( cls, *, - id: str | None = None, - name: str | None = None, - storage_dir: Path | None = None, + id: str | None, + name: str | None, + configuration: Configuration, ) -> MemoryKeyValueStoreClient: - if storage_dir is not None: - logger.warning('The `storage_dir` argument is not used in the memory key-value store client.') - - name = name or cls._DEFAULT_NAME + name = name or configuration.default_key_value_store_id # Check if the client is already cached by name if name in cls._cache_by_name: diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index 293fa5b88c..95775bfa83 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -12,7 +12,7 @@ from crawlee.storage_clients.models import RequestQueueMetadata if TYPE_CHECKING: - from pathlib import Path + from crawlee.configuration import Configuration logger = getLogger(__name__) @@ -25,9 +25,6 @@ class MemoryRequestQueueClient(RequestQueueClient): for testing and development purposes where persistence is not required. """ - _DEFAULT_NAME = 'default' - """The default name for the dataset when no name is provided.""" - _cache_by_name: ClassVar[dict[str, MemoryRequestQueueClient]] = {} """A dictionary to cache clients by their names.""" @@ -75,14 +72,11 @@ def metadata(self) -> RequestQueueMetadata: async def open( cls, *, - id: str | None = None, - name: str | None = None, - storage_dir: Path | None = None, + id: str | None, + name: str | None, + configuration: Configuration, ) -> MemoryRequestQueueClient: - if storage_dir is not None: - logger.warning('The `storage_dir` argument is not used in the memory request queue client.') - - name = name or cls._DEFAULT_NAME + name = name or configuration.default_request_queue_id # Check if the client is already cached by name if name in cls._cache_by_name: diff --git a/src/crawlee/storage_clients/_memory/_storage_client.py b/src/crawlee/storage_clients/_memory/_storage_client.py index 5ce9b16dd1..6123a6ca53 100644 --- a/src/crawlee/storage_clients/_memory/_storage_client.py +++ b/src/crawlee/storage_clients/_memory/_storage_client.py @@ -1,18 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from typing_extensions import override +from crawlee.configuration import Configuration from crawlee.storage_clients._base import StorageClient from ._dataset_client import MemoryDatasetClient from ._key_value_store_client import MemoryKeyValueStoreClient from ._request_queue_client import MemoryRequestQueueClient -if TYPE_CHECKING: - from pathlib import Path - class MemoryStorageClient(StorageClient): """Memory storage client.""" @@ -23,14 +19,14 @@ async def open_dataset_client( *, id: str | None = None, name: str | None = None, - purge_on_start: bool = True, - storage_dir: Path | None = None + configuration: Configuration | None = None, ) -> MemoryDatasetClient: - client = await MemoryDatasetClient.open(id=id, name=name, storage_dir=storage_dir) + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryDatasetClient.open(id=id, name=name, configuration=configuration) - if purge_on_start: + if configuration.purge_on_start: await client.drop() - client = await MemoryDatasetClient.open(id=id, name=name, storage_dir=storage_dir) + client = await MemoryDatasetClient.open(id=id, name=name, configuration=configuration) return client @@ -40,14 +36,14 @@ async def open_key_value_store_client( *, id: str | None = None, name: str | None = None, - purge_on_start: bool = True, - storage_dir: Path | None = None + configuration: Configuration | None = None, ) -> MemoryKeyValueStoreClient: - client = await MemoryKeyValueStoreClient.open(id=id, name=name, storage_dir=storage_dir) + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryKeyValueStoreClient.open(id=id, name=name, configuration=configuration) - if purge_on_start: + if configuration.purge_on_start: await client.drop() - client = await MemoryKeyValueStoreClient.open(id=id, name=name, storage_dir=storage_dir) + client = await MemoryKeyValueStoreClient.open(id=id, name=name, configuration=configuration) return client @@ -57,7 +53,13 @@ async def open_request_queue_client( *, id: str | None = None, name: str | None = None, - purge_on_start: bool = True, - storage_dir: Path | None = None + configuration: Configuration | None = None, ) -> MemoryRequestQueueClient: - pass + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryRequestQueueClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await MemoryRequestQueueClient.open(id=id, name=name, configuration=configuration) + + return client diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index ea386c082d..4b88ebbef0 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -120,11 +120,11 @@ async def open( purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start storage_dir = Path(configuration.storage_dir) if storage_dir is None else storage_dir + # TODO client = await storage_client.open_dataset_client( id=id, name=name, - purge_on_start=purge_on_start, - storage_dir=storage_dir, + configuration=configuration, ) dataset = cls(client) diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 54264b754e..4ea439878a 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -113,11 +113,11 @@ async def open( purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start storage_dir = Path(configuration.storage_dir) if storage_dir is None else storage_dir + # TODO client = await storage_client.open_key_value_store_client( id=id, name=name, - purge_on_start=purge_on_start, - storage_dir=storage_dir, + configuration=configuration, ) kvs = cls(client) diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index 3e3a65b2d3..2330e906f7 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -109,11 +109,11 @@ async def open( purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start storage_dir = Path(configuration.storage_dir) if storage_dir is None else storage_dir + # TODO client = await storage_client.open_request_queue_client( id=id, name=name, - purge_on_start=purge_on_start, - storage_dir=storage_dir, + configuration=configuration, ) return cls(client) diff --git a/tests/e2e/project_template/utils.py b/tests/e2e/project_template/utils.py index 3bc5be4ea6..685e8c45e8 100644 --- a/tests/e2e/project_template/utils.py +++ b/tests/e2e/project_template/utils.py @@ -20,23 +20,25 @@ def patch_crawlee_version_in_project( def _patch_crawlee_version_in_requirements_txt_based_project(project_path: Path, wheel_path: Path) -> None: # Get any extras - with open(project_path / 'requirements.txt') as f: + requirements_path = project_path / 'requirements.txt' + with requirements_path.open() as f: requirements = f.read() crawlee_extras = re.findall(r'crawlee(\[.*\])', requirements)[0] or '' # Modify requirements.txt to use crawlee from wheel file instead of from Pypi - with open(project_path / 'requirements.txt') as f: + with requirements_path.open() as f: modified_lines = [] for line in f: if 'crawlee' in line: modified_lines.append(f'./{wheel_path.name}{crawlee_extras}\n') else: modified_lines.append(line) - with open(project_path / 'requirements.txt', 'w') as f: + with requirements_path.open('w') as f: f.write(''.join(modified_lines)) # Patch the dockerfile to have wheel file available - with open(project_path / 'Dockerfile') as f: + dockerfile_path = project_path / 'Dockerfile' + with dockerfile_path.open() as f: modified_lines = [] for line in f: modified_lines.append(line) @@ -49,19 +51,21 @@ def _patch_crawlee_version_in_requirements_txt_based_project(project_path: Path, f'RUN pip install ./{wheel_path.name}{crawlee_extras} --force-reinstall\n', ] ) - with open(project_path / 'Dockerfile', 'w') as f: + with dockerfile_path.open('w') as f: f.write(''.join(modified_lines)) def _patch_crawlee_version_in_pyproject_toml_based_project(project_path: Path, wheel_path: Path) -> None: """Ensure that the test is using current version of the crawlee from the source and not from Pypi.""" # Get any extras - with open(project_path / 'pyproject.toml') as f: + pyproject_path = project_path / 'pyproject.toml' + with pyproject_path.open() as f: pyproject = f.read() crawlee_extras = re.findall(r'crawlee(\[.*\])', pyproject)[0] or '' # Inject crawlee wheel file to the docker image and update project to depend on it.""" - with open(project_path / 'Dockerfile') as f: + dockerfile_path = project_path / 'Dockerfile' + with dockerfile_path.open() as f: modified_lines = [] for line in f: modified_lines.append(line) @@ -94,5 +98,5 @@ def _patch_crawlee_version_in_pyproject_toml_based_project(project_path: Path, w f'RUN {package_manager} lock\n', ] ) - with open(project_path / 'Dockerfile', 'w') as f: + with dockerfile_path.open('w') as f: f.write(''.join(modified_lines)) diff --git a/tests/unit/_utils/test_file.py b/tests/unit/_utils/test_file.py index b05d44723e..0762e1d966 100644 --- a/tests/unit/_utils/test_file.py +++ b/tests/unit/_utils/test_file.py @@ -104,7 +104,7 @@ async def test_force_remove(tmp_path: Path) -> None: assert test_file_path.exists() is False # Remove the file if it exists - with open(test_file_path, 'a', encoding='utf-8'): + with test_file_path.open('a', encoding='utf-8'): pass assert test_file_path.exists() is True await force_remove(test_file_path) @@ -123,11 +123,11 @@ async def test_force_rename(tmp_path: Path) -> None: # Will remove dst_dir if it exists (also covers normal case) # Create the src_dir with a file in it src_dir.mkdir() - with open(src_file, 'a', encoding='utf-8'): + with src_file.open('a', encoding='utf-8'): pass # Create the dst_dir with a file in it dst_dir.mkdir() - with open(dst_file, 'a', encoding='utf-8'): + with dst_file.open('a', encoding='utf-8'): pass assert src_file.exists() is True assert dst_file.exists() is True diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a749d43f2e..1b73df5743 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -93,8 +93,6 @@ def _set_crawler_log_level(pytestconfig: pytest.Config, monkeypatch: pytest.Monk monkeypatch.setattr(_log_config, 'get_configured_log_level', lambda: getattr(logging, loglevel.upper())) - - @pytest.fixture async def proxy_info(unused_tcp_port: int) -> ProxyInfo: username = 'user' diff --git a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py index 0368b517f4..fbea0baac1 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py @@ -9,6 +9,8 @@ import pytest from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient from crawlee.storage_clients._file_system import FileSystemDatasetClient from crawlee.storage_clients.models import DatasetItemsListPage @@ -19,21 +21,32 @@ @pytest.fixture -async def dataset_client(tmp_path: Path) -> AsyncGenerator[FileSystemDatasetClient, None]: - """A fixture for a file system dataset client.""" - # Clear any existing dataset clients in the cache - FileSystemDatasetClient._cache_by_name.clear() +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + - client = await FileSystemDatasetClient.open(name='test_dataset', storage_dir=tmp_path) +@pytest.fixture +async def dataset_client(configuration: Configuration) -> AsyncGenerator[FileSystemDatasetClient, None]: + """A fixture for a file system dataset client.""" + client = await FileSystemStorageClient().open_dataset_client( + name='test_dataset', + configuration=configuration, + ) yield client await client.drop() -async def test_open_creates_new_dataset(tmp_path: Path) -> None: +async def test_open_creates_new_dataset(configuration: Configuration) -> None: """Test that open() creates a new dataset with proper metadata when it doesn't exist.""" - client = await FileSystemDatasetClient.open(name='new_dataset', storage_dir=tmp_path) + client = await FileSystemStorageClient().open_dataset_client( + name='new_dataset', + configuration=configuration, + ) - # Verify client properties + # Verify correct client type and properties + assert isinstance(client, FileSystemDatasetClient) assert client.metadata.id is not None assert client.metadata.name == 'new_dataset' assert client.metadata.item_count == 0 @@ -53,10 +66,18 @@ async def test_open_creates_new_dataset(tmp_path: Path) -> None: assert metadata['item_count'] == 0 -async def test_open_existing_dataset(dataset_client: FileSystemDatasetClient, tmp_path: Path) -> None: +async def test_open_existing_dataset( + dataset_client: FileSystemDatasetClient, + configuration: Configuration, +) -> None: """Test that open() loads an existing dataset correctly.""" + configuration.purge_on_start = False + # Open the same dataset again - reopened_client = await FileSystemDatasetClient.open(name=dataset_client.metadata.name, storage_dir=tmp_path) + reopened_client = await FileSystemStorageClient().open_dataset_client( + name=dataset_client.metadata.name, + configuration=configuration, + ) # Verify client properties assert dataset_client.metadata.id == reopened_client.metadata.id @@ -67,10 +88,59 @@ async def test_open_existing_dataset(dataset_client: FileSystemDatasetClient, tm assert id(dataset_client) == id(reopened_client) -async def test_open_with_id_raises_error(tmp_path: Path) -> None: +async def test_dataset_client_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=True clears existing data in the dataset.""" + configuration.purge_on_start = True + + # Create dataset and add data + dataset_client1 = await FileSystemStorageClient().open_dataset_client( + name='test-purge-dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'initial data'}) + + # Verify data was added + items = await dataset_client1.get_data() + assert len(items.items) == 1 + + # Reopen + dataset_client2 = await FileSystemStorageClient().open_dataset_client( + name='test-purge-dataset', + configuration=configuration, + ) + + # Verify data was purged + items = await dataset_client2.get_data() + assert len(items.items) == 0 + + +async def test_dataset_client_no_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=False keeps existing data in the dataset.""" + configuration.purge_on_start = False + + # Create dataset and add data + dataset_client1 = await FileSystemStorageClient().open_dataset_client( + name='test-no-purge-dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'preserved data'}) + + # Reopen + dataset_client2 = await FileSystemStorageClient().open_dataset_client( + name='test-no-purge-dataset', + configuration=configuration, + ) + + # Verify data was preserved + items = await dataset_client2.get_data() + assert len(items.items) == 1 + assert items.items[0]['item'] == 'preserved data' + + +async def test_open_with_id_raises_error(configuration: Configuration) -> None: """Test that open() raises an error when an ID is provided.""" with pytest.raises(ValueError, match='not supported for file system storage client'): - await FileSystemDatasetClient.open(id='some-id', storage_dir=tmp_path) + await FileSystemStorageClient().open_dataset_client(id='some-id', configuration=configuration) async def test_push_data_single_item(dataset_client: FileSystemDatasetClient) -> None: diff --git a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py index dc0e8e721a..cf8128ede4 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py @@ -8,6 +8,8 @@ import pytest from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient from crawlee.storage_clients._file_system import FileSystemKeyValueStoreClient if TYPE_CHECKING: @@ -18,21 +20,32 @@ @pytest.fixture -async def kvs_client(tmp_path: Path) -> AsyncGenerator[FileSystemKeyValueStoreClient, None]: - """Fixture that provides a fresh file system key-value store client using a temporary directory.""" - # Clear any existing dataset clients in the cache - FileSystemKeyValueStoreClient._cache_by_name.clear() +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) - client = await FileSystemKeyValueStoreClient.open(name='test_kvs', storage_dir=tmp_path) + +@pytest.fixture +async def kvs_client(configuration: Configuration) -> AsyncGenerator[FileSystemKeyValueStoreClient, None]: + """A fixture for a file system key-value store client.""" + client = await FileSystemStorageClient().open_key_value_store_client( + name='test_kvs', + configuration=configuration, + ) yield client await client.drop() -async def test_open_creates_new_kvs(tmp_path: Path) -> None: +async def test_open_creates_new_kvs(configuration: Configuration) -> None: """Test that open() creates a new key-value store with proper metadata and files on disk.""" - client = await FileSystemKeyValueStoreClient.open(name='new_kvs', storage_dir=tmp_path) + client = await FileSystemStorageClient().open_key_value_store_client( + name='new_kvs', + configuration=configuration, + ) - # Verify client properties + # Verify correct client type and properties + assert isinstance(client, FileSystemKeyValueStoreClient) assert client.metadata.id is not None assert client.metadata.name == 'new_kvs' assert isinstance(client.metadata.created_at, datetime) @@ -50,10 +63,18 @@ async def test_open_creates_new_kvs(tmp_path: Path) -> None: assert metadata['name'] == 'new_kvs' -async def test_open_existing_kvs(kvs_client: FileSystemKeyValueStoreClient, tmp_path: Path) -> None: +async def test_open_existing_kvs( + kvs_client: FileSystemKeyValueStoreClient, + configuration: Configuration, +) -> None: """Test that open() loads an existing key-value store with matching properties.""" + configuration.purge_on_start = False + # Open the same key-value store again - reopened_client = await FileSystemKeyValueStoreClient.open(name=kvs_client.metadata.name, storage_dir=tmp_path) + reopened_client = await FileSystemStorageClient().open_key_value_store_client( + name=kvs_client.metadata.name, + configuration=configuration, + ) # Verify client properties assert kvs_client.metadata.id == reopened_client.metadata.id @@ -63,10 +84,60 @@ async def test_open_existing_kvs(kvs_client: FileSystemKeyValueStoreClient, tmp_ assert id(kvs_client) == id(reopened_client) -async def test_open_with_id_raises_error(tmp_path: Path) -> None: +async def test_kvs_client_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=True clears existing data in the key-value store.""" + configuration.purge_on_start = True + + # Create KVS and add data + kvs_client1 = await FileSystemStorageClient().open_key_value_store_client( + name='test-purge-kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='initial value') + + # Verify value was set + record = await kvs_client1.get_value(key='test-key') + assert record is not None + assert record.value == 'initial value' + + # Reopen + kvs_client2 = await FileSystemStorageClient().open_key_value_store_client( + name='test-purge-kvs', + configuration=configuration, + ) + + # Verify value was purged + record = await kvs_client2.get_value(key='test-key') + assert record is None + + +async def test_kvs_client_no_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=False keeps existing data in the key-value store.""" + configuration.purge_on_start = False + + # Create KVS and add data + kvs_client1 = await FileSystemStorageClient().open_key_value_store_client( + name='test-no-purge-kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='preserved value') + + # Reopen + kvs_client2 = await FileSystemStorageClient().open_key_value_store_client( + name='test-no-purge-kvs', + configuration=configuration, + ) + + # Verify value was preserved + record = await kvs_client2.get_value(key='test-key') + assert record is not None + assert record.value == 'preserved value' + + +async def test_open_with_id_raises_error(configuration: Configuration) -> None: """Test that open() raises an error when an ID is provided (unsupported for file system client).""" with pytest.raises(ValueError, match='not supported for file system storage client'): - await FileSystemKeyValueStoreClient.open(id='some-id', storage_dir=tmp_path) + await FileSystemStorageClient().open_key_value_store_client(id='some-id', configuration=configuration) async def test_set_get_value_string(kvs_client: FileSystemKeyValueStoreClient) -> None: diff --git a/tests/unit/storage_clients/_file_system/test_fs_storage_client.py b/tests/unit/storage_clients/_file_system/test_fs_storage_client.py deleted file mode 100644 index d5eefeffc1..0000000000 --- a/tests/unit/storage_clients/_file_system/test_fs_storage_client.py +++ /dev/null @@ -1,144 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from crawlee.storage_clients._file_system import ( - FileSystemDatasetClient, - FileSystemKeyValueStoreClient, - FileSystemStorageClient, -) - -if TYPE_CHECKING: - from pathlib import Path - -pytestmark = pytest.mark.only - - -@pytest.fixture -async def client() -> FileSystemStorageClient: - return FileSystemStorageClient() - - -async def test_open_dataset_client(client: FileSystemStorageClient, tmp_path: Path) -> None: - """Test that open_dataset_client creates a dataset client with correct type and properties.""" - dataset_client = await client.open_dataset_client(name='test-dataset', storage_dir=tmp_path) - - # Verify correct client type and properties - assert isinstance(dataset_client, FileSystemDatasetClient) - assert dataset_client.metadata.name == 'test-dataset' - - # Verify directory structure was created - assert dataset_client.path_to_dataset.exists() - - -async def test_dataset_client_purge_on_start(client: FileSystemStorageClient, tmp_path: Path) -> None: - """Test that purge_on_start=True clears existing data in the dataset.""" - # Create dataset and add data - dataset_client1 = await client.open_dataset_client( - name='test-purge-dataset', - storage_dir=tmp_path, - purge_on_start=True, - ) - await dataset_client1.push_data({'item': 'initial data'}) - - # Verify data was added - items = await dataset_client1.get_data() - assert len(items.items) == 1 - - # Reopen - dataset_client2 = await client.open_dataset_client( - name='test-purge-dataset', - storage_dir=tmp_path, - purge_on_start=True, - ) - - # Verify data was purged - items = await dataset_client2.get_data() - assert len(items.items) == 0 - - -async def test_dataset_client_no_purge_on_start(client: FileSystemStorageClient, tmp_path: Path) -> None: - """Test that purge_on_start=False keeps existing data in the dataset.""" - # Create dataset and add data - dataset_client1 = await client.open_dataset_client( - name='test-no-purge-dataset', - storage_dir=tmp_path, - purge_on_start=False, - ) - await dataset_client1.push_data({'item': 'preserved data'}) - - # Reopen - dataset_client2 = await client.open_dataset_client( - name='test-no-purge-dataset', - storage_dir=tmp_path, - purge_on_start=False, - ) - - # Verify data was preserved - items = await dataset_client2.get_data() - assert len(items.items) == 1 - assert items.items[0]['item'] == 'preserved data' - - -async def test_open_kvs_client(client: FileSystemStorageClient, tmp_path: Path) -> None: - """Test that open_key_value_store_client creates a KVS client with correct type and properties.""" - kvs_client = await client.open_key_value_store_client(name='test-kvs', storage_dir=tmp_path) - - # Verify correct client type and properties - assert isinstance(kvs_client, FileSystemKeyValueStoreClient) - assert kvs_client.metadata.name == 'test-kvs' - - # Verify directory structure was created - assert kvs_client.path_to_kvs.exists() - - -async def test_kvs_client_purge_on_start(client: FileSystemStorageClient, tmp_path: Path) -> None: - """Test that purge_on_start=True clears existing data in the key-value store.""" - # Create KVS and add data - kvs_client1 = await client.open_key_value_store_client( - name='test-purge-kvs', - storage_dir=tmp_path, - purge_on_start=True, - ) - await kvs_client1.set_value(key='test-key', value='initial value') - - # Verify value was set - record = await kvs_client1.get_value(key='test-key') - assert record is not None - assert record.value == 'initial value' - - # Reopen - kvs_client2 = await client.open_key_value_store_client( - name='test-purge-kvs', - storage_dir=tmp_path, - purge_on_start=True, - ) - - # Verify value was purged - record = await kvs_client2.get_value(key='test-key') - assert record is None - - -async def test_kvs_client_no_purge_on_start(client: FileSystemStorageClient, tmp_path: Path) -> None: - """Test that purge_on_start=False keeps existing data in the key-value store.""" - # Create KVS and add data - kvs_client1 = await client.open_key_value_store_client( - name='test-no-purge-kvs', - storage_dir=tmp_path, - purge_on_start=False, - ) - await kvs_client1.set_value(key='test-key', value='preserved value') - - # Reopen - kvs_client2 = await client.open_key_value_store_client( - name='test-no-purge-kvs', - storage_dir=tmp_path, - purge_on_start=False, - ) - - # Verify value was preserved - record = await kvs_client2.get_value(key='test-key') - assert record is not None - assert record.value == 'preserved value' diff --git a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py index 7f349daf23..4f915ff67b 100644 --- a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py @@ -6,6 +6,8 @@ import pytest +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient from crawlee.storage_clients._memory import MemoryDatasetClient from crawlee.storage_clients.models import DatasetItemsListPage @@ -18,19 +20,17 @@ @pytest.fixture async def dataset_client() -> AsyncGenerator[MemoryDatasetClient, None]: """Fixture that provides a fresh memory dataset client for each test.""" - # Clear any existing dataset clients in the cache - MemoryDatasetClient._cache_by_name.clear() - - client = await MemoryDatasetClient.open(name='test_dataset') + client = await MemoryStorageClient().open_dataset_client(name='test_dataset') yield client await client.drop() async def test_open_creates_new_dataset() -> None: """Test that open() creates a new dataset with proper metadata and adds it to the cache.""" - client = await MemoryDatasetClient.open(name='new_dataset') + client = await MemoryStorageClient().open_dataset_client(name='new_dataset') - # Verify client properties + # Verify correct client type and properties + assert isinstance(client, MemoryDatasetClient) assert client.metadata.id is not None assert client.metadata.name == 'new_dataset' assert client.metadata.item_count == 0 @@ -44,8 +44,13 @@ async def test_open_creates_new_dataset() -> None: async def test_open_existing_dataset(dataset_client: MemoryDatasetClient) -> None: """Test that open() loads an existing dataset with matching properties.""" + configuration = Configuration(purge_on_start=False) + # Open the same dataset again - reopened_client = await MemoryDatasetClient.open(name=dataset_client.metadata.name) + reopened_client = await MemoryStorageClient().open_dataset_client( + name=dataset_client.metadata.name, + configuration=configuration, + ) # Verify client properties assert dataset_client.metadata.id == reopened_client.metadata.id @@ -56,9 +61,61 @@ async def test_open_existing_dataset(dataset_client: MemoryDatasetClient) -> Non assert id(dataset_client) == id(reopened_client) +async def test_dataset_client_purge_on_start() -> None: + """Test that purge_on_start=True clears existing data in the dataset.""" + configuration = Configuration(purge_on_start=True) + + # Create dataset and add data + dataset_client1 = await MemoryStorageClient().open_dataset_client( + name='test_purge_dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'initial data'}) + + # Verify data was added + items = await dataset_client1.get_data() + assert len(items.items) == 1 + + # Reopen + dataset_client2 = await MemoryStorageClient().open_dataset_client( + name='test_purge_dataset', + configuration=configuration, + ) + + # Verify data was purged + items = await dataset_client2.get_data() + assert len(items.items) == 0 + + +async def test_dataset_client_no_purge_on_start() -> None: + """Test that purge_on_start=False keeps existing data in the dataset.""" + configuration = Configuration(purge_on_start=False) + + # Create dataset and add data + dataset_client1 = await MemoryStorageClient().open_dataset_client( + name='test_no_purge_dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'preserved data'}) + + # Reopen + dataset_client2 = await MemoryStorageClient().open_dataset_client( + name='test_no_purge_dataset', + configuration=configuration, + ) + + # Verify data was preserved + items = await dataset_client2.get_data() + assert len(items.items) == 1 + assert items.items[0]['item'] == 'preserved data' + + async def test_open_with_id_and_name() -> None: """Test that open() can be used with both id and name parameters.""" - client = await MemoryDatasetClient.open(id='some-id', name='some-name') + client = await MemoryStorageClient().open_dataset_client( + id='some-id', + name='some-name', + ) assert client.metadata.id == 'some-id' assert client.metadata.name == 'some-name' diff --git a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py index 3b3b4806a7..bb98e9fe7c 100644 --- a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py @@ -6,6 +6,8 @@ import pytest +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient from crawlee.storage_clients._memory import MemoryKeyValueStoreClient from crawlee.storage_clients.models import KeyValueStoreRecordMetadata @@ -18,18 +20,17 @@ @pytest.fixture async def kvs_client() -> AsyncGenerator[MemoryKeyValueStoreClient, None]: """Fixture that provides a fresh memory key-value store client for each test.""" - # Clear any existing key-value store clients in the cache - MemoryKeyValueStoreClient._cache_by_name.clear() - - client = await MemoryKeyValueStoreClient.open(name='test_kvs') + client = await MemoryStorageClient().open_key_value_store_client(name='test_kvs') yield client await client.drop() -async def test_open_creates_new_store() -> None: + +async def test_open_creates_new_kvs() -> None: """Test that open() creates a new key-value store with proper metadata and adds it to the cache.""" - client = await MemoryKeyValueStoreClient.open(name='new_kvs') + client = await MemoryStorageClient().open_key_value_store_client(name='new_kvs') - # Verify client properties + # Verify correct client type and properties + assert isinstance(client, MemoryKeyValueStoreClient) assert client.metadata.id is not None assert client.metadata.name == 'new_kvs' assert isinstance(client.metadata.created_at, datetime) @@ -40,10 +41,14 @@ async def test_open_creates_new_store() -> None: assert 'new_kvs' in MemoryKeyValueStoreClient._cache_by_name -async def test_open_existing_store(kvs_client: MemoryKeyValueStoreClient) -> None: +async def test_open_existing_kvs(kvs_client: MemoryKeyValueStoreClient) -> None: """Test that open() loads an existing key-value store with matching properties.""" + configuration = Configuration(purge_on_start=False) # Open the same key-value store again - reopened_client = await MemoryKeyValueStoreClient.open(name=kvs_client.metadata.name) + reopened_client = await MemoryStorageClient().open_key_value_store_client( + name=kvs_client.metadata.name, + configuration=configuration, + ) # Verify client properties assert kvs_client.metadata.id == reopened_client.metadata.id @@ -53,9 +58,62 @@ async def test_open_existing_store(kvs_client: MemoryKeyValueStoreClient) -> Non assert id(kvs_client) == id(reopened_client) +async def test_kvs_client_purge_on_start() -> None: + """Test that purge_on_start=True clears existing data in the KVS.""" + configuration = Configuration(purge_on_start=True) + + # Create KVS and add data + kvs_client1 = await MemoryStorageClient().open_key_value_store_client( + name='test_purge_kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='initial value') + + # Verify value was set + record = await kvs_client1.get_value(key='test-key') + assert record is not None + assert record.value == 'initial value' + + # Reopen + kvs_client2 = await MemoryStorageClient().open_key_value_store_client( + name='test_purge_kvs', + configuration=configuration, + ) + + # Verify value was purged + record = await kvs_client2.get_value(key='test-key') + assert record is None + + +async def test_kvs_client_no_purge_on_start() -> None: + """Test that purge_on_start=False keeps existing data in the KVS.""" + configuration = Configuration(purge_on_start=False) + + # Create KVS and add data + kvs_client1 = await MemoryStorageClient().open_key_value_store_client( + name='test_no_purge_kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='preserved value') + + # Reopen + kvs_client2 = await MemoryStorageClient().open_key_value_store_client( + name='test_no_purge_kvs', + configuration=configuration, + ) + + # Verify value was preserved + record = await kvs_client2.get_value(key='test-key') + assert record is not None + assert record.value == 'preserved value' + + async def test_open_with_id_and_name() -> None: """Test that open() can be used with both id and name parameters.""" - client = await MemoryKeyValueStoreClient.open(id='some-id', name='some-name') + client = await MemoryStorageClient().open_key_value_store_client( + id='some-id', + name='some-name', + ) assert client.metadata.id == 'some-id' assert client.metadata.name == 'some-name' From 834713f3d09466be97184abd17a7a50aa8076bee Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Thu, 17 Apr 2025 11:49:14 +0200 Subject: [PATCH 14/22] RQ and Apify clients (will be moved to SDK later) --- .../storages/rq_basic_example.py | 2 +- .../rq_with_crawler_explicit_example.py | 2 +- pyproject.toml | 6 +- src/crawlee/_utils/file.py | 6 +- src/crawlee/crawlers/_basic/_basic_crawler.py | 4 +- .../request_loaders/_request_loader.py | 8 - .../request_loaders/_request_manager.py | 10 +- .../_request_manager_tandem.py | 4 +- .../storage_clients/_apify/__init__.py | 11 + .../storage_clients/_apify/_dataset_client.py | 198 ++++++++++++++ .../_apify/_key_value_store_client.py | 210 +++++++++++++++ .../_apify/_request_queue_client.py | 251 ++++++++++++++++++ .../storage_clients/_apify/_storage_client.py | 65 +++++ src/crawlee/storage_clients/_apify/py.typed | 0 .../_base/_request_queue_client.py | 100 +++---- src/crawlee/storage_clients/models.py | 72 ++++- src/crawlee/storages/_dataset.py | 6 - src/crawlee/storages/_key_value_store.py | 6 - src/crawlee/storages/_request_queue.py | 62 ++++- .../crawlers/_basic/test_basic_crawler.py | 4 +- .../storages/test_request_manager_tandem.py | 2 +- tests/unit/storages/test_request_queue.py | 2 +- 22 files changed, 912 insertions(+), 119 deletions(-) create mode 100644 src/crawlee/storage_clients/_apify/__init__.py create mode 100644 src/crawlee/storage_clients/_apify/_dataset_client.py create mode 100644 src/crawlee/storage_clients/_apify/_key_value_store_client.py create mode 100644 src/crawlee/storage_clients/_apify/_request_queue_client.py create mode 100644 src/crawlee/storage_clients/_apify/_storage_client.py create mode 100644 src/crawlee/storage_clients/_apify/py.typed diff --git a/docs/guides/code_examples/storages/rq_basic_example.py b/docs/guides/code_examples/storages/rq_basic_example.py index 9e983bb9fe..388c184fc6 100644 --- a/docs/guides/code_examples/storages/rq_basic_example.py +++ b/docs/guides/code_examples/storages/rq_basic_example.py @@ -12,7 +12,7 @@ async def main() -> None: await request_queue.add_request('https://apify.com/') # Add multiple requests as a batch. - await request_queue.add_requests_batched( + await request_queue.add_requests( ['https://crawlee.dev/', 'https://crawlee.dev/python/'] ) diff --git a/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py b/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py index 21bedad0b9..4ef61efc82 100644 --- a/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py +++ b/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py @@ -10,7 +10,7 @@ async def main() -> None: request_queue = await RequestQueue.open(name='my-request-queue') # Interact with the request queue directly, e.g. add a batch of requests. - await request_queue.add_requests_batched( + await request_queue.add_requests( ['https://apify.com/', 'https://crawlee.dev/'] ) diff --git a/pyproject.toml b/pyproject.toml index fdc89eed71..0ad4baed98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,8 +92,8 @@ crawlee = "crawlee._cli:cli" [dependency-groups] dev = [ - "apify_client", # For e2e tests. - "build~=1.2.2", # For e2e tests. + "apify-client", # For e2e tests. + "build~=1.2.2", # For e2e tests. "mypy~=1.15.0", "pre-commit~=4.2.0", "proxy-py~=2.4.0", @@ -104,7 +104,7 @@ dev = [ "pytest-xdist~=3.6.0", "pytest~=8.3.0", "ruff~=0.11.0", - "setuptools~=79.0.0", # setuptools are used by pytest, but not explicitly required + "setuptools", # setuptools are used by pytest, but not explicitly required "sortedcontainers-stubs~=2.4.0", "types-beautifulsoup4~=4.12.0.20240229", "types-cachetools~=5.5.0.20240820", diff --git a/src/crawlee/_utils/file.py b/src/crawlee/_utils/file.py index 6a2100dd87..4de6804490 100644 --- a/src/crawlee/_utils/file.py +++ b/src/crawlee/_utils/file.py @@ -134,11 +134,7 @@ async def export_json_to_stream( **kwargs: Unpack[ExportDataJsonKwargs], ) -> None: items = [item async for item in iterator] - - if items: - json.dump(items, dst, **kwargs) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') + json.dump(items, dst, **kwargs) async def export_csv_to_stream( diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index f800f43dc3..890f2f235c 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -648,7 +648,7 @@ async def add_requests( """ request_manager = await self.get_request_manager() - await request_manager.add_requests_batched( + await request_manager.add_requests( requests=requests, batch_size=batch_size, wait_time_between_batches=wait_time_between_batches, @@ -976,7 +976,7 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> ): requests.append(dst_request) - await request_manager.add_requests_batched(requests) + await request_manager.add_requests(requests) for push_data_call in result.push_data_calls: await self._push_data(**push_data_call) diff --git a/src/crawlee/request_loaders/_request_loader.py b/src/crawlee/request_loaders/_request_loader.py index e358306a45..2e3c8a3b73 100644 --- a/src/crawlee/request_loaders/_request_loader.py +++ b/src/crawlee/request_loaders/_request_loader.py @@ -25,10 +25,6 @@ class RequestLoader(ABC): - Managing state information such as the total and handled request counts. """ - @abstractmethod - async def get_total_count(self) -> int: - """Return an offline approximation of the total number of requests in the source (i.e. pending + handled).""" - @abstractmethod async def is_empty(self) -> bool: """Return True if there are no more requests in the source (there might still be unfinished requests).""" @@ -45,10 +41,6 @@ async def fetch_next_request(self) -> Request | None: async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: """Mark a request as handled after a successful processing (or after giving up retrying).""" - @abstractmethod - async def get_handled_count(self) -> int: - """Return the number of handled requests.""" - async def to_tandem(self, request_manager: RequestManager | None = None) -> RequestManagerTandem: """Combine the loader with a request manager to support adding and reclaiming requests. diff --git a/src/crawlee/request_loaders/_request_manager.py b/src/crawlee/request_loaders/_request_manager.py index f63f962cb9..5a8427c2cb 100644 --- a/src/crawlee/request_loaders/_request_manager.py +++ b/src/crawlee/request_loaders/_request_manager.py @@ -6,12 +6,12 @@ from crawlee._utils.docs import docs_group from crawlee.request_loaders._request_loader import RequestLoader +from crawlee.storage_clients.models import ProcessedRequest if TYPE_CHECKING: from collections.abc import Sequence from crawlee._request import Request - from crawlee.storage_clients.models import ProcessedRequest @docs_group('Abstract classes') @@ -40,10 +40,11 @@ async def add_request( Information about the request addition to the manager. """ - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, # noqa: ARG002 wait_time_between_batches: timedelta = timedelta(seconds=1), # noqa: ARG002 wait_for_all_requests_to_be_added: bool = False, # noqa: ARG002 @@ -53,14 +54,17 @@ async def add_requests_batched( Args: requests: Requests to enqueue. + forefront: If True, add requests to the beginning of the queue. batch_size: The number of requests to add in one batch. wait_time_between_batches: Time to wait between adding batches. wait_for_all_requests_to_be_added: If True, wait for all requests to be added before returning. wait_for_all_requests_to_be_added_timeout: Timeout for waiting for all requests to be added. """ # Default and dumb implementation. + processed_requests = list[ProcessedRequest]() for request in requests: - await self.add_request(request) + processed_request = await self.add_request(request, forefront=forefront) + processed_requests.append(processed_request) @abstractmethod async def reclaim_request(self, request: Request, *, forefront: bool = False) -> ProcessedRequest | None: diff --git a/src/crawlee/request_loaders/_request_manager_tandem.py b/src/crawlee/request_loaders/_request_manager_tandem.py index 9f0b8cefe8..5debdb7135 100644 --- a/src/crawlee/request_loaders/_request_manager_tandem.py +++ b/src/crawlee/request_loaders/_request_manager_tandem.py @@ -49,7 +49,7 @@ async def add_request(self, request: str | Request, *, forefront: bool = False) return await self._read_write_manager.add_request(request, forefront=forefront) @override - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, @@ -58,7 +58,7 @@ async def add_requests_batched( wait_for_all_requests_to_be_added: bool = False, wait_for_all_requests_to_be_added_timeout: timedelta | None = None, ) -> None: - return await self._read_write_manager.add_requests_batched( + return await self._read_write_manager.add_requests( requests, batch_size=batch_size, wait_time_between_batches=wait_time_between_batches, diff --git a/src/crawlee/storage_clients/_apify/__init__.py b/src/crawlee/storage_clients/_apify/__init__.py new file mode 100644 index 0000000000..4af7c8ee23 --- /dev/null +++ b/src/crawlee/storage_clients/_apify/__init__.py @@ -0,0 +1,11 @@ +from ._dataset_client import ApifyDatasetClient +from ._key_value_store_client import ApifyKeyValueStoreClient +from ._request_queue_client import ApifyRequestQueueClient +from ._storage_client import ApifyStorageClient + +__all__ = [ + 'ApifyDatasetClient', + 'ApifyKeyValueStoreClient', + 'ApifyRequestQueueClient', + 'ApifyStorageClient', +] diff --git a/src/crawlee/storage_clients/_apify/_dataset_client.py b/src/crawlee/storage_clients/_apify/_dataset_client.py new file mode 100644 index 0000000000..10cb47f028 --- /dev/null +++ b/src/crawlee/storage_clients/_apify/_dataset_client.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import asyncio +from logging import getLogger +from typing import TYPE_CHECKING, Any, ClassVar + +from apify_client import ApifyClientAsync +from typing_extensions import override + +from crawlee.storage_clients._base import DatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from datetime import datetime + + from apify_client.clients import DatasetClientAsync + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class ApifyDatasetClient(DatasetClient): + """An Apify platform implementation of the dataset client.""" + + _cache_by_name: ClassVar[dict[str, ApifyDatasetClient]] = {} + """A dictionary to cache clients by their names.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + item_count: int, + api_client: DatasetClientAsync, + ) -> None: + """Initialize a new instance. + + Preferably use the `ApifyDatasetClient.open` class method to create a new instance. + """ + self._metadata = DatasetMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + item_count=item_count, + ) + + self._api_client = api_client + """The Apify dataset client for API operations.""" + + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + @override + @property + def metadata(self) -> DatasetMetadata: + return self._metadata + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> ApifyDatasetClient: + default_name = configuration.default_dataset_id + token = 'configuration.apify_token' # TODO: use the real value + api_url = 'configuration.apify_api_url' # TODO: use the real value + + name = name or default_name + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata() # noqa: SLF001 + return client + + # Otherwise, create a new one. + apify_client_async = ApifyClientAsync( + token=token, + api_url=api_url, + max_retries=8, + min_delay_between_retries_millis=500, + timeout_secs=360, + ) + + apify_datasets_client = apify_client_async.datasets() + + metadata = DatasetMetadata.model_validate( + await apify_datasets_client.get_or_create(name=id if id is not None else name), + ) + + apify_dataset_client = apify_client_async.dataset(dataset_id=metadata.id) + + client = cls( + id=metadata.id, + name=metadata.name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + item_count=metadata.item_count, + api_client=apify_dataset_client, + ) + + # Cache the client by name. + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + async with self._lock: + await self._api_client.delete() + + # Remove the client from the cache. + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + @override + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + async with self._lock: + await self._api_client.push_items(items=data) + await self._update_metadata() + + @override + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + response = await self._api_client.list_items( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + flatten=flatten, + view=view, + ) + result = DatasetItemsListPage.model_validate(vars(response)) + await self._update_metadata() + return result + + @override + async def iterate_items( + self, + *, + offset: int = 0, + limit: int | None = None, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> AsyncIterator[dict]: + async for item in self._api_client.iterate_items( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + ): + yield item + + await self._update_metadata() + + async def _update_metadata(self) -> None: + """Update the dataset metadata file with current information.""" + metadata = await self._api_client.get() + self._metadata = DatasetMetadata.model_validate(metadata) diff --git a/src/crawlee/storage_clients/_apify/_key_value_store_client.py b/src/crawlee/storage_clients/_apify/_key_value_store_client.py new file mode 100644 index 0000000000..621a9d9fe2 --- /dev/null +++ b/src/crawlee/storage_clients/_apify/_key_value_store_client.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import asyncio +from logging import getLogger +from typing import TYPE_CHECKING, Any, ClassVar + +from apify_client import ApifyClientAsync +from typing_extensions import override +from yarl import URL + +from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import ( + KeyValueStoreListKeysPage, + KeyValueStoreMetadata, + KeyValueStoreRecord, + KeyValueStoreRecordMetadata, +) + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from datetime import datetime + + from apify_client.clients import KeyValueStoreClientAsync + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class ApifyKeyValueStoreClient(KeyValueStoreClient): + """An Apify platform implementation of the key-value store client.""" + + _cache_by_name: ClassVar[dict[str, ApifyKeyValueStoreClient]] = {} + """A dictionary to cache clients by their names.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + api_client: KeyValueStoreClientAsync, + ) -> None: + """Initialize a new instance. + + Preferably use the `ApifyKeyValueStoreClient.open` class method to create a new instance. + """ + self._metadata = KeyValueStoreMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + ) + + self._api_client = api_client + """The Apify key-value store client for API operations.""" + + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + @override + @property + def metadata(self) -> KeyValueStoreMetadata: + return self._metadata + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> ApifyKeyValueStoreClient: + default_name = configuration.default_key_value_store_id + token = 'configuration.apify_token' # TODO: use the real value + api_url = 'configuration.apify_api_url' # TODO: use the real value + + name = name or default_name + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata() # noqa: SLF001 + return client + + # Otherwise, create a new one. + apify_client_async = ApifyClientAsync( + token=token, + api_url=api_url, + max_retries=8, + min_delay_between_retries_millis=500, + timeout_secs=360, + ) + + apify_kvss_client = apify_client_async.key_value_stores() + + metadata = KeyValueStoreMetadata.model_validate( + await apify_kvss_client.get_or_create(name=id if id is not None else name), + ) + + apify_kvs_client = apify_client_async.key_value_store(key_value_store_id=metadata.id) + + client = cls( + id=metadata.id, + name=metadata.name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + api_client=apify_kvs_client, + ) + + # Cache the client by name. + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + async with self._lock: + await self._api_client.delete() + + # Remove the client from the cache. + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + @override + async def get_value(self, key: str) -> KeyValueStoreRecord | None: + response = await self._api_client.get_record(key) + record = KeyValueStoreRecord.model_validate(response) if response else None + await self._update_metadata() + return record + + @override + async def set_value(self, key: str, value: Any, content_type: str | None = None) -> None: + async with self._lock: + await self._api_client.set_record( + key=key, + value=value, + content_type=content_type, + ) + await self._update_metadata() + + @override + async def delete_value(self, key: str) -> None: + async with self._lock: + await self._api_client.delete_record(key=key) + await self._update_metadata() + + @override + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + count = 0 + + while True: + response = await self._api_client.list_keys(exclusive_start_key=exclusive_start_key) + list_key_page = KeyValueStoreListKeysPage.model_validate(response) + + for item in list_key_page.items: + yield item + count += 1 + + # If we've reached the limit, stop yielding + if limit and count >= limit: + break + + # If we've reached the limit or there are no more pages, exit the loop + if (limit and count >= limit) or not list_key_page.is_truncated: + break + + exclusive_start_key = list_key_page.next_exclusive_start_key + + await self._update_metadata() + + async def get_public_url(self, key: str) -> str: + """Get a URL for the given key that may be used to publicly access the value in the remote key-value store. + + Args: + key: The key for which the URL should be generated. + """ + if self._api_client.resource_id is None: + raise ValueError('resource_id cannot be None when generating a public URL') + + public_url = ( + URL(self._api_client.base_url) / 'v2' / 'key-value-stores' / self._api_client.resource_id / 'records' / key + ) + + key_value_store = self.metadata + + if key_value_store and isinstance(getattr(key_value_store, 'model_extra', None), dict): + url_signing_secret_key = key_value_store.model_extra.get('urlSigningSecretKey') + if url_signing_secret_key: + # Note: This would require importing create_hmac_signature from apify._crypto + # public_url = public_url.with_query(signature=create_hmac_signature(url_signing_secret_key, key)) + # For now, I'll leave this part commented as we may need to add the proper import + pass + + return str(public_url) + + async def _update_metadata(self) -> None: + """Update the key-value store metadata with current information.""" + metadata = await self._api_client.get() + self._metadata = KeyValueStoreMetadata.model_validate(metadata) diff --git a/src/crawlee/storage_clients/_apify/_request_queue_client.py b/src/crawlee/storage_clients/_apify/_request_queue_client.py new file mode 100644 index 0000000000..1118239b2d --- /dev/null +++ b/src/crawlee/storage_clients/_apify/_request_queue_client.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import asyncio +from datetime import timedelta +from logging import getLogger +from typing import TYPE_CHECKING, ClassVar + +from apify_client import ApifyClientAsync +from typing_extensions import override + +from crawlee import Request +from crawlee.storage_clients._base import RequestQueueClient +from crawlee.storage_clients.models import ( + AddRequestsResponse, + ProcessedRequest, + ProlongRequestLockResponse, + RequestQueueHead, + RequestQueueMetadata, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + from datetime import datetime + + from apify_client.clients import RequestQueueClientAsync + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class ApifyRequestQueueClient(RequestQueueClient): + """An Apify platform implementation of the request queue client.""" + + _cache_by_name: ClassVar[dict[str, ApifyRequestQueueClient]] = {} + """A dictionary to cache clients by their names.""" + + _DEFAULT_LOCK_TIME = timedelta(minutes=3) + """The default lock time for requests in the queue.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + had_multiple_clients: bool, + handled_request_count: int, + pending_request_count: int, + stats: dict, + total_request_count: int, + api_client: RequestQueueClientAsync, + ) -> None: + """Initialize a new instance. + + Preferably use the `ApifyRequestQueueClient.open` class method to create a new instance. + """ + self._metadata = RequestQueueMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + had_multiple_clients=had_multiple_clients, + handled_request_count=handled_request_count, + pending_request_count=pending_request_count, + stats=stats, + total_request_count=total_request_count, + ) + + self._api_client = api_client + """The Apify key-value store client for API operations.""" + + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + self._add_requests_tasks = list[asyncio.Task]() + """A list of tasks for adding requests to the queue.""" + + self._assumed_total_count = 0 + """An assumed total count of requests in the queue.""" + + @override + @property + def metadata(self) -> RequestQueueMetadata: + return self._metadata + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> ApifyRequestQueueClient: + default_name = configuration.default_request_queue_id + + # TODO: use the real values + token = 'TOKEN' + api_url = 'https://api.apify.com' + + name = name or default_name + + # Check if the client is already cached by name. + if name in cls._cache_by_name: + client = cls._cache_by_name[name] + await client._update_metadata() # noqa: SLF001 + return client + + # Otherwise, create a new one. + apify_client_async = ApifyClientAsync( + token=token, + api_url=api_url, + max_retries=8, + min_delay_between_retries_millis=500, + timeout_secs=360, + ) + + apify_rqs_client = apify_client_async.request_queues() + + metadata = RequestQueueMetadata.model_validate( + await apify_rqs_client.get_or_create(name=id if id is not None else name), + ) + + apify_rq_client = apify_client_async.request_queue(request_queue_id=metadata.id) + + client = cls( + id=metadata.id, + name=metadata.name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + had_multiple_clients=metadata.had_multiple_clients, + handled_request_count=metadata.handled_request_count, + pending_request_count=metadata.pending_request_count, + stats=metadata.stats, + total_request_count=metadata.total_request_count, + api_client=apify_rq_client, + ) + + # Cache the client by name. + cls._cache_by_name[name] = client + + return client + + @override + async def drop(self) -> None: + async with self._lock: + await self._api_client.delete() + + # Remove the client from the cache. + if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 + del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + + @override + async def list_head( + self, + *, + lock_time: timedelta | None = None, + limit: int | None = None, + ) -> RequestQueueHead: + lock_time = lock_time or self._DEFAULT_LOCK_TIME + + response = await self._api_client.list_and_lock_head( + lock_secs=int(lock_time.total_seconds()), + limit=limit, + ) + + return RequestQueueHead.model_validate(**response) + + @override + async def add_requests( + self, + requests: Sequence[Request], + *, + forefront: bool = False, + batch_size: int = 1000, + wait_time_between_batches: timedelta = timedelta(seconds=1), + wait_for_all_requests_to_be_added: bool = False, + wait_for_all_requests_to_be_added_timeout: timedelta | None = None, + ) -> AddRequestsResponse: + requests_dict = [request.model_dump(by_alias=True) for request in requests] + response = await self._api_client.batch_add_requests(requests=requests_dict, forefront=forefront) + return AddRequestsResponse.model_validate(response) + + @override + async def get_request(self, request_id: str) -> Request | None: + response = await self._api_client.get_request(request_id) + if response is None: + return None + return Request.model_validate(**response) + + @override + async def update_request( + self, + request: Request, + *, + forefront: bool = False, + ) -> ProcessedRequest: + response = await self._api_client.update_request( + request=request.model_dump(by_alias=True), + forefront=forefront, + ) + + return ProcessedRequest.model_validate( + {'id': request.id, 'uniqueKey': request.unique_key} | response, + ) + + @override + async def is_finished(self) -> bool: + if self._add_requests_tasks: + logger.debug('Background tasks are still in progress') + return False + + # TODO + + async def _prolong_request_lock( + self, + request_id: str, + *, + forefront: bool = False, + lock_secs: int, + ) -> ProlongRequestLockResponse: + """Prolong the lock on a specific request in the queue. + + Args: + request_id: The identifier of the request whose lock is to be prolonged. + forefront: Whether to put the request in the beginning or the end of the queue after lock expires. + lock_secs: The additional amount of time, in seconds, that the request will remain locked. + """ + + async def _delete_request_lock( + self, + request_id: str, + *, + forefront: bool = False, + ) -> None: + """Delete the lock on a specific request in the queue. + + Args: + request_id: ID of the request to delete the lock. + forefront: Whether to put the request in the beginning or the end of the queue after the lock is deleted. + """ + + async def _update_metadata(self) -> None: + """Update the request queue metadata with current information.""" + metadata = await self._api_client.get() + self._metadata = RequestQueueMetadata.model_validate(metadata) diff --git a/src/crawlee/storage_clients/_apify/_storage_client.py b/src/crawlee/storage_clients/_apify/_storage_client.py new file mode 100644 index 0000000000..1d4d66dd6a --- /dev/null +++ b/src/crawlee/storage_clients/_apify/_storage_client.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing_extensions import override + +from crawlee.configuration import Configuration +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import ApifyDatasetClient +from ._key_value_store_client import ApifyKeyValueStoreClient +from ._request_queue_client import ApifyRequestQueueClient + + +class ApifyStorageClient(StorageClient): + """Apify storage client.""" + + @override + async def open_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> ApifyDatasetClient: + configuration = configuration or Configuration.get_global_configuration() + client = await ApifyDatasetClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await ApifyDatasetClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_key_value_store_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> ApifyKeyValueStoreClient: + configuration = configuration or Configuration.get_global_configuration() + client = await ApifyKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await ApifyKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_request_queue_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> ApifyRequestQueueClient: + configuration = configuration or Configuration.get_global_configuration() + client = await ApifyRequestQueueClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await ApifyRequestQueueClient.open(id=id, name=name, configuration=configuration) + + return client diff --git a/src/crawlee/storage_clients/_apify/py.typed b/src/crawlee/storage_clients/_apify/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/crawlee/storage_clients/_base/_request_queue_client.py b/src/crawlee/storage_clients/_base/_request_queue_client.py index 184d2ca97c..e506e5b763 100644 --- a/src/crawlee/storage_clients/_base/_request_queue_client.py +++ b/src/crawlee/storage_clients/_base/_request_queue_client.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from datetime import timedelta from typing import TYPE_CHECKING from crawlee._utils.docs import docs_group @@ -10,11 +11,10 @@ from crawlee.configuration import Configuration from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, + AddRequestsResponse, ProcessedRequest, - ProlongRequestLockResponse, Request, - RequestQueueHeadWithLocks, + RequestQueueHead, RequestQueueMetadata, ) @@ -60,35 +60,56 @@ async def drop(self) -> None: """ @abstractmethod - async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) -> RequestQueueHeadWithLocks: - """Fetch and lock a specified number of requests from the start of the queue. + async def list_head( + self, + *, + lock_time: timedelta | None = None, + limit: int | None = None, + ) -> RequestQueueHead: + """Retrieve requests from the beginning of the queue. - Retrieve and locks the first few requests of a queue for the specified duration. This prevents the requests - from being fetched by another client until the lock expires. + Fetches the first requests in the queue. If `lock_time` is provided, the requests will be locked + for the specified duration, preventing them from being processed by other clients until the lock expires. + This locking functionality may not be supported by all request queue client implementations. Args: - lock_secs: Duration for which the requests are locked, in seconds. - limit: Maximum number of requests to retrieve and lock. + lock_time: Duration for which to lock the retrieved requests, if supported by the client. + If None, requests will not be locked. + limit: Maximum number of requests to retrieve. Returns: - The desired number of locked requests from the beginning of the queue. + A collection of requests from the beginning of the queue, including lock information if applicable. """ @abstractmethod - async def add_requests_batch( + async def add_requests( self, requests: Sequence[Request], *, forefront: bool = False, - ) -> BatchRequestsOperationResponse: - """Add a requests to the queue in batches. + batch_size: int = 1000, + wait_time_between_batches: timedelta = timedelta(seconds=1), + wait_for_all_requests_to_be_added: bool = False, + wait_for_all_requests_to_be_added_timeout: timedelta | None = None, + ) -> AddRequestsResponse: + """Add batch of requests to the queue. + + This method adds a batch of requests to the queue. Each request is processed based on its uniqueness + (determined by `unique_key`). Duplicates will be identified but not re-added to the queue. Args: - requests: The batch of requests to add to the queue. - forefront: Whether to add the requests to the head or the end of the queue. + requests: The collection of requests to add to the queue. + forefront: Whether to put the added requests at the beginning (True) or the end (False) of the queue. + When True, the requests will be processed sooner than previously added requests. + batch_size: The maximum number of requests to add in a single batch. + wait_time_between_batches: The time to wait between adding batches of requests. + wait_for_all_requests_to_be_added: If True, the method will wait until all requests are added + to the queue before returning. + wait_for_all_requests_to_be_added_timeout: The maximum time to wait for all requests to be added. Returns: - Request queue batch operation information. + A response object containing information about which requests were successfully + processed and which failed (if any). """ @abstractmethod @@ -120,47 +141,12 @@ async def update_request( """ @abstractmethod - async def delete_request(self, request_id: str) -> None: - """Delete a request from the queue. + async def is_finished(self) -> bool: + """Check if the request queue is finished. - Args: - request_id: ID of the request to delete. - """ - - @abstractmethod - async def delete_requests_batch(self, requests: list[Request]) -> BatchRequestsOperationResponse: - """Delete given requests from the queue. + Finished means that all requests in the queue have been processed (the queue is empty) and there + are no more tasks that could add additional requests to the queue. - Args: - requests: The requests to delete from the queue. - """ - - @abstractmethod - async def prolong_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - lock_secs: int, - ) -> ProlongRequestLockResponse: - """Prolong the lock on a specific request in the queue. - - Args: - request_id: The identifier of the request whose lock is to be prolonged. - forefront: Whether to put the request in the beginning or the end of the queue after lock expires. - lock_secs: The additional amount of time, in seconds, that the request will remain locked. - """ - - @abstractmethod - async def delete_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - ) -> None: - """Delete the lock on a specific request in the queue. - - Args: - request_id: ID of the request to delete the lock. - forefront: Whether to put the request in the beginning or the end of the queue after the lock is deleted. + Returns: + True if the request queue is finished, False otherwise. """ diff --git a/src/crawlee/storage_clients/models.py b/src/crawlee/storage_clients/models.py index c470028fd4..04d1ff95ed 100644 --- a/src/crawlee/storage_clients/models.py +++ b/src/crawlee/storage_clients/models.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from datetime import datetime +from datetime import datetime, timedelta from decimal import Decimal from typing import Annotated, Any, Generic @@ -26,10 +26,19 @@ class StorageMetadata(BaseModel): model_config = ConfigDict(populate_by_name=True, extra='allow') id: Annotated[str, Field(alias='id')] + """The unique identifier of the storage.""" + name: Annotated[str, Field(alias='name', default='default')] + """The name of the storage.""" + accessed_at: Annotated[datetime, Field(alias='accessedAt')] + """The timestamp when the storage was last accessed.""" + created_at: Annotated[datetime, Field(alias='createdAt')] + """The timestamp when the storage was created.""" + modified_at: Annotated[datetime, Field(alias='modifiedAt')] + """The timestamp when the storage was last modified.""" @docs_group('Data structures') @@ -39,6 +48,7 @@ class DatasetMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) item_count: Annotated[int, Field(alias='itemCount')] + """The number of items in the dataset.""" @docs_group('Data structures') @@ -55,10 +65,19 @@ class RequestQueueMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] + """Indicates whether the queue has been accessed by multiple clients (consumers).""" + handled_request_count: Annotated[int, Field(alias='handledRequestCount')] + """The number of requests that have been handled from the queue.""" + pending_request_count: Annotated[int, Field(alias='pendingRequestCount')] + """The number of requests that are still pending in the queue.""" + stats: Annotated[dict, Field(alias='stats')] + """Statistics about the request queue, TODO?""" + total_request_count: Annotated[int, Field(alias='totalRequestCount')] + """The total number of requests that have been added to the queue.""" @docs_group('Data structures') @@ -100,11 +119,22 @@ class KeyValueStoreListKeysPage(BaseModel): model_config = ConfigDict(populate_by_name=True) count: Annotated[int, Field(alias='count')] + """The number of keys returned on this page.""" + limit: Annotated[int, Field(alias='limit')] + """The maximum number of keys to return.""" + is_truncated: Annotated[bool, Field(alias='isTruncated')] + """Indicates whether there are more keys to retrieve.""" + exclusive_start_key: Annotated[str | None, Field(alias='exclusiveStartKey', default=None)] + """The key from which to start this page of results.""" + next_exclusive_start_key: Annotated[str | None, Field(alias='nextExclusiveStartKey', default=None)] + """The key from which to start the next page of results.""" + items: Annotated[list[KeyValueStoreRecordMetadata], Field(alias='items', default_factory=list)] + """The list of KVS items metadata returned on this page.""" @docs_group('Data structures') @@ -121,17 +151,32 @@ class RequestQueueHeadState(BaseModel): @docs_group('Data structures') -class RequestQueueHeadWithLocks(BaseModel): - """Model for request queue head with locks.""" +class RequestQueueHead(BaseModel): + """Model for request queue head. + + Represents a collection of requests retrieved from the beginning of a queue, + including metadata about the queue's state and lock information for the requests. + """ model_config = ConfigDict(populate_by_name=True) limit: Annotated[int | None, Field(alias='limit', default=None)] - had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] + """The maximum number of requests that were requested from the queue.""" + + had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients', default=False)] + """Indicates whether the queue has been accessed by multiple clients (consumers).""" + queue_modified_at: Annotated[datetime, Field(alias='queueModifiedAt')] - lock_secs: Annotated[int, Field(alias='lockSecs')] - queue_has_locked_requests: Annotated[bool | None, Field(alias='queueHasLockedRequests')] = None - items: Annotated[list[Request], Field(alias='items', default_factory=list)] + """The timestamp when the queue was last modified.""" + + lock_time: Annotated[timedelta | None, Field(alias='lockSecs', default=None)] + """The duration for which the returned requests are locked and cannot be processed by other clients.""" + + queue_has_locked_requests: Annotated[bool | None, Field(alias='queueHasLockedRequests', default=False)] + """Indicates whether the queue contains any locked requests.""" + + items: Annotated[list[Request], Field(alias='items', default_factory=list[Request])] + """The list of request objects retrieved from the beginning of the queue.""" class _ListPage(BaseModel): @@ -220,13 +265,22 @@ class UnprocessedRequest(BaseModel): @docs_group('Data structures') -class BatchRequestsOperationResponse(BaseModel): - """Response to batch request deletion calls.""" +class AddRequestsResponse(BaseModel): + """Model for a response to add requests to a queue. + + Contains detailed information about the processing results when adding multiple requests + to a queue. This includes which requests were successfully processed and which ones + encountered issues during processing. + """ model_config = ConfigDict(populate_by_name=True) processed_requests: Annotated[list[ProcessedRequest], Field(alias='processedRequests')] + """Successfully processed requests, including information about whether they were + already present in the queue and whether they had been handled previously.""" + unprocessed_requests: Annotated[list[UnprocessedRequest], Field(alias='unprocessedRequests')] + """Requests that could not be processed, typically due to validation errors or other issues.""" class InternalRequest(BaseModel): diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 4b88ebbef0..bd453bc8cc 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -2,7 +2,6 @@ import logging from io import StringIO -from pathlib import Path from typing import TYPE_CHECKING, overload from typing_extensions import override @@ -101,8 +100,6 @@ async def open( *, id: str | None = None, name: str | None = None, - purge_on_start: bool | None = None, - storage_dir: Path | None = None, configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> Dataset: @@ -117,10 +114,7 @@ async def open( configuration = service_locator.get_configuration() if configuration is None else configuration storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start - storage_dir = Path(configuration.storage_dir) if storage_dir is None else storage_dir - # TODO client = await storage_client.open_dataset_client( id=id, name=name, diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 4ea439878a..10574a3910 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload from typing_extensions import override @@ -94,8 +93,6 @@ async def open( *, id: str | None = None, name: str | None = None, - purge_on_start: bool | None = None, - storage_dir: Path | None = None, configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> KeyValueStore: @@ -110,10 +107,7 @@ async def open( configuration = service_locator.get_configuration() if configuration is None else configuration storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start - storage_dir = Path(configuration.storage_dir) if storage_dir is None else storage_dir - # TODO client = await storage_client.open_key_value_store_client( id=id, name=name, diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index 2330e906f7..7a7c112ac1 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -1,16 +1,16 @@ from __future__ import annotations +import asyncio from datetime import timedelta from logging import getLogger -from pathlib import Path from typing import TYPE_CHECKING, TypeVar from typing_extensions import override -from crawlee import service_locator +from crawlee import Request, service_locator from crawlee._utils.docs import docs_group +from crawlee._utils.wait import wait_for_all_tasks_for_finish from crawlee.request_loaders import RequestManager -from crawlee.storage_clients.models import Request from ._base import Storage @@ -74,6 +74,13 @@ def __init__(self, client: RequestQueueClient) -> None: """ self._client = client + # Internal attributes + self._add_requests_tasks = list[asyncio.Task]() + self._assumed_total_count = 0 + + self._add_requests_tasks = list[asyncio.Task]() + """A list of tasks for adding requests to the queue.""" + @override @property def id(self) -> str: @@ -96,8 +103,6 @@ async def open( *, id: str | None = None, name: str | None = None, - purge_on_start: bool | None = None, - storage_dir: Path | None = None, configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> RequestQueue: @@ -106,10 +111,7 @@ async def open( configuration = service_locator.get_configuration() if configuration is None else configuration storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - purge_on_start = configuration.purge_on_start if purge_on_start is None else purge_on_start - storage_dir = Path(configuration.storage_dir) if storage_dir is None else storage_dir - # TODO client = await storage_client.open_request_queue_client( id=id, name=name, @@ -129,20 +131,56 @@ async def add_request( *, forefront: bool = False, ) -> ProcessedRequest: - return await self._client.add_requests_batch([request], forefront=forefront) + request = self._transform_request(request) + response = await self._client.add_requests([request], forefront=forefront) + return response.processed_requests[0] @override - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, wait_time_between_batches: timedelta = timedelta(seconds=1), wait_for_all_requests_to_be_added: bool = False, wait_for_all_requests_to_be_added_timeout: timedelta | None = None, ) -> None: - # TODO: implement - pass + transformed_requests = self._transform_requests(requests) + wait_time_secs = wait_time_between_batches.total_seconds() + + async def _process_batch(batch: Sequence[Request]) -> None: + request_count = len(batch) + response = await self._client.add_requests(batch, forefront=forefront) + self._assumed_total_count += request_count + logger.debug(f'Added {request_count} requests to the queue, response: {response}') + + # Wait for the first batch to be added + first_batch = transformed_requests[:batch_size] + if first_batch: + await _process_batch(first_batch) + + async def _process_remaining_batches() -> None: + for i in range(batch_size, len(transformed_requests), batch_size): + batch = transformed_requests[i : i + batch_size] + await _process_batch(batch) + if i + batch_size < len(transformed_requests): + await asyncio.sleep(wait_time_secs) + + # Create and start the task to process remaining batches in the background + remaining_batches_task = asyncio.create_task( + _process_remaining_batches(), name='request_queue_process_remaining_batches_task' + ) + self._add_requests_tasks.append(remaining_batches_task) + remaining_batches_task.add_done_callback(lambda _: self._add_requests_tasks.remove(remaining_batches_task)) + + # Wait for all tasks to finish if requested + if wait_for_all_requests_to_be_added: + await wait_for_all_tasks_for_finish( + (remaining_batches_task,), + logger=logger, + timeout=wait_for_all_requests_to_be_added_timeout, + ) # Wait for the first batch to be added first_batch = transformed_requests[:batch_size] diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index ab7a219ef7..f28531037b 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -40,7 +40,7 @@ async def test_processes_requests_from_explicit_queue() -> None: queue = await RequestQueue.open() - await queue.add_requests_batched(['http://a.com/', 'http://b.com/', 'http://c.com/']) + await queue.add_requests(['http://a.com/', 'http://b.com/', 'http://c.com/']) crawler = BasicCrawler(request_manager=queue) calls = list[str]() @@ -56,7 +56,7 @@ async def handler(context: BasicCrawlingContext) -> None: async def test_processes_requests_from_request_source_tandem() -> None: request_queue = await RequestQueue.open() - await request_queue.add_requests_batched(['http://a.com/', 'http://b.com/', 'http://c.com/']) + await request_queue.add_requests(['http://a.com/', 'http://b.com/', 'http://c.com/']) request_list = RequestList(['http://a.com/', 'http://d.com', 'http://e.com']) diff --git a/tests/unit/storages/test_request_manager_tandem.py b/tests/unit/storages/test_request_manager_tandem.py index d08ab57dc1..060484a136 100644 --- a/tests/unit/storages/test_request_manager_tandem.py +++ b/tests/unit/storages/test_request_manager_tandem.py @@ -54,7 +54,7 @@ async def test_basic_functionality(test_input: TestInput) -> None: request_queue = await RequestQueue.open() if test_input.request_manager_items: - await request_queue.add_requests_batched(test_input.request_manager_items) + await request_queue.add_requests(test_input.request_manager_items) mock_request_loader = create_autospec(RequestLoader, instance=True, spec_set=True) mock_request_loader.fetch_next_request.side_effect = lambda: test_input.request_loader_items.pop(0) diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index cddba8ef99..5de86a5544 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -158,7 +158,7 @@ async def test_add_batched_requests( request_count = len(requests) # Add the requests to the RQ in batches - await request_queue.add_requests_batched(requests, wait_for_all_requests_to_be_added=True) + await request_queue.add_requests(requests, wait_for_all_requests_to_be_added=True) # Ensure the batch was processed correctly assert await request_queue.get_total_count() == request_count From 84f9f1276d2def77d43bbd4fafa4e6466e59f98d Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Thu, 17 Apr 2025 14:15:55 +0200 Subject: [PATCH 15/22] Add init version of RQ and its clients --- .../_apify/_request_queue_client.py | 483 ++++++++++++++++-- .../_base/_request_queue_client.py | 83 +-- .../_file_system/_request_queue_client.py | 320 +++++++++++- .../_memory/_request_queue_client.py | 224 +++++++- src/crawlee/storage_clients/models.py | 19 + src/crawlee/storages/_request_queue.py | 61 ++- 6 files changed, 1084 insertions(+), 106 deletions(-) diff --git a/src/crawlee/storage_clients/_apify/_request_queue_client.py b/src/crawlee/storage_clients/_apify/_request_queue_client.py index 1118239b2d..d0f86041d2 100644 --- a/src/crawlee/storage_clients/_apify/_request_queue_client.py +++ b/src/crawlee/storage_clients/_apify/_request_queue_client.py @@ -1,17 +1,22 @@ from __future__ import annotations import asyncio -from datetime import timedelta +import os +from collections import deque +from datetime import datetime, timedelta, timezone from logging import getLogger -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING, ClassVar, Final from apify_client import ApifyClientAsync +from cachetools import LRUCache from typing_extensions import override from crawlee import Request +from crawlee._utils.requests import unique_key_to_request_id from crawlee.storage_clients._base import RequestQueueClient from crawlee.storage_clients.models import ( AddRequestsResponse, + CachedRequest, ProcessedRequest, ProlongRequestLockResponse, RequestQueueHead, @@ -20,7 +25,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from datetime import datetime from apify_client.clients import RequestQueueClientAsync @@ -35,9 +39,12 @@ class ApifyRequestQueueClient(RequestQueueClient): _cache_by_name: ClassVar[dict[str, ApifyRequestQueueClient]] = {} """A dictionary to cache clients by their names.""" - _DEFAULT_LOCK_TIME = timedelta(minutes=3) + _DEFAULT_LOCK_TIME: Final[timedelta] = timedelta(minutes=3) """The default lock time for requests in the queue.""" + _MAX_CACHED_REQUESTS: Final[int] = 1_000_000 + """Maximum number of requests that can be cached.""" + def __init__( self, *, @@ -71,16 +78,22 @@ def __init__( ) self._api_client = api_client - """The Apify key-value store client for API operations.""" + """The Apify request queue client for API operations.""" self._lock = asyncio.Lock() """A lock to ensure that only one operation is performed at a time.""" - self._add_requests_tasks = list[asyncio.Task]() - """A list of tasks for adding requests to the queue.""" + self._queue_head = deque[str]() + """A deque to store request IDs in the queue head.""" + + self._requests_cache: LRUCache[str, CachedRequest] = LRUCache(maxsize=self._MAX_CACHED_REQUESTS) + """A cache to store request objects.""" - self._assumed_total_count = 0 - """An assumed total count of requests in the queue.""" + self._queue_has_locked_requests: bool | None = None + """Whether the queue has requests locked by another client.""" + + self._should_check_for_forefront_requests = False + """Whether to check for forefront requests in the next list_head call.""" @override @property @@ -98,8 +111,8 @@ async def open( ) -> ApifyRequestQueueClient: default_name = configuration.default_request_queue_id - # TODO: use the real values - token = 'TOKEN' + # Get API credentials + token = os.environ.get('APIFY_TOKEN') api_url = 'https://api.apify.com' name = name or default_name @@ -110,7 +123,7 @@ async def open( await client._update_metadata() # noqa: SLF001 return client - # Otherwise, create a new one. + # Create a new API client apify_client_async = ApifyClientAsync( token=token, api_url=api_url, @@ -121,12 +134,14 @@ async def open( apify_rqs_client = apify_client_async.request_queues() + # Get or create the request queue metadata = RequestQueueMetadata.model_validate( await apify_rqs_client.get_or_create(name=id if id is not None else name), ) apify_rq_client = apify_client_async.request_queue(request_queue_id=metadata.id) + # Create the client instance client = cls( id=metadata.id, name=metadata.name, @@ -141,7 +156,7 @@ async def open( api_client=apify_rq_client, ) - # Cache the client by name. + # Cache the client by name cls._cache_by_name[name] = client return client @@ -151,55 +166,286 @@ async def drop(self) -> None: async with self._lock: await self._api_client.delete() - # Remove the client from the cache. + # Remove the client from the cache if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 @override - async def list_head( - self, - *, - lock_time: timedelta | None = None, - limit: int | None = None, - ) -> RequestQueueHead: - lock_time = lock_time or self._DEFAULT_LOCK_TIME - - response = await self._api_client.list_and_lock_head( - lock_secs=int(lock_time.total_seconds()), - limit=limit, - ) - - return RequestQueueHead.model_validate(**response) - - @override - async def add_requests( + async def add_batch_of_requests( self, requests: Sequence[Request], *, forefront: bool = False, - batch_size: int = 1000, - wait_time_between_batches: timedelta = timedelta(seconds=1), - wait_for_all_requests_to_be_added: bool = False, - wait_for_all_requests_to_be_added_timeout: timedelta | None = None, ) -> AddRequestsResponse: + """Add a batch of requests to the queue. + + Args: + requests: The requests to add. + forefront: Whether to add the requests to the beginning of the queue. + + Returns: + Response containing information about the added requests. + """ + # Prepare requests for API by converting to dictionaries requests_dict = [request.model_dump(by_alias=True) for request in requests] + + # Remove 'id' fields from requests as the API doesn't accept them + for request_dict in requests_dict: + if 'id' in request_dict: + del request_dict['id'] + + # Send requests to API response = await self._api_client.batch_add_requests(requests=requests_dict, forefront=forefront) + + # Update metadata after adding requests + await self._update_metadata() + return AddRequestsResponse.model_validate(response) @override async def get_request(self, request_id: str) -> Request | None: + """Get a request by ID. + + Args: + request_id: The ID of the request to get. + + Returns: + The request or None if not found. + """ response = await self._api_client.get_request(request_id) + await self._update_metadata() + if response is None: return None + return Request.model_validate(**response) @override - async def update_request( + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. + + Once you successfully finish processing of the request, you need to call `mark_request_as_handled` + to mark the request as handled in the queue. If there was some error in processing the request, call + `reclaim_request` instead, so that the queue will give the request to some other consumer + in another call to the `fetch_next_request` method. + + Returns: + The request or `None` if there are no more pending requests. + """ + # Ensure the queue head has requests if available + await self._ensure_head_is_non_empty() + + # If queue head is empty after ensuring, there are no requests + if not self._queue_head: + return None + + # Get the next request ID from the queue head + next_request_id = self._queue_head.popleft() + request = await self._get_or_hydrate_request(next_request_id) + + # Handle potential inconsistency where request might not be in the main table yet + if request is None: + logger.debug( + 'Cannot find a request from the beginning of queue, will be retried later', + extra={'nextRequestId': next_request_id}, + ) + return None + + # If the request was already handled, skip it + if request.handled_at is not None: + logger.debug( + 'Request fetched from the beginning of queue was already handled', + extra={'nextRequestId': next_request_id}, + ) + return None + + return request + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. + + Handled requests will never again be returned by the `fetch_next_request` method. + + Args: + request: The request to mark as handled. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + # Set the handled_at timestamp if not already set + if request.handled_at is None: + request.handled_at = datetime.now(tz=timezone.utc) + + try: + # Update the request in the API + processed_request = await self._update_request(request) + processed_request.unique_key = request.unique_key + + # Update the cache with the handled request + cache_key = unique_key_to_request_id(request.unique_key) + self._cache_request( + cache_key, + processed_request, + forefront=False, + hydrated_request=request, + ) + + # Update metadata after marking request as handled + await self._update_metadata() + except Exception as exc: + logger.debug(f'Error marking request {request.id} as handled: {exc!s}') + return None + else: + return processed_request + + @override + async def reclaim_request( + self, + request: Request, + *, + forefront: bool = False, + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. + + The request will be returned for processing later again by another call to `fetch_next_request`. + + Args: + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + try: + # Update the request in the API + processed_request = await self._update_request(request, forefront=forefront) + processed_request.unique_key = request.unique_key + + # Update the cache + cache_key = unique_key_to_request_id(request.unique_key) + self._cache_request( + cache_key, + processed_request, + forefront=forefront, + hydrated_request=request, + ) + + # If we're adding to the forefront, we need to check for forefront requests + # in the next list_head call + if forefront: + self._should_check_for_forefront_requests = True + + # Try to release the lock on the request + try: + await self._delete_request_lock(request.id, forefront=forefront) + except Exception as err: + logger.debug(f'Failed to delete request lock for request {request.id}', exc_info=err) + + # Update metadata after reclaiming request + await self._update_metadata() + except Exception as exc: + logger.debug(f'Error reclaiming request {request.id}: {exc!s}') + return None + else: + return processed_request + + @override + async def is_empty(self) -> bool: + """Check if the queue is empty. + + Returns: + True if the queue is empty, False otherwise. + """ + head = await self._list_head(limit=1, lock_time=None) + return len(head.items) == 0 + + async def _ensure_head_is_non_empty(self) -> None: + """Ensure that the queue head has requests if they are available in the queue.""" + # If queue head has adequate requests, skip fetching more + if len(self._queue_head) > 1 and not self._should_check_for_forefront_requests: + return + + # Fetch requests from the API and populate the queue head + await self._list_head(lock_time=self._DEFAULT_LOCK_TIME) + + async def _get_or_hydrate_request(self, request_id: str) -> Request | None: + """Get a request by ID, either from cache or by fetching from API. + + Args: + request_id: The ID of the request to get. + + Returns: + The request if found and valid, otherwise None. + """ + # First check if the request is in our cache + cached_entry = self._requests_cache.get(request_id) + + if cached_entry and cached_entry.hydrated: + # If we have the request hydrated in cache, check if lock is expired + if cached_entry.lock_expires_at and cached_entry.lock_expires_at < datetime.now(tz=timezone.utc): + # Try to prolong the lock if it's expired + try: + lock_secs = int(self._DEFAULT_LOCK_TIME.total_seconds()) + response = await self._prolong_request_lock( + request_id, forefront=cached_entry.forefront, lock_secs=lock_secs + ) + cached_entry.lock_expires_at = response.lock_expires_at + except Exception: + # If prolonging the lock fails, we lost the request + logger.debug(f'Failed to prolong lock for request {request_id}, returning None') + return None + + return cached_entry.hydrated + + # If not in cache or not hydrated, fetch the request + try: + # Try to acquire or prolong the lock + lock_secs = int(self._DEFAULT_LOCK_TIME.total_seconds()) + await self._prolong_request_lock(request_id, forefront=False, lock_secs=lock_secs) + + # Fetch the request data + request = await self.get_request(request_id) + + # If request is not found, release lock and return None + if not request: + await self._delete_request_lock(request_id) + return None + + # Update cache with hydrated request + cache_key = unique_key_to_request_id(request.unique_key) + self._cache_request( + cache_key, + ProcessedRequest( + id=request_id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=request.handled_at is not None, + ), + forefront=False, + hydrated_request=request, + ) + except Exception as exc: + logger.debug(f'Error fetching or locking request {request_id}: {exc!s}') + return None + else: + return request + + async def _update_request( self, request: Request, *, forefront: bool = False, ) -> ProcessedRequest: + """Update a request in the queue. + + Args: + request: The updated request. + forefront: Whether to put the updated request in the beginning or the end of the queue. + + Returns: + The updated request + """ response = await self._api_client.update_request( request=request.model_dump(by_alias=True), forefront=forefront, @@ -209,13 +455,109 @@ async def update_request( {'id': request.id, 'uniqueKey': request.unique_key} | response, ) - @override - async def is_finished(self) -> bool: - if self._add_requests_tasks: - logger.debug('Background tasks are still in progress') - return False + async def _list_head( + self, + *, + lock_time: timedelta | None = None, + limit: int = 25, + ) -> RequestQueueHead: + """Retrieve requests from the beginning of the queue. - # TODO + Args: + lock_time: Duration for which to lock the retrieved requests. + If None, requests will not be locked. + limit: Maximum number of requests to retrieve. + + Returns: + A collection of requests from the beginning of the queue. + """ + # Return from cache if available and we're not checking for new forefront requests + if self._queue_head and not self._should_check_for_forefront_requests: + logger.debug(f'Using cached queue head with {len(self._queue_head)} requests') + + # Create a list of requests from the cached queue head + items = [] + for request_id in list(self._queue_head)[:limit]: + cached_request = self._requests_cache.get(request_id) + if cached_request and cached_request.hydrated: + items.append(cached_request.hydrated) + + return RequestQueueHead( + limit=limit, + had_multiple_clients=self._metadata.had_multiple_clients, + queue_modified_at=self._metadata.modified_at, + items=items, + queue_has_locked_requests=self._queue_has_locked_requests, + lock_time=lock_time, + ) + + # Otherwise fetch from API + lock_time = lock_time or self._DEFAULT_LOCK_TIME + lock_secs = int(lock_time.total_seconds()) + + response = await self._api_client.list_and_lock_head( + lock_secs=lock_secs, + limit=limit, + ) + + # Update the queue head cache + self._queue_has_locked_requests = response.get('queueHasLockedRequests', False) + + # Clear current queue head if we're checking for forefront requests + if self._should_check_for_forefront_requests: + self._queue_head.clear() + self._should_check_for_forefront_requests = False + + # Process and cache the requests + head_id_buffer = list[str]() + forefront_head_id_buffer = list[str]() + + for request_data in response.get('items', []): + request = Request.model_validate(request_data) + + # Skip requests without ID or unique key + if not request.id or not request.unique_key: + logger.debug( + 'Skipping request from queue head, missing ID or unique key', + extra={ + 'id': request.id, + 'unique_key': request.unique_key, + }, + ) + continue + + # Check if this request was already cached and if it was added to forefront + cache_key = unique_key_to_request_id(request.unique_key) + cached_request = self._requests_cache.get(cache_key) + forefront = cached_request.forefront if cached_request else False + + # Add to appropriate buffer based on forefront flag + if forefront: + forefront_head_id_buffer.insert(0, request.id) + else: + head_id_buffer.append(request.id) + + # Cache the request + self._cache_request( + cache_key, + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ), + forefront=forefront, + hydrated_request=request, + ) + + # Update the queue head deque + for request_id in head_id_buffer: + self._queue_head.append(request_id) + + for request_id in forefront_head_id_buffer: + self._queue_head.appendleft(request_id) + + return RequestQueueHead.model_validate(response) async def _prolong_request_lock( self, @@ -230,7 +572,27 @@ async def _prolong_request_lock( request_id: The identifier of the request whose lock is to be prolonged. forefront: Whether to put the request in the beginning or the end of the queue after lock expires. lock_secs: The additional amount of time, in seconds, that the request will remain locked. + + Returns: + A response containing the time at which the lock will expire. """ + response = await self._api_client.prolong_request_lock( + request_id=request_id, + forefront=forefront, + lock_secs=lock_secs, + ) + + result = ProlongRequestLockResponse( + lock_expires_at=datetime.fromisoformat(response['lockExpiresAt'].replace('Z', '+00:00')) + ) + + # Update the cache with the new lock expiration + for cached_request in self._requests_cache.values(): + if cached_request.id == request_id: + cached_request.lock_expires_at = result.lock_expires_at + break + + return result async def _delete_request_lock( self, @@ -244,6 +606,43 @@ async def _delete_request_lock( request_id: ID of the request to delete the lock. forefront: Whether to put the request in the beginning or the end of the queue after the lock is deleted. """ + try: + await self._api_client.delete_request_lock( + request_id=request_id, + forefront=forefront, + ) + + # Update the cache to remove the lock + for cached_request in self._requests_cache.values(): + if cached_request.id == request_id: + cached_request.lock_expires_at = None + break + except Exception as err: + logger.debug(f'Failed to delete request lock for request {request_id}', exc_info=err) + + def _cache_request( + self, + cache_key: str, + processed_request: ProcessedRequest, + *, + forefront: bool, + hydrated_request: Request | None = None, + ) -> None: + """Cache a request for future use. + + Args: + cache_key: The key to use for caching the request. + processed_request: The processed request information. + forefront: Whether the request was added to the forefront of the queue. + hydrated_request: The hydrated request object, if available. + """ + self._requests_cache[cache_key] = CachedRequest( + id=processed_request.id, + was_already_handled=processed_request.was_already_handled, + hydrated=hydrated_request, + lock_expires_at=None, + forefront=forefront, + ) async def _update_metadata(self) -> None: """Update the request queue metadata with current information.""" diff --git a/src/crawlee/storage_clients/_base/_request_queue_client.py b/src/crawlee/storage_clients/_base/_request_queue_client.py index e506e5b763..7f2cdc11f1 100644 --- a/src/crawlee/storage_clients/_base/_request_queue_client.py +++ b/src/crawlee/storage_clients/_base/_request_queue_client.py @@ -1,7 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from datetime import timedelta from typing import TYPE_CHECKING from crawlee._utils.docs import docs_group @@ -14,7 +13,6 @@ AddRequestsResponse, ProcessedRequest, Request, - RequestQueueHead, RequestQueueMetadata, ) @@ -60,37 +58,11 @@ async def drop(self) -> None: """ @abstractmethod - async def list_head( - self, - *, - lock_time: timedelta | None = None, - limit: int | None = None, - ) -> RequestQueueHead: - """Retrieve requests from the beginning of the queue. - - Fetches the first requests in the queue. If `lock_time` is provided, the requests will be locked - for the specified duration, preventing them from being processed by other clients until the lock expires. - This locking functionality may not be supported by all request queue client implementations. - - Args: - lock_time: Duration for which to lock the retrieved requests, if supported by the client. - If None, requests will not be locked. - limit: Maximum number of requests to retrieve. - - Returns: - A collection of requests from the beginning of the queue, including lock information if applicable. - """ - - @abstractmethod - async def add_requests( + async def add_batch_of_requests( self, requests: Sequence[Request], *, forefront: bool = False, - batch_size: int = 1000, - wait_time_between_batches: timedelta = timedelta(seconds=1), - wait_for_all_requests_to_be_added: bool = False, - wait_for_all_requests_to_be_added_timeout: timedelta | None = None, ) -> AddRequestsResponse: """Add batch of requests to the queue. @@ -124,29 +96,58 @@ async def get_request(self, request_id: str) -> Request | None: """ @abstractmethod - async def update_request( + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. + + Once you successfully finish processing of the request, you need to call `RequestQueue.mark_request_as_handled` + to mark the request as handled in the queue. If there was some error in processing the request, call + `RequestQueue.reclaim_request` instead, so that the queue will give the request to some other consumer + in another call to the `fetch_next_request` method. + + Note that the `None` return value does not mean the queue processing finished, it means there are currently + no pending requests. To check whether all requests in queue were finished, use `RequestQueue.is_finished` + instead. + + Returns: + The request or `None` if there are no more pending requests. + """ + + @abstractmethod + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. + + Handled requests will never again be returned by the `RequestQueue.fetch_next_request` method. + + Args: + request: The request to mark as handled. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + + @abstractmethod + async def reclaim_request( self, request: Request, *, forefront: bool = False, - ) -> ProcessedRequest: - """Update a request in the queue. + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. + + The request will be returned for processing later again by another call to `RequestQueue.fetch_next_request`. Args: - request: The updated request. - forefront: Whether to put the updated request in the beginning or the end of the queue. + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. Returns: - The updated request + Information about the queue operation. `None` if the given request was not in progress. """ @abstractmethod - async def is_finished(self) -> bool: - """Check if the request queue is finished. - - Finished means that all requests in the queue have been processed (the queue is empty) and there - are no more tasks that could add additional requests to the queue. + async def is_empty(self) -> bool: + """Check if the request queue is empty. Returns: - True if the request queue is finished, False otherwise. + True if the request queue is empty, False otherwise. """ diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py index e5c32d860a..4769f8d80a 100644 --- a/src/crawlee/storage_clients/_file_system/_request_queue_client.py +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -11,13 +11,16 @@ from pydantic import ValidationError from typing_extensions import override +from crawlee import Request from crawlee._utils.crypto import crypto_random_object_id from crawlee.storage_clients._base import RequestQueueClient -from crawlee.storage_clients.models import RequestQueueMetadata +from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, RequestQueueMetadata from ._utils import METADATA_FILENAME, json_dumps if TYPE_CHECKING: + from collections.abc import Sequence + from crawlee.configuration import Configuration logger = getLogger(__name__) @@ -182,7 +185,320 @@ async def drop(self) -> None: if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - # TODO: other methods + @override + async def add_batch_of_requests( + self, + requests: Sequence[Request], + *, + forefront: bool = False, + ) -> AddRequestsResponse: + """Add a batch of requests to the queue. + + Args: + requests: The requests to add. + forefront: Whether to add the requests to the beginning of the queue. + + Returns: + Response containing information about the added requests. + """ + async with self._lock: + processed_requests = [] + + # Create the requests directory if it doesn't exist + requests_dir = self.path_to_rq / 'requests' + await asyncio.to_thread(requests_dir.mkdir, parents=True, exist_ok=True) + + # Create the in_progress directory if it doesn't exist + in_progress_dir = self.path_to_rq / 'in_progress' + await asyncio.to_thread(in_progress_dir.mkdir, parents=True, exist_ok=True) + + for request in requests: + # Ensure the request has an ID + if not request.id: + request.id = crypto_random_object_id() + + # Check if the request is already in the queue by unique_key + existing_request = None + + # List all request files and check for matching unique_key + request_files = await asyncio.to_thread(list, requests_dir.glob('*.json')) + for request_file in request_files: + file = await asyncio.to_thread(open, request_file) + try: + file_content = json.load(file) + if file_content.get('unique_key') == request.unique_key: + existing_request = Request(**file_content) + break + except (json.JSONDecodeError, ValidationError): + logger.warning(f'Failed to parse request file: {request_file}') + finally: + await asyncio.to_thread(file.close) + + was_already_present = existing_request is not None + was_already_handled = ( + was_already_present and existing_request and existing_request.handled_at is not None + ) + + # If the request is already in the queue and handled, don't add it again + if was_already_handled: + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + ) + continue + + # If the request is already in the queue but not handled, update it + if was_already_present: + # Update the existing request file + request_path = requests_dir / f'{request.id}.json' + request_data = await json_dumps(request.model_dump()) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + else: + # Add the new request to the queue + request_path = requests_dir / f'{request.id}.json' + request_data = await json_dumps(request.model_dump()) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + # Update metadata counts + self._metadata.total_request_count += 1 + self._metadata.pending_request_count += 1 + + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=was_already_present, + was_already_handled=False, + ) + ) + + # Update metadata + await self._update_metadata(update_modified_at=True) + + return AddRequestsResponse( + processed_requests=processed_requests, + unprocessed_requests=[], + ) + + @override + async def get_request(self, request_id: str) -> Request | None: + """Retrieve a request from the queue. + + Args: + request_id: ID of the request to retrieve. + + Returns: + The retrieved request, or None, if it did not exist. + """ + # First check in-progress directory + in_progress_dir = self.path_to_rq / 'in_progress' + in_progress_path = in_progress_dir / f'{request_id}.json' + + # Then check regular requests directory + requests_dir = self.path_to_rq / 'requests' + request_path = requests_dir / f'{request_id}.json' + + for path in [in_progress_path, request_path]: + if await asyncio.to_thread(path.exists): + file = await asyncio.to_thread(open, path) + try: + file_content = json.load(file) + return Request(**file_content) + except (json.JSONDecodeError, ValidationError) as e: + logger.warning(f'Failed to parse request file {path}: {e!s}') + finally: + await asyncio.to_thread(file.close) + + return None + + @override + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. + + Once you successfully finish processing of the request, you need to call `RequestQueue.mark_request_as_handled` + to mark the request as handled in the queue. If there was some error in processing the request, call + `RequestQueue.reclaim_request` instead, so that the queue will give the request to some other consumer + in another call to the `fetch_next_request` method. + + Returns: + The request or `None` if there are no more pending requests. + """ + async with self._lock: + # Create the requests and in_progress directories if they don't exist + requests_dir = self.path_to_rq / 'requests' + in_progress_dir = self.path_to_rq / 'in_progress' + + await asyncio.to_thread(requests_dir.mkdir, parents=True, exist_ok=True) + await asyncio.to_thread(in_progress_dir.mkdir, parents=True, exist_ok=True) + + # List all request files + request_files = await asyncio.to_thread(list, requests_dir.glob('*.json')) + + # Find a request that's not handled + for request_file in request_files: + file = await asyncio.to_thread(open, request_file) + try: + file_content = json.load(file) + # Skip if already handled + if file_content.get('handled_at') is not None: + continue + + # Create request object + request = Request(**file_content) + + # Move to in-progress + in_progress_path = in_progress_dir / f'{request.id}.json' + + # If already in in-progress, skip + if await asyncio.to_thread(in_progress_path.exists): + continue + + # Write to in-progress directory + request_data = await json_dumps(request.model_dump()) + await asyncio.to_thread(in_progress_path.write_text, request_data, encoding='utf-8') + + except (json.JSONDecodeError, ValidationError) as e: + logger.warning(f'Failed to parse request file {request_file}: {e!s}') + else: + return request + finally: + await asyncio.to_thread(file.close) + + return None + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. + + Handled requests will never again be returned by the `fetch_next_request` method. + + Args: + request: The request to mark as handled. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + async with self._lock: + # Check if the request is in progress + in_progress_dir = self.path_to_rq / 'in_progress' + in_progress_path = in_progress_dir / f'{request.id}.json' + + if not await asyncio.to_thread(in_progress_path.exists): + return None + + # Update the request object - set handled_at timestamp + if request.handled_at is None: + request.handled_at = datetime.now(timezone.utc) + + # Write the updated request back to the requests directory + requests_dir = self.path_to_rq / 'requests' + request_path = requests_dir / f'{request.id}.json' + + request_data = await json_dumps(request.model_dump()) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + # Remove the in-progress file + await asyncio.to_thread(in_progress_path.unlink, missing_ok=True) + + # Update metadata counts + self._metadata.handled_request_count += 1 + self._metadata.pending_request_count -= 1 + + # Update metadata timestamps + await self._update_metadata(update_modified_at=True) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + + @override + async def reclaim_request( + self, + request: Request, + *, + forefront: bool = False, + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. + + The request will be returned for processing later again by another call to `fetch_next_request`. + + Args: + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + async with self._lock: + # Check if the request is in progress + in_progress_dir = self.path_to_rq / 'in_progress' + in_progress_path = in_progress_dir / f'{request.id}.json' + + if not await asyncio.to_thread(in_progress_path.exists): + return None + + # Remove the in-progress file + await asyncio.to_thread(in_progress_path.unlink, missing_ok=True) + + # If forefront is true, we need to handle this specially + # Since we can't reorder files, we'll add a 'priority' field to the request + if forefront: + # Update the priority of the request to indicate it should be processed first + request.priority = 1 # Higher priority + + # Write the updated request back to the requests directory + requests_dir = self.path_to_rq / 'requests' + request_path = requests_dir / f'{request.id}.json' + + request_data = await json_dumps(request.model_dump()) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + # Update metadata timestamps + await self._update_metadata(update_modified_at=True) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + + @override + async def is_empty(self) -> bool: + """Check if the queue is empty. + + Returns: + True if the queue is empty, False otherwise. + """ + # Create the requests directory if it doesn't exist + requests_dir = self.path_to_rq / 'requests' + await asyncio.to_thread(requests_dir.mkdir, parents=True, exist_ok=True) + + # List all request files + request_files = await asyncio.to_thread(list, requests_dir.glob('*.json')) + + # Check each file to see if there are any unhandled requests + for request_file in request_files: + file = await asyncio.to_thread(open, request_file) + try: + file_content = json.load(file) + # If any request is not handled, the queue is not empty + if file_content.get('handled_at') is None: + return False + except (json.JSONDecodeError, ValidationError): + logger.warning(f'Failed to parse request file: {request_file}') + finally: + await asyncio.to_thread(file.close) + + # If we got here, all requests are handled or there are no requests + return True async def _update_metadata( self, diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index 95775bfa83..e197954644 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -9,9 +9,15 @@ from crawlee import Request from crawlee._utils.crypto import crypto_random_object_id from crawlee.storage_clients._base import RequestQueueClient -from crawlee.storage_clients.models import RequestQueueMetadata +from crawlee.storage_clients.models import ( + AddRequestsResponse, + ProcessedRequest, + RequestQueueMetadata, +) if TYPE_CHECKING: + from collections.abc import Sequence + from crawlee.configuration import Configuration logger = getLogger(__name__) @@ -62,6 +68,9 @@ def __init__( # List to hold RQ items self._records = list[Request]() + # Dictionary to track in-progress requests (fetched but not yet handled or reclaimed) + self._in_progress = dict[str, Request]() + @override @property def metadata(self) -> RequestQueueMetadata: @@ -110,12 +119,223 @@ async def open( async def drop(self) -> None: # Clear all data self._records.clear() + self._in_progress.clear() # Remove from cache if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - # TODO: other methods + @override + async def add_batch_of_requests( + self, + requests: Sequence[Request], + *, + forefront: bool = False, + ) -> AddRequestsResponse: + """Add a batch of requests to the queue. + + Args: + requests: The requests to add. + forefront: Whether to add the requests to the beginning of the queue. + + Returns: + Response containing information about the added requests. + """ + processed_requests = [] + for request in requests: + # Ensure the request has an ID + if not request.id: + request.id = crypto_random_object_id() + + # Check if the request is already in the queue by unique_key + existing_request = next((r for r in self._records if r.unique_key == request.unique_key), None) + + was_already_present = existing_request is not None + was_already_handled = was_already_present and existing_request and existing_request.handled_at is not None + + # If the request is already in the queue and handled, don't add it again + if was_already_handled: + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + ) + continue + + # If the request is already in the queue but not handled, update it + if was_already_present: + # Update the existing request with any new data + for idx, rec in enumerate(self._records): + if rec.unique_key == request.unique_key: + self._records[idx] = request + break + else: + # Add the new request to the queue + if forefront: + self._records.insert(0, request) + else: + self._records.append(request) + + # Update metadata counts + self._metadata.total_request_count += 1 + self._metadata.pending_request_count += 1 + + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=was_already_present, + was_already_handled=False, + ) + ) + + # Update metadata + await self._update_metadata(update_modified_at=True) + + return AddRequestsResponse( + processed_requests=processed_requests, + unprocessed_requests=[], + ) + + @override + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. + + Returns: + The request or `None` if there are no more pending requests. + """ + # Find the first request that's not handled or in progress + for request in self._records: + if request.handled_at is None and request.id not in self._in_progress: + # Mark as in progress + self._in_progress[request.id] = request + return request + + return None + + @override + async def get_request(self, request_id: str) -> Request | None: + """Retrieve a request from the queue. + + Args: + request_id: ID of the request to retrieve. + + Returns: + The retrieved request, or None, if it did not exist. + """ + # Check in-progress requests first + if request_id in self._in_progress: + return self._in_progress[request_id] + + # Otherwise search in the records + for request in self._records: + if request.id == request_id: + return request + + return None + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. + + Handled requests will never again be returned by the `fetch_next_request` method. + + Args: + request: The request to mark as handled. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + # Check if the request is in progress + if request.id not in self._in_progress: + return None + + # Set handled_at timestamp if not already set + if request.handled_at is None: + request.handled_at = datetime.now(timezone.utc) + + # Update the request in records + for idx, rec in enumerate(self._records): + if rec.id == request.id: + self._records[idx] = request + break + + # Remove from in-progress + del self._in_progress[request.id] + + # Update metadata counts + self._metadata.handled_request_count += 1 + self._metadata.pending_request_count -= 1 + + # Update metadata timestamps + await self._update_metadata(update_modified_at=True) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + + @override + async def reclaim_request( + self, + request: Request, + *, + forefront: bool = False, + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. + + The request will be returned for processing later again by another call to `fetch_next_request`. + + Args: + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + # Check if the request is in progress + if request.id not in self._in_progress: + return None + + # Remove from in-progress + del self._in_progress[request.id] + + # If forefront is true, move the request to the beginning of the queue + if forefront: + # First remove the request from its current position + for idx, rec in enumerate(self._records): + if rec.id == request.id: + self._records.pop(idx) + break + + # Then insert it at the beginning + self._records.insert(0, request) + + # Update metadata timestamps + await self._update_metadata(update_modified_at=True) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + + @override + async def is_empty(self) -> bool: + """Check if the queue is empty. + + Returns: + True if the queue is empty, False otherwise. + """ + # Queue is empty if there are no pending requests + pending_requests = [r for r in self._records if r.handled_at is None] + return len(pending_requests) == 0 async def _update_metadata( self, diff --git a/src/crawlee/storage_clients/models.py b/src/crawlee/storage_clients/models.py index 04d1ff95ed..f680ba945f 100644 --- a/src/crawlee/storage_clients/models.py +++ b/src/crawlee/storage_clients/models.py @@ -319,3 +319,22 @@ def from_request(cls, request: Request, id: str, order_no: Decimal | None) -> In def to_request(self) -> Request: """Convert the internal request back to a `Request` object.""" return self.request + + +class CachedRequest(BaseModel): + """Pydantic model for cached request information.""" + + id: str + """The ID of the request.""" + + was_already_handled: bool + """Whether the request was already handled.""" + + hydrated: Request | None = None + """The hydrated request object (the original one).""" + + lock_expires_at: datetime | None = None + """The expiration time of the lock on the request.""" + + forefront: bool = False + """Whether the request was added to the forefront of the queue.""" diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index 7a7c112ac1..948a03ecf1 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -3,7 +3,7 @@ import asyncio from datetime import timedelta from logging import getLogger -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, ClassVar, TypeVar from typing_extensions import override @@ -61,6 +61,12 @@ class RequestQueue(Storage, RequestManager): ``` """ + _cache_by_id: ClassVar[dict[str, RequestQueue]] = {} + """A dictionary to cache request queues by their IDs.""" + + _cache_by_name: ClassVar[dict[str, RequestQueue]] = {} + """A dictionary to cache request queues by their names.""" + _MAX_CACHED_REQUESTS = 1_000_000 """Maximum number of requests that can be cached.""" @@ -74,10 +80,6 @@ def __init__(self, client: RequestQueueClient) -> None: """ self._client = client - # Internal attributes - self._add_requests_tasks = list[asyncio.Task]() - self._assumed_total_count = 0 - self._add_requests_tasks = list[asyncio.Task]() """A list of tasks for adding requests to the queue.""" @@ -109,6 +111,12 @@ async def open( if id and name: raise ValueError('Only one of "id" or "name" can be specified, not both.') + # Check if key value store is already cached by id or name + if id and id in cls._cache_by_id: + return cls._cache_by_id[id] + if name and name in cls._cache_by_name: + return cls._cache_by_name[name] + configuration = service_locator.get_configuration() if configuration is None else configuration storage_client = service_locator.get_storage_client() if storage_client is None else storage_client @@ -122,6 +130,12 @@ async def open( @override async def drop(self) -> None: + # Remove from cache before dropping + if self.id in self._cache_by_id: + del self._cache_by_id[self.id] + if self.name and self.name in self._cache_by_name: + del self._cache_by_name[self.name] + await self._client.drop() @override @@ -132,7 +146,7 @@ async def add_request( forefront: bool = False, ) -> ProcessedRequest: request = self._transform_request(request) - response = await self._client.add_requests([request], forefront=forefront) + response = await self._client.add_batch_of_requests([request], forefront=forefront) return response.processed_requests[0] @override @@ -151,8 +165,7 @@ async def add_requests( async def _process_batch(batch: Sequence[Request]) -> None: request_count = len(batch) - response = await self._client.add_requests(batch, forefront=forefront) - self._assumed_total_count += request_count + response = await self._client.add_batch_of_requests(batch, forefront=forefront) logger.debug(f'Added {request_count} requests to the queue, response: {response}') # Wait for the first batch to be added @@ -260,7 +273,7 @@ async def fetch_next_request(self) -> Request | None: Returns: The request or `None` if there are no more pending requests. """ - # TODO: implement + return await self._client.fetch_next_request() async def get_request(self, request_id: str) -> Request | None: """Retrieve a request by its ID. @@ -284,7 +297,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | Returns: Information about the queue operation. `None` if the given request was not in progress. """ - # TODO: implement + return await self._client.mark_request_as_handled(request) async def reclaim_request( self, @@ -303,23 +316,33 @@ async def reclaim_request( Returns: Information about the queue operation. `None` if the given request was not in progress. """ - # TODO: implement + return await self._client.reclaim_request(request, forefront=forefront) async def is_empty(self) -> bool: - """Check whether the queue is empty. + """Check if the request queue is empty. + + An empty queue means that there are no requests in the queue. Returns: - `True` if the next call to `RequestQueue.fetch_next_request` would return `None`, otherwise `False`. + True if the request queue is empty, False otherwise. """ - # TODO: implement + return await self._client.is_empty() async def is_finished(self) -> bool: - """Check whether the queue is finished. + """Check if the request queue is finished. - Due to the nature of distributed storage used by the queue, the function might occasionally return a false - negative, but it will never return a false positive. + Finished means that all requests in the queue have been processed (the queue is empty) and there + are no more tasks that could add additional requests to the queue. Returns: - `True` if all requests were already handled and there are no more left. `False` otherwise. + True if the request queue is finished, False otherwise. """ - # TODO: implement + if self._add_requests_tasks: + logger.debug('Background add requests tasks are still in progress.') + return False + + if await self.is_empty(): + logger.debug('The request queue is empty.') + return True + + return False From 98dfdaf1936fb1a102545ce77bb2f8eb3361bf61 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Tue, 22 Apr 2025 15:55:15 +0200 Subject: [PATCH 16/22] Add tests for RQ --- .../_file_system/_request_queue_client.py | 10 +- src/crawlee/storages/_key_value_store.py | 4 +- src/crawlee/storages/_request_queue.py | 130 ++-- .../_memory/test_memory_storage_client.py | 288 --------- tests/unit/storages/test_dataset.py | 15 - tests/unit/storages/test_key_value_store.py | 11 - tests/unit/storages/test_request_queue.py | 589 ++++++++++-------- uv.lock | 4 +- 8 files changed, 378 insertions(+), 673 deletions(-) delete mode 100644 tests/unit/storage_clients/_memory/test_memory_storage_client.py diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py index 4769f8d80a..ed0d60f39a 100644 --- a/src/crawlee/storage_clients/_file_system/_request_queue_client.py +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -104,10 +104,10 @@ async def open( ) -> FileSystemRequestQueueClient: if id: raise ValueError( - 'Opening a dataset by "id" is not supported for file system storage client, use "name" instead.' + 'Opening a request queue by "id" is not supported for file system storage client, use "name" instead.' ) - name = name or configuration.default_dataset_id + name = name or configuration.default_request_queue_id # Check if the client is already cached by name. if name in cls._cache_by_name: @@ -123,7 +123,7 @@ async def open( if rq_path.exists(): # If metadata file is missing, raise an error. if not metadata_path.exists(): - raise ValueError(f'Metadata file not found for RQ "{name}"') + raise ValueError(f'Metadata file not found for request queue "{name}"') file = await asyncio.to_thread(open, metadata_path) try: @@ -133,7 +133,7 @@ async def open( try: metadata = RequestQueueMetadata(**file_content) except ValidationError as exc: - raise ValueError(f'Invalid metadata file for RQ "{name}"') from exc + raise ValueError(f'Invalid metadata file for request queue "{name}"') from exc client = cls( id=metadata.id, @@ -185,6 +185,8 @@ async def drop(self) -> None: if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 + # TODO: continue + @override async def add_batch_of_requests( self, diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 10574a3910..95f2b3c1c9 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -99,7 +99,7 @@ async def open( if id and name: raise ValueError('Only one of "id" or "name" can be specified, not both.') - # Check if key value store is already cached by id or name + # Check if key-value store is already cached by id or name if id and id in cls._cache_by_id: return cls._cache_by_id[id] if name and name in cls._cache_by_name: @@ -116,7 +116,7 @@ async def open( kvs = cls(client) - # Cache the key value store by id and name if available + # Cache the key-value store by id and name if available if kvs.id: cls._cache_by_id[kvs.id] = kvs if kvs.name: diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index 948a03ecf1..843ac6d0f1 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -76,7 +76,7 @@ def __init__(self, client: RequestQueueClient) -> None: Preferably use the `RequestQueue.open` constructor to create a new instance. Args: - client: An instance of a key-value store client. + client: An instance of a request queue client. """ self._client = client @@ -111,7 +111,7 @@ async def open( if id and name: raise ValueError('Only one of "id" or "name" can be specified, not both.') - # Check if key value store is already cached by id or name + # Check if request queue is already cached by id or name if id and id in cls._cache_by_id: return cls._cache_by_id[id] if name and name in cls._cache_by_name: @@ -126,7 +126,15 @@ async def open( configuration=configuration, ) - return cls(client) + rq = cls(client) + + # Cache the request queue by id and name if available + if rq.id: + cls._cache_by_id[rq.id] = rq + if rq.name: + cls._cache_by_name[rq.name] = rq + + return rq @override async def drop(self) -> None: @@ -163,27 +171,32 @@ async def add_requests( transformed_requests = self._transform_requests(requests) wait_time_secs = wait_time_between_batches.total_seconds() - async def _process_batch(batch: Sequence[Request]) -> None: - request_count = len(batch) - response = await self._client.add_batch_of_requests(batch, forefront=forefront) - logger.debug(f'Added {request_count} requests to the queue, response: {response}') - # Wait for the first batch to be added first_batch = transformed_requests[:batch_size] if first_batch: - await _process_batch(first_batch) + await self._process_batch( + first_batch, + base_retry_wait=wait_time_between_batches, + forefront=forefront, + ) async def _process_remaining_batches() -> None: for i in range(batch_size, len(transformed_requests), batch_size): batch = transformed_requests[i : i + batch_size] - await _process_batch(batch) + await self._process_batch( + batch, + base_retry_wait=wait_time_between_batches, + forefront=forefront, + ) if i + batch_size < len(transformed_requests): await asyncio.sleep(wait_time_secs) # Create and start the task to process remaining batches in the background remaining_batches_task = asyncio.create_task( - _process_remaining_batches(), name='request_queue_process_remaining_batches_task' + _process_remaining_batches(), + name='request_queue_process_remaining_batches_task', ) + self._add_requests_tasks.append(remaining_batches_task) remaining_batches_task.add_done_callback(lambda _: self._add_requests_tasks.remove(remaining_batches_task)) @@ -195,69 +208,6 @@ async def _process_remaining_batches() -> None: timeout=wait_for_all_requests_to_be_added_timeout, ) - # Wait for the first batch to be added - first_batch = transformed_requests[:batch_size] - if first_batch: - await self._process_batch(first_batch, base_retry_wait=wait_time_between_batches) - - async def _process_remaining_batches() -> None: - for i in range(batch_size, len(transformed_requests), batch_size): - batch = transformed_requests[i : i + batch_size] - await self._process_batch(batch, base_retry_wait=wait_time_between_batches) - if i + batch_size < len(transformed_requests): - await asyncio.sleep(wait_time_secs) - - # Create and start the task to process remaining batches in the background - remaining_batches_task = asyncio.create_task( - _process_remaining_batches(), name='request_queue_process_remaining_batches_task' - ) - self._tasks.append(remaining_batches_task) - remaining_batches_task.add_done_callback(lambda _: self._tasks.remove(remaining_batches_task)) - - # Wait for all tasks to finish if requested - if wait_for_all_requests_to_be_added: - await wait_for_all_tasks_for_finish( - (remaining_batches_task,), - logger=logger, - timeout=wait_for_all_requests_to_be_added_timeout, - ) - - async def _process_batch(self, batch: Sequence[Request], base_retry_wait: timedelta, attempt: int = 1) -> None: - max_attempts = 5 - response = await self._resource_client.batch_add_requests(batch) - - if response.unprocessed_requests: - logger.debug(f'Following requests were not processed: {response.unprocessed_requests}.') - if attempt > max_attempts: - logger.warning( - f'Following requests were not processed even after {max_attempts} attempts:\n' - f'{response.unprocessed_requests}' - ) - else: - logger.debug('Retry to add requests.') - unprocessed_requests_unique_keys = {request.unique_key for request in response.unprocessed_requests} - retry_batch = [request for request in batch if request.unique_key in unprocessed_requests_unique_keys] - await asyncio.sleep((base_retry_wait * attempt).total_seconds()) - await self._process_batch(retry_batch, base_retry_wait=base_retry_wait, attempt=attempt + 1) - - request_count = len(batch) - len(response.unprocessed_requests) - self._assumed_total_count += request_count - if request_count: - logger.debug( - f'Added {request_count} requests to the queue. Processed requests: {response.processed_requests}' - ) - - async def get_request(self, request_id: str) -> Request | None: - """Retrieve a request from the queue. - - Args: - request_id: ID of the request to retrieve. - - Returns: - The retrieved request, or `None`, if it does not exist. - """ - # TODO: implement - async def fetch_next_request(self) -> Request | None: """Return the next request in the queue to be processed. @@ -346,3 +296,35 @@ async def is_finished(self) -> bool: return True return False + + async def _process_batch( + self, + batch: Sequence[Request], + *, + base_retry_wait: timedelta, + attempt: int = 1, + forefront: bool = False, + ) -> None: + max_attempts = 5 + response = await self._client.add_batch_of_requests(batch, forefront=forefront) + + if response.unprocessed_requests: + logger.debug(f'Following requests were not processed: {response.unprocessed_requests}.') + if attempt > max_attempts: + logger.warning( + f'Following requests were not processed even after {max_attempts} attempts:\n' + f'{response.unprocessed_requests}' + ) + else: + logger.debug('Retry to add requests.') + unprocessed_requests_unique_keys = {request.unique_key for request in response.unprocessed_requests} + retry_batch = [request for request in batch if request.unique_key in unprocessed_requests_unique_keys] + await asyncio.sleep((base_retry_wait * attempt).total_seconds()) + await self._process_batch(retry_batch, base_retry_wait=base_retry_wait, attempt=attempt + 1) + + request_count = len(batch) - len(response.unprocessed_requests) + + if request_count: + logger.debug( + f'Added {request_count} requests to the queue. Processed requests: {response.processed_requests}' + ) diff --git a/tests/unit/storage_clients/_memory/test_memory_storage_client.py b/tests/unit/storage_clients/_memory/test_memory_storage_client.py deleted file mode 100644 index 66345fb023..0000000000 --- a/tests/unit/storage_clients/_memory/test_memory_storage_client.py +++ /dev/null @@ -1,288 +0,0 @@ -# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed -# https://github.com/apify/crawlee-python/issues/146 - -from __future__ import annotations - -from pathlib import Path - -import pytest - -from crawlee import Request, service_locator -from crawlee._consts import METADATA_FILENAME -from crawlee.configuration import Configuration -from crawlee.storage_clients import MemoryStorageClient -from crawlee.storage_clients.models import BatchRequestsOperationResponse - - -async def test_write_metadata(tmp_path: Path) -> None: - dataset_name = 'test' - dataset_no_metadata_name = 'test-no-metadata' - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - write_metadata=True, - ), - ) - ms_no_metadata = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - write_metadata=False, - ) - ) - datasets_client = ms.datasets() - datasets_no_metadata_client = ms_no_metadata.datasets() - await datasets_client.get_or_create(name=dataset_name) - await datasets_no_metadata_client.get_or_create(name=dataset_no_metadata_name) - assert Path(ms.datasets_directory, dataset_name, METADATA_FILENAME).exists() is True - assert Path(ms_no_metadata.datasets_directory, dataset_no_metadata_name, METADATA_FILENAME).exists() is False - - -@pytest.mark.parametrize( - 'persist_storage', - [ - True, - False, - ], -) -async def test_persist_storage(persist_storage: bool, tmp_path: Path) -> None: # noqa: FBT001 - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - persist_storage=persist_storage, - ) - ) - - # Key value stores - kvs_client = ms.key_value_stores() - kvs_info = await kvs_client.get_or_create(name='kvs') - await ms.key_value_store(kvs_info.id).set_record('test', {'x': 1}, 'application/json') - - path = Path(ms.key_value_stores_directory) / (kvs_info.name or '') / 'test.json' - assert path.exists() is persist_storage - - # Request queues - rq_client = ms.request_queues() - rq_info = await rq_client.get_or_create(name='rq') - - request = Request.from_url('http://lorem.com') - await ms.request_queue(rq_info.id).add_request(request) - - path = Path(ms.request_queues_directory) / (rq_info.name or '') / f'{request.id}.json' - assert path.exists() is persist_storage - - # Datasets - ds_client = ms.datasets() - ds_info = await ds_client.get_or_create(name='ds') - - await ms.dataset(ds_info.id).push_data([{'foo': 'bar'}]) - - -def test_persist_storage_set_to_false_via_string_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - monkeypatch.setenv('CRAWLEE_PERSIST_STORAGE', 'false') - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.persist_storage is False - - -def test_persist_storage_set_to_false_via_numeric_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - monkeypatch.setenv('CRAWLEE_PERSIST_STORAGE', '0') - ms = MemoryStorageClient.from_config(Configuration(crawlee_storage_dir=str(tmp_path))) # type: ignore[call-arg] - assert ms.persist_storage is False - - -def test_persist_storage_true_via_constructor_arg(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - persist_storage=True, - ) - ) - assert ms.persist_storage is True - - -def test_default_write_metadata_behavior(tmp_path: Path) -> None: - # Default behavior - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.write_metadata is True - - -def test_write_metadata_set_to_false_via_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - # Test if env var changes write_metadata to False - monkeypatch.setenv('CRAWLEE_WRITE_METADATA', 'false') - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.write_metadata is False - - -def test_write_metadata_false_via_constructor_arg_overrides_env_var(tmp_path: Path) -> None: - # Test if constructor arg takes precedence over env var value - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=False, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - assert ms.write_metadata is False - - -async def test_purge_datasets(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - # Create default and non-default datasets - datasets_client = ms.datasets() - default_dataset_info = await datasets_client.get_or_create(name='default') - non_default_dataset_info = await datasets_client.get_or_create(name='non-default') - - # Check all folders inside datasets directory before and after purge - assert default_dataset_info.name is not None - assert non_default_dataset_info.name is not None - - default_path = Path(ms.datasets_directory, default_dataset_info.name) - non_default_path = Path(ms.datasets_directory, non_default_dataset_info.name) - - assert default_path.exists() is True - assert non_default_path.exists() is True - - await ms._purge_default_storages() - - assert default_path.exists() is False - assert non_default_path.exists() is True - - -async def test_purge_key_value_stores(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - - # Create default and non-default key-value stores - kvs_client = ms.key_value_stores() - default_kvs_info = await kvs_client.get_or_create(name='default') - non_default_kvs_info = await kvs_client.get_or_create(name='non-default') - default_kvs_client = ms.key_value_store(default_kvs_info.id) - # INPUT.json should be kept - await default_kvs_client.set_record('INPUT', {'abc': 123}, 'application/json') - # test.json should not be kept - await default_kvs_client.set_record('test', {'abc': 123}, 'application/json') - - # Check all folders and files inside kvs directory before and after purge - assert default_kvs_info.name is not None - assert non_default_kvs_info.name is not None - - default_kvs_path = Path(ms.key_value_stores_directory, default_kvs_info.name) - non_default_kvs_path = Path(ms.key_value_stores_directory, non_default_kvs_info.name) - kvs_directory = Path(ms.key_value_stores_directory, 'default') - - assert default_kvs_path.exists() is True - assert non_default_kvs_path.exists() is True - - assert (kvs_directory / 'INPUT.json').exists() is True - assert (kvs_directory / 'test.json').exists() is True - - await ms._purge_default_storages() - - assert default_kvs_path.exists() is True - assert non_default_kvs_path.exists() is True - - assert (kvs_directory / 'INPUT.json').exists() is True - assert (kvs_directory / 'test.json').exists() is False - - -async def test_purge_request_queues(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - # Create default and non-default request queues - rq_client = ms.request_queues() - default_rq_info = await rq_client.get_or_create(name='default') - non_default_rq_info = await rq_client.get_or_create(name='non-default') - - # Check all folders inside rq directory before and after purge - assert default_rq_info.name - assert non_default_rq_info.name - - default_rq_path = Path(ms.request_queues_directory, default_rq_info.name) - non_default_rq_path = Path(ms.request_queues_directory, non_default_rq_info.name) - - assert default_rq_path.exists() is True - assert non_default_rq_path.exists() is True - - await ms._purge_default_storages() - - assert default_rq_path.exists() is False - assert non_default_rq_path.exists() is True - - -async def test_not_implemented_method(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - ddt = ms.dataset('test') - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await ddt.stream_items(item_format='json') - - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await ddt.stream_items(item_format='json') - - -async def test_default_storage_path_used(monkeypatch: pytest.MonkeyPatch) -> None: - # Reset the configuration in service locator - service_locator._configuration = None - service_locator._configuration_was_retrieved = False - - # Remove the env var for setting the storage directory - monkeypatch.delenv('CRAWLEE_STORAGE_DIR', raising=False) - - # Initialize the service locator with default configuration - msc = MemoryStorageClient.from_config() - assert msc.storage_dir == './storage' - - -async def test_storage_path_from_env_var_overrides_default(monkeypatch: pytest.MonkeyPatch) -> None: - # We expect the env var to override the default value - monkeypatch.setenv('CRAWLEE_STORAGE_DIR', './env_var_storage_dir') - service_locator.set_configuration(Configuration()) - ms = MemoryStorageClient.from_config() - assert ms.storage_dir == './env_var_storage_dir' - - -async def test_parametrized_storage_path_overrides_env_var() -> None: - # We expect the parametrized value to be used - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir='./parametrized_storage_dir'), # type: ignore[call-arg] - ) - assert ms.storage_dir == './parametrized_storage_dir' - - -async def test_batch_requests_operation_response() -> None: - """Test that `BatchRequestsOperationResponse` creation from example responses.""" - process_request = { - 'requestId': 'EAaArVRs5qV39C9', - 'uniqueKey': 'https://example.com', - 'wasAlreadyHandled': False, - 'wasAlreadyPresent': True, - } - unprocess_request_full = {'uniqueKey': 'https://example2.com', 'method': 'GET', 'url': 'https://example2.com'} - unprocess_request_minimal = {'uniqueKey': 'https://example3.com', 'url': 'https://example3.com'} - BatchRequestsOperationResponse.model_validate( - { - 'processedRequests': [process_request], - 'unprocessedRequests': [unprocess_request_full, unprocess_request_minimal], - } - ) diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index 4a1c8a9e23..c81c01a7ac 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -39,7 +39,6 @@ def configuration(tmp_path: Path) -> Configuration: async def dataset( storage_client: StorageClient, configuration: Configuration, - tmp_path: Path, ) -> AsyncGenerator[Dataset, None]: """Fixture that provides a dataset instance for each test.""" Dataset._cache_by_id.clear() @@ -47,7 +46,6 @@ async def dataset( dataset = await Dataset.open( name='test_dataset', - storage_dir=tmp_path, storage_client=storage_client, configuration=configuration, ) @@ -59,12 +57,10 @@ async def dataset( async def test_open_creates_new_dataset( storage_client: StorageClient, configuration: Configuration, - tmp_path: Path, ) -> None: """Test that open() creates a new dataset with proper metadata.""" dataset = await Dataset.open( name='new_dataset', - storage_dir=tmp_path, storage_client=storage_client, configuration=configuration, ) @@ -80,13 +76,11 @@ async def test_open_creates_new_dataset( async def test_open_existing_dataset( dataset: Dataset, storage_client: StorageClient, - tmp_path: Path, ) -> None: """Test that open() loads an existing dataset correctly.""" # Open the same dataset again reopened_dataset = await Dataset.open( name=dataset.name, - storage_dir=tmp_path, storage_client=storage_client, ) @@ -102,14 +96,12 @@ async def test_open_existing_dataset( async def test_open_with_id_and_name( storage_client: StorageClient, configuration: Configuration, - tmp_path: Path, ) -> None: """Test that open() raises an error when both id and name are provided.""" with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): await Dataset.open( id='some-id', name='some-name', - storage_dir=tmp_path, storage_client=storage_client, configuration=configuration, ) @@ -251,12 +243,10 @@ async def test_iterate_items_with_options(dataset: Dataset) -> None: async def test_drop( storage_client: StorageClient, configuration: Configuration, - tmp_path: Path, ) -> None: """Test dropping a dataset removes it from cache and clears its data.""" dataset = await Dataset.open( name='drop_test', - storage_dir=tmp_path, storage_client=storage_client, configuration=configuration, ) @@ -280,7 +270,6 @@ async def test_drop( # Verify dataset is empty (by creating a new one with the same name) new_dataset = await Dataset.open( name='drop_test', - storage_dir=tmp_path, storage_client=storage_client, configuration=configuration, ) @@ -293,13 +282,11 @@ async def test_drop( async def test_export_to_json( dataset: Dataset, storage_client: StorageClient, - tmp_path: Path, ) -> None: """Test exporting dataset to JSON format.""" # Create a key-value store for export kvs = await KeyValueStore.open( name='export_kvs', - storage_dir=tmp_path, storage_client=storage_client, ) @@ -333,13 +320,11 @@ async def test_export_to_json( async def test_export_to_csv( dataset: Dataset, storage_client: StorageClient, - tmp_path: Path, ) -> None: """Test exporting dataset to CSV format.""" # Create a key-value store for export kvs = await KeyValueStore.open( name='export_kvs', - storage_dir=tmp_path, storage_client=storage_client, ) diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index 6312009f81..c03e8ae332 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -40,7 +40,6 @@ def configuration(tmp_path: Path) -> Configuration: async def kvs( storage_client: StorageClient, configuration: Configuration, - tmp_path: Path, ) -> AsyncGenerator[KeyValueStore, None]: """Fixture that provides a key-value store instance for each test.""" KeyValueStore._cache_by_id.clear() @@ -48,7 +47,6 @@ async def kvs( kvs = await KeyValueStore.open( name='test_kvs', - storage_dir=tmp_path, storage_client=storage_client, configuration=configuration, ) @@ -60,12 +58,10 @@ async def kvs( async def test_open_creates_new_kvs( storage_client: StorageClient, configuration: Configuration, - tmp_path: Path, ) -> None: """Test that open() creates a new key-value store with proper metadata.""" kvs = await KeyValueStore.open( name='new_kvs', - storage_dir=tmp_path, storage_client=storage_client, configuration=configuration, ) @@ -80,13 +76,11 @@ async def test_open_creates_new_kvs( async def test_open_existing_kvs( kvs: KeyValueStore, storage_client: StorageClient, - tmp_path: Path, ) -> None: """Test that open() loads an existing key-value store correctly.""" # Open the same key-value store again reopened_kvs = await KeyValueStore.open( name=kvs.name, - storage_dir=tmp_path, storage_client=storage_client, ) @@ -101,14 +95,12 @@ async def test_open_existing_kvs( async def test_open_with_id_and_name( storage_client: StorageClient, configuration: Configuration, - tmp_path: Path, ) -> None: """Test that open() raises an error when both id and name are provided.""" with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): await KeyValueStore.open( id='some-id', name='some-name', - storage_dir=tmp_path, storage_client=storage_client, configuration=configuration, ) @@ -262,12 +254,10 @@ async def test_iterate_keys_with_limit(kvs: KeyValueStore) -> None: async def test_drop( storage_client: StorageClient, configuration: Configuration, - tmp_path: Path, ) -> None: """Test dropping a key-value store removes it from cache and clears its data.""" kvs = await KeyValueStore.open( name='drop_test', - storage_dir=tmp_path, storage_client=storage_client, configuration=configuration, ) @@ -291,7 +281,6 @@ async def test_drop( # Verify key-value store is empty (by creating a new one with the same name) new_kvs = await KeyValueStore.open( name='drop_test', - storage_dir=tmp_path, storage_client=storage_client, configuration=configuration, ) diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index 5de86a5544..78404dc1e0 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -1,367 +1,402 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations import asyncio -from datetime import datetime, timedelta, timezone -from itertools import count from typing import TYPE_CHECKING -from unittest.mock import AsyncMock, MagicMock import pytest -from pydantic import ValidationError - -from crawlee import Request, service_locator -from crawlee._request import RequestState -from crawlee.storage_clients import MemoryStorageClient, StorageClient -from crawlee.storage_clients._memory import RequestQueueClient -from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, - StorageMetadata, - UnprocessedRequest, -) + +from crawlee import Request +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, StorageClient from crawlee.storages import RequestQueue if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Sequence + from collections.abc import AsyncGenerator + from pathlib import Path + + +pytestmark = pytest.mark.only + + +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() + + return FileSystemStorageClient() @pytest.fixture -async def request_queue() -> AsyncGenerator[RequestQueue, None]: - rq = await RequestQueue.open() +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration(crawlee_storage_dir=str(tmp_path)) # type: ignore[call-arg] + + +@pytest.fixture +async def rq( + storage_client: StorageClient, + configuration: Configuration, +) -> AsyncGenerator[RequestQueue, None]: + """Fixture that provides a request queue instance for each test.""" + RequestQueue._cache_by_id.clear() + RequestQueue._cache_by_name.clear() + + rq = await RequestQueue.open( + name='test_request_queue', + storage_client=storage_client, + configuration=configuration, + ) + yield rq await rq.drop() -async def test_open() -> None: - default_request_queue = await RequestQueue.open() - default_request_queue_by_id = await RequestQueue.open(id=default_request_queue.id) +async def test_open_creates_new_rq( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() creates a new request queue with proper metadata.""" + rq = await RequestQueue.open( + name='new_request_queue', + storage_client=storage_client, + configuration=configuration, + ) - assert default_request_queue is default_request_queue_by_id + # Verify request queue properties + assert rq.id is not None + assert rq.name == 'new_request_queue' + assert rq.metadata.pending_request_count == 0 + assert rq.metadata.handled_request_count == 0 + assert rq.metadata.total_request_count == 0 - request_queue_name = 'dummy-name' - named_request_queue = await RequestQueue.open(name=request_queue_name) - assert default_request_queue is not named_request_queue + await rq.drop() - with pytest.raises(RuntimeError, match='RequestQueue with id "nonexistent-id" does not exist!'): - await RequestQueue.open(id='nonexistent-id') - # Test that when you try to open a request queue by ID and you use a name of an existing request queue, - # it doesn't work - with pytest.raises(RuntimeError, match='RequestQueue with id "dummy-name" does not exist!'): - await RequestQueue.open(id='dummy-name') +async def test_open_existing_rq( + rq: RequestQueue, + storage_client: StorageClient, +) -> None: + """Test that open() loads an existing request queue correctly.""" + # Open the same request queue again + reopened_rq = await RequestQueue.open( + name=rq.name, + storage_client=storage_client, + ) + # Verify request queue properties + assert rq.id == reopened_rq.id + assert rq.name == reopened_rq.name -async def test_consistency_accross_two_clients() -> None: - request_apify = Request.from_url('https://apify.com') - request_crawlee = Request.from_url('https://crawlee.dev') + # Verify they are the same object (from cache) + assert id(rq) == id(reopened_rq) - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_apify) - rq_by_id = await RequestQueue.open(id=rq.id) - await rq_by_id.add_request(request_crawlee) +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await RequestQueue.open( + id='some-id', + name='some-name', + storage_client=storage_client, + configuration=configuration, + ) - assert await rq.get_total_count() == 2 - assert await rq_by_id.get_total_count() == 2 - assert await rq.fetch_next_request() == request_apify - assert await rq_by_id.fetch_next_request() == request_crawlee +async def test_add_request_string_url(rq: RequestQueue) -> None: + """Test adding a request with a string URL.""" + # Add a request with a string URL + url = 'https://example.com' + result = await rq.add_request(url) - await rq.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await rq_by_id.drop() + # Verify request was added + assert result.id is not None + assert result.unique_key is not None + assert result.was_already_present is False + assert result.was_already_handled is False + # Verify the queue stats were updated + assert rq.metadata.total_request_count == 1 + assert rq.metadata.pending_request_count == 1 -async def test_same_references() -> None: - rq1 = await RequestQueue.open() - rq2 = await RequestQueue.open() - assert rq1 is rq2 - rq_name = 'non-default' - rq_named1 = await RequestQueue.open(name=rq_name) - rq_named2 = await RequestQueue.open(name=rq_name) - assert rq_named1 is rq_named2 +async def test_add_request_object(rq: RequestQueue) -> None: + """Test adding a request object.""" + # Create and add a request object + request = Request.from_url(url='https://example.com', user_data={'key': 'value'}) + result = await rq.add_request(request) + # Verify request was added + assert result.id is not None + assert result.unique_key is not None + assert result.was_already_present is False + assert result.was_already_handled is False -async def test_drop() -> None: - rq1 = await RequestQueue.open() - await rq1.drop() - rq2 = await RequestQueue.open() - assert rq1 is not rq2 + # Verify the queue stats were updated + assert rq.metadata.total_request_count == 1 + assert rq.metadata.pending_request_count == 1 -async def test_get_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - processed_request = await request_queue.add_request(request) - assert request.id == processed_request.id - request_2 = await request_queue.get_request(request.id) - assert request_2 is not None - assert request == request_2 +async def test_add_duplicate_request(rq: RequestQueue) -> None: + """Test adding a duplicate request to the queue.""" + # Add a request + url = 'https://example.com' + first_result = await rq.add_request(url) + # Add the same request again + second_result = await rq.add_request(url) -async def test_add_fetch_handle_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - assert await request_queue.is_empty() is True - add_request_info = await request_queue.add_request(request) + # Verify the second request was detected as duplicate + assert second_result.was_already_present is True + assert second_result.unique_key == first_result.unique_key - assert add_request_info.was_already_present is False - assert add_request_info.was_already_handled is False - assert await request_queue.is_empty() is False + # Verify the queue stats weren't incremented twice + assert rq.metadata.total_request_count == 1 + assert rq.metadata.pending_request_count == 1 - # Fetch the request - next_request = await request_queue.fetch_next_request() - assert next_request is not None - # Mark it as handled - next_request.handled_at = datetime.now(timezone.utc) - processed_request = await request_queue.mark_request_as_handled(next_request) +async def test_add_requests_batch(rq: RequestQueue) -> None: + """Test adding multiple requests in a batch.""" + # Create a batch of requests + urls = [ + 'https://example.com/page1', + 'https://example.com/page2', + 'https://example.com/page3', + ] - assert processed_request is not None - assert processed_request.id == request.id - assert processed_request.unique_key == request.unique_key - assert await request_queue.is_finished() is True + # Add the requests + await rq.add_requests(urls) + # Wait for all background tasks to complete + await asyncio.sleep(0.1) -async def test_reclaim_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - await request_queue.add_request(request) + # Verify the queue stats + assert rq.metadata.total_request_count == 3 + assert rq.metadata.pending_request_count == 3 - # Fetch the request - next_request = await request_queue.fetch_next_request() + +async def test_add_requests_with_forefront(rq: RequestQueue) -> None: + """Test adding requests to the front of the queue.""" + # Add some initial requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') + + # Add a priority request at the forefront + await rq.add_request('https://example.com/priority', forefront=True) + + # Fetch the next request - should be the priority one + next_request = await rq.fetch_next_request() assert next_request is not None - assert next_request.unique_key == request.url - - # Reclaim - await request_queue.reclaim_request(next_request) - # Try to fetch again after a few secs - await asyncio.sleep(4) # 3 seconds is the consistency delay in request queue - next_again = await request_queue.fetch_next_request() - - assert next_again is not None - assert next_again.id == request.id - assert next_again.unique_key == request.unique_key - - -@pytest.mark.parametrize( - 'requests', - [ - [Request.from_url('https://apify.com')], - ['https://crawlee.dev'], - [Request.from_url(f'https://example.com/{i}') for i in range(10)], - [f'https://example.com/{i}' for i in range(15)], - ], - ids=['single-request', 'single-url', 'multiple-requests', 'multiple-urls'], -) -async def test_add_batched_requests( - request_queue: RequestQueue, - requests: Sequence[str | Request], -) -> None: - request_count = len(requests) + assert next_request.url == 'https://example.com/priority' - # Add the requests to the RQ in batches - await request_queue.add_requests(requests, wait_for_all_requests_to_be_added=True) - # Ensure the batch was processed correctly - assert await request_queue.get_total_count() == request_count +async def test_fetch_next_request_and_mark_handled(rq: RequestQueue) -> None: + """Test fetching and marking requests as handled.""" + # Add some requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') - # Fetch and validate each request in the queue - for original_request in requests: - next_request = await request_queue.fetch_next_request() - assert next_request is not None + # Fetch first request + request1 = await rq.fetch_next_request() + assert request1 is not None + assert request1.url == 'https://example.com/page1' - expected_url = original_request if isinstance(original_request, str) else original_request.url - assert next_request.url == expected_url + # Mark the request as handled + result = await rq.mark_request_as_handled(request1) + assert result is not None + assert result.was_already_handled is True - # Confirm the queue is empty after processing all requests - assert await request_queue.is_empty() is True + # Fetch next request + request2 = await rq.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/page2' + # Mark the second request as handled + await rq.mark_request_as_handled(request2) -async def test_invalid_user_data_serialization() -> None: - with pytest.raises(ValidationError): - Request.from_url( - 'https://crawlee.dev', - user_data={ - 'foo': datetime(year=2020, month=7, day=4, tzinfo=timezone.utc), - 'bar': {datetime(year=2020, month=4, day=7, tzinfo=timezone.utc)}, - }, - ) + # Verify counts + assert rq.metadata.total_request_count == 2 + assert rq.metadata.handled_request_count == 2 + assert rq.metadata.pending_request_count == 0 + # Verify queue is empty + empty_request = await rq.fetch_next_request() + assert empty_request is None -async def test_user_data_serialization(request_queue: RequestQueue) -> None: - request = Request.from_url( - 'https://crawlee.dev', - user_data={ - 'hello': 'world', - 'foo': 42, - }, - ) - await request_queue.add_request(request) +async def test_get_request_by_id(rq: RequestQueue) -> None: + """Test retrieving a request by its ID.""" + # Add a request + added_result = await rq.add_request('https://example.com') + request_id = added_result.id - dequeued_request = await request_queue.fetch_next_request() - assert dequeued_request is not None + # Retrieve the request by ID + retrieved_request = await rq.get_request(request_id) + assert retrieved_request is not None + assert retrieved_request.id == request_id + assert retrieved_request.url == 'https://example.com' - assert dequeued_request.user_data['hello'] == 'world' - assert dequeued_request.user_data['foo'] == 42 +async def test_get_non_existent_request(rq: RequestQueue) -> None: + """Test retrieving a request that doesn't exist.""" + non_existent_request = await rq.get_request('non-existent-id') + assert non_existent_request is None -async def test_complex_user_data_serialization(request_queue: RequestQueue) -> None: - request = Request.from_url('https://crawlee.dev') - request.user_data['hello'] = 'world' - request.user_data['foo'] = 42 - request.crawlee_data.max_retries = 1 - request.crawlee_data.state = RequestState.ERROR_HANDLER - await request_queue.add_request(request) +async def test_reclaim_request(rq: RequestQueue) -> None: + """Test reclaiming a request that failed processing.""" + # Add a request + await rq.add_request('https://example.com') - dequeued_request = await request_queue.fetch_next_request() - assert dequeued_request is not None + # Fetch the request + request = await rq.fetch_next_request() + assert request is not None - data = dequeued_request.model_dump(by_alias=True) - assert data['userData']['hello'] == 'world' - assert data['userData']['foo'] == 42 - assert data['userData']['__crawlee'] == { - 'maxRetries': 1, - 'state': RequestState.ERROR_HANDLER, - } + # Reclaim the request + result = await rq.reclaim_request(request) + assert result is not None + assert result.was_already_handled is False + # Verify we can fetch it again + reclaimed_request = await rq.fetch_next_request() + assert reclaimed_request is not None + assert reclaimed_request.id == request.id + assert reclaimed_request.url == 'https://example.com' -async def test_deduplication_of_requests_with_custom_unique_key() -> None: - with pytest.raises(ValueError, match='`always_enqueue` cannot be used with a custom `unique_key`'): - Request.from_url('https://apify.com', unique_key='apify', always_enqueue=True) +async def test_reclaim_request_with_forefront(rq: RequestQueue) -> None: + """Test reclaiming a request to the front of the queue.""" + # Add requests + await rq.add_request('https://example.com/first') + await rq.add_request('https://example.com/second') -async def test_deduplication_of_requests_with_invalid_custom_unique_key() -> None: - request_1 = Request.from_url('https://apify.com', always_enqueue=True) - request_2 = Request.from_url('https://apify.com', always_enqueue=True) + # Fetch the first request + first_request = await rq.fetch_next_request() + assert first_request is not None + assert first_request.url == 'https://example.com/first' - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_1) - await rq.add_request(request_2) + # Reclaim it to the forefront + await rq.reclaim_request(first_request, forefront=True) - assert await rq.get_total_count() == 2 + # The reclaimed request should be returned first (before the second request) + next_request = await rq.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/first' - assert await rq.fetch_next_request() == request_1 - assert await rq.fetch_next_request() == request_2 +async def test_is_empty(rq: RequestQueue) -> None: + """Test checking if a request queue is empty.""" + # Initially the queue should be empty + assert await rq.is_empty() is True -async def test_deduplication_of_requests_with_valid_custom_unique_key() -> None: - request_1 = Request.from_url('https://apify.com') - request_2 = Request.from_url('https://apify.com') + # Add a request + await rq.add_request('https://example.com') + assert await rq.is_empty() is False - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_1) - await rq.add_request(request_2) + # Fetch and handle the request + request = await rq.fetch_next_request() - assert await rq.get_total_count() == 1 + assert request is not None + await rq.mark_request_as_handled(request) - assert await rq.fetch_next_request() == request_1 + # Queue should be empty again + assert await rq.is_empty() is True -async def test_cache_requests(request_queue: RequestQueue) -> None: - request_1 = Request.from_url('https://apify.com') - request_2 = Request.from_url('https://crawlee.dev') +async def test_is_finished(rq: RequestQueue) -> None: + """Test checking if a request queue is finished.""" + # Initially the queue should be finished (empty and no background tasks) + assert await rq.is_finished() is True - await request_queue.add_request(request_1) - await request_queue.add_request(request_2) + # Add a request + await rq.add_request('https://example.com') + assert await rq.is_finished() is False - assert request_queue._requests_cache.currsize == 2 + # Add requests in the background + await rq.add_requests( + ['https://example.com/1', 'https://example.com/2'], + wait_for_all_requests_to_be_added=False, + ) - fetched_request = await request_queue.fetch_next_request() + # Queue shouldn't be finished while background tasks are running + assert await rq.is_finished() is False - assert fetched_request is not None - assert fetched_request.id == request_1.id + # Wait for background tasks to finish + await asyncio.sleep(0.2) - # After calling fetch_next_request request_1 moved to the end of the cache store. - cached_items = [request_queue._requests_cache.popitem()[0] for _ in range(2)] - assert cached_items == [request_2.id, request_1.id] + # Process all requests + while True: + request = await rq.fetch_next_request() + if request is None: + break + await rq.mark_request_as_handled(request) + # Now queue should be finished + assert await rq.is_finished() is True -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', - ) +async def test_mark_non_existent_request_as_handled(rq: RequestQueue) -> None: + """Test marking a non-existent request as handled.""" + # Create a request that hasn't been added to the queue + request = Request.from_url(url='https://example.com', id='non-existent-id') - request_queue = RequestQueue.from_storage_object(storage_client, storage_object) - - assert request_queue.id == storage_object.id - assert request_queue.name == storage_object.name - assert request_queue.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] - - -async def test_add_batched_requests_with_retry(request_queue: RequestQueue) -> None: - """Test that unprocessed requests are retried. - - Unprocessed requests should not count in `get_total_count` - Test creates situation where in `batch_add_requests` call in first batch 3 requests are unprocessed. - On each following `batch_add_requests` call the last request in batch remains unprocessed. - In this test `batch_add_requests` is called once with batch of 10 requests. With retries only 1 request should - remain unprocessed.""" - - batch_add_requests_call_counter = count(start=1) - service_locator.get_storage_client() - initial_request_count = 10 - expected_added_requests = 9 - requests = [f'https://example.com/{i}' for i in range(initial_request_count)] - - class MockedRequestQueueClient(RequestQueueClient): - """Patched memory storage client that simulates unprocessed requests.""" - - async def _batch_add_requests_without_last_n( - self, batch: Sequence[Request], n: int = 0 - ) -> BatchRequestsOperationResponse: - response = await super().batch_add_requests(batch[:-n]) - response.unprocessed_requests = [ - UnprocessedRequest(url=r.url, unique_key=r.unique_key, method=r.method) for r in batch[-n:] - ] - return response - - async def batch_add_requests( - self, - requests: Sequence[Request], - *, - forefront: bool = False, # noqa: ARG002 - ) -> BatchRequestsOperationResponse: - """Mocked client behavior that simulates unprocessed requests. - - It processes all except last three at first run, then all except last none. - Overall if tried with the same batch it will process all except the last one. - """ - call_count = next(batch_add_requests_call_counter) - if call_count == 1: - # Process all but last three - return await self._batch_add_requests_without_last_n(requests, n=3) - # Process all but last - return await self._batch_add_requests_without_last_n(requests, n=1) - - mocked_storage_client = AsyncMock(spec=StorageClient) - mocked_storage_client.request_queue = MagicMock( - return_value=MockedRequestQueueClient(id='default', memory_storage_client=MemoryStorageClient.from_config()) - ) + # Attempt to mark it as handled + result = await rq.mark_request_as_handled(request) + assert result is None + + +async def test_reclaim_non_existent_request(rq: RequestQueue) -> None: + """Test reclaiming a non-existent request.""" + # Create a request that hasn't been added to the queue + request = Request.from_url(url='https://example.com', id='non-existent-id') - request_queue = RequestQueue(id='default', name='some_name', storage_client=mocked_storage_client) + # Attempt to reclaim it + result = await rq.reclaim_request(request) + assert result is None - # Add the requests to the RQ in batches - await request_queue.add_requests_batched( - requests, wait_for_all_requests_to_be_added=True, wait_time_between_batches=timedelta(0) + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test dropping a request queue removes it from cache and clears its data.""" + rq = await RequestQueue.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, ) - # Ensure the batch was processed correctly - assert await request_queue.get_total_count() == expected_added_requests - # Fetch and validate each request in the queue - for original_request in requests[:expected_added_requests]: - next_request = await request_queue.fetch_next_request() - assert next_request is not None + # Add a request + await rq.add_request('https://example.com') - expected_url = original_request if isinstance(original_request, str) else original_request.url - assert next_request.url == expected_url + # Verify request queue exists in cache + assert rq.id in RequestQueue._cache_by_id + if rq.name: + assert rq.name in RequestQueue._cache_by_name + + # Drop the request queue + await rq.drop() + + # Verify request queue was removed from cache + assert rq.id not in RequestQueue._cache_by_id + if rq.name: + assert rq.name not in RequestQueue._cache_by_name + + # Verify request queue is empty (by creating a new one with the same name) + new_rq = await RequestQueue.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) - # Confirm the queue is empty after processing all requests - assert await request_queue.is_empty() is True + # Verify the queue is empty + assert await new_rq.is_empty() is True + assert new_rq.metadata.total_request_count == 0 + assert new_rq.metadata.pending_request_count == 0 + await new_rq.drop() diff --git a/uv.lock b/uv.lock index 392d5b63fa..8dfd81cdbb 100644 --- a/uv.lock +++ b/uv.lock @@ -600,7 +600,7 @@ toml = [ [[package]] name = "crawlee" -version = "0.6.7" +version = "0.6.8" source = { editable = "." } dependencies = [ { name = "apify-fingerprint-datapoints" }, @@ -744,7 +744,7 @@ dev = [ { name = "pytest-only", specifier = "~=2.1.0" }, { name = "pytest-xdist", specifier = "~=3.6.0" }, { name = "ruff", specifier = "~=0.11.0" }, - { name = "setuptools", specifier = "~=79.0.0" }, + { name = "setuptools" }, { name = "sortedcontainers-stubs", specifier = "~=2.4.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0.20240229" }, { name = "types-cachetools", specifier = "~=5.5.0.20240820" }, From 2c10b75c0f1fe24e5b772ac8efc6421fa4f49e17 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Thu, 24 Apr 2025 17:07:20 +0200 Subject: [PATCH 17/22] Update and add tests for the RQ storage clients --- .../_file_system/_dataset_client.py | 18 +- .../_file_system/_key_value_store_client.py | 18 +- .../_file_system/_request_queue_client.py | 335 ++++++++---- .../_memory/_dataset_client.py | 13 +- .../_memory/_key_value_store_client.py | 12 +- .../_memory/_request_queue_client.py | 19 +- src/crawlee/storages/_dataset.py | 40 +- src/crawlee/storages/_key_value_store.py | 38 +- src/crawlee/storages/_request_queue.py | 80 +-- .../_file_system/test_fs_rq_client.py | 500 ++++++++++++++++++ .../_memory/test_memory_rq_client.py | 495 +++++++++++++++++ tests/unit/storages/test_request_queue.py | 98 ++++ 12 files changed, 1477 insertions(+), 189 deletions(-) create mode 100644 tests/unit/storage_clients/_file_system/test_fs_rq_client.py create mode 100644 tests/unit/storage_clients/_memory/test_memory_rq_client.py diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py index 958566d925..5db837612f 100644 --- a/src/crawlee/storage_clients/_file_system/_dataset_client.py +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -27,11 +27,21 @@ class FileSystemDatasetClient(DatasetClient): - """A file system implementation of the dataset client. + """File system implementation of the dataset client. - This client persists data to the file system, making it suitable for scenarios where data needs - to survive process restarts. Each dataset item is stored as a separate JSON file with a numeric - filename, allowing for easy ordering and pagination. + This client persists dataset items to the file system as individual JSON files within a structured + directory hierarchy following the pattern: + + ``` + {STORAGE_DIR}/datasets/{DATASET_ID}/{ITEM_ID}.json + ``` + + Each item is stored as a separate file, which allows for durability and the ability to + recover after process termination. Dataset operations like filtering, sorting, and pagination are + implemented by processing the stored files according to the requested parameters. + + This implementation is ideal for long-running crawlers where data persistence is important, + and for development environments where you want to easily inspect the collected data between runs. """ _STORAGE_SUBDIR = 'datasets' diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py index 7799b71583..f7db025a25 100644 --- a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -28,11 +28,21 @@ class FileSystemKeyValueStoreClient(KeyValueStoreClient): - """A file system implementation of the key-value store client. + """File system implementation of the key-value store client. - This client persists data to the file system, making it suitable for scenarios where data needs - to survive process restarts. Each key-value pair is stored as a separate file, with its metadata - in an accompanying file. + This client persists data to the file system, making it suitable for scenarios where data needs to + survive process restarts. Keys are mapped to file paths in a directory structure following the pattern: + + ``` + {STORAGE_DIR}/key_value_stores/{STORE_ID}/{KEY} + ``` + + Binary data is stored as-is, while JSON and text data are stored in human-readable format. + The implementation automatically handles serialization based on the content type and + maintains metadata about each record. + + This implementation is ideal for long-running crawlers where persistence is important and + for development environments where you want to easily inspect the stored data between runs. """ _STORAGE_SUBDIR = 'key_value_stores' diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py index ed0d60f39a..a88168e894 100644 --- a/src/crawlee/storage_clients/_file_system/_request_queue_client.py +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -29,9 +29,20 @@ class FileSystemRequestQueueClient(RequestQueueClient): """A file system implementation of the request queue client. - This client persists requests to the file system, making it suitable for scenarios where data needs - to survive process restarts. Each request is stored as a separate file, allowing for proper request - handling and tracking across crawler runs. + This client persists requests to the file system as individual JSON files, making it suitable for scenarios + where data needs to survive process restarts. Each request is stored as a separate file in a directory + structure following the pattern: + + ``` + {STORAGE_DIR}/request_queues/{QUEUE_ID}/{REQUEST_ID}.json + ``` + + The implementation uses file timestamps for FIFO ordering of regular requests and maintains in-memory sets + for tracking in-progress and forefront requests. File system storage provides durability at the cost of + slower I/O operations compared to memory-based storage. + + This implementation is ideal for long-running crawlers where persistence is important and for situations + where you need to resume crawling after process termination. """ _STORAGE_SUBDIR = 'request_queues' @@ -78,6 +89,12 @@ def __init__( self._lock = asyncio.Lock() """A lock to ensure that only one operation is performed at a time.""" + self._in_progress = set[str]() + """A set of request IDs that are currently being processed.""" + + self._forefront_requests = set[str]() + """A set of request IDs that should be prioritized (added with forefront=True).""" + @override @property def metadata(self) -> RequestQueueMetadata: @@ -120,7 +137,7 @@ async def open( metadata_path = rq_path / METADATA_FILENAME # If the RQ directory exists, reconstruct the client from the metadata file. - if rq_path.exists(): + if rq_path.exists() and not configuration.purge_on_start: # If metadata file is missing, raise an error. if not metadata_path.exists(): raise ValueError(f'Metadata file not found for request queue "{name}"') @@ -149,10 +166,40 @@ async def open( storage_dir=storage_dir, ) - await client._update_metadata(update_accessed_at=True) + # Recalculate request counts from actual files to ensure consistency + handled_count = 0 + pending_count = 0 + request_files = await asyncio.to_thread(list, rq_path.glob('*.json')) + for request_file in request_files: + if request_file.name == METADATA_FILENAME: + continue + + try: + file = await asyncio.to_thread(open, request_file) + try: + data = json.load(file) + if data.get('handled_at') is not None: + handled_count += 1 + else: + pending_count += 1 + finally: + await asyncio.to_thread(file.close) + except (json.JSONDecodeError, ValidationError): + logger.warning(f'Failed to parse request file: {request_file}') + + await client._update_metadata( + update_accessed_at=True, + new_handled_request_count=handled_count, + new_pending_request_count=pending_count, + new_total_request_count=handled_count + pending_count, + ) # Otherwise, create a new dataset client. else: + # If purge_on_start is true and the directory exists, remove it + if configuration.purge_on_start and rq_path.exists(): + await asyncio.to_thread(shutil.rmtree, rq_path) + now = datetime.now(timezone.utc) client = cls( id=crypto_random_object_id(), @@ -185,8 +232,6 @@ async def drop(self) -> None: if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - # TODO: continue - @override async def add_batch_of_requests( self, @@ -204,15 +249,13 @@ async def add_batch_of_requests( Response containing information about the added requests. """ async with self._lock: + new_total_request_count = self._metadata.total_request_count + new_pending_request_count = self._metadata.pending_request_count + processed_requests = [] # Create the requests directory if it doesn't exist - requests_dir = self.path_to_rq / 'requests' - await asyncio.to_thread(requests_dir.mkdir, parents=True, exist_ok=True) - - # Create the in_progress directory if it doesn't exist - in_progress_dir = self.path_to_rq / 'in_progress' - await asyncio.to_thread(in_progress_dir.mkdir, parents=True, exist_ok=True) + await asyncio.to_thread(self.path_to_rq.mkdir, parents=True, exist_ok=True) for request in requests: # Ensure the request has an ID @@ -223,8 +266,12 @@ async def add_batch_of_requests( existing_request = None # List all request files and check for matching unique_key - request_files = await asyncio.to_thread(list, requests_dir.glob('*.json')) + request_files = await asyncio.to_thread(list, self.path_to_rq.glob('*.json')) for request_file in request_files: + # Skip metadata file + if request_file.name == METADATA_FILENAME: + continue + file = await asyncio.to_thread(open, request_file) try: file_content = json.load(file) @@ -242,10 +289,10 @@ async def add_batch_of_requests( ) # If the request is already in the queue and handled, don't add it again - if was_already_handled: + if was_already_handled and existing_request: processed_requests.append( ProcessedRequest( - id=request.id, + id=existing_request.id, unique_key=request.unique_key, was_already_present=True, was_already_handled=True, @@ -253,33 +300,70 @@ async def add_batch_of_requests( ) continue + # If forefront and existing request is not handled, mark it as forefront + if forefront and was_already_present and not was_already_handled and existing_request: + self._forefront_requests.add(existing_request.id) + processed_requests.append( + ProcessedRequest( + id=existing_request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + ) + continue + # If the request is already in the queue but not handled, update it - if was_already_present: + if was_already_present and existing_request: # Update the existing request file - request_path = requests_dir / f'{request.id}.json' - request_data = await json_dumps(request.model_dump()) - await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') - else: - # Add the new request to the queue - request_path = requests_dir / f'{request.id}.json' - request_data = await json_dumps(request.model_dump()) + request_path = self.path_to_rq / f'{existing_request.id}.json' + request_data = await json_dumps(existing_request.model_dump()) await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') - # Update metadata counts - self._metadata.total_request_count += 1 - self._metadata.pending_request_count += 1 + processed_requests.append( + ProcessedRequest( + id=existing_request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + ) + continue + + # Add the new request to the queue + request_path = self.path_to_rq / f'{request.id}.json' + + # Create a data dictionary from the request and remove handled_at if it's None + request_dict = request.model_dump() + if request_dict.get('handled_at') is None: + request_dict.pop('handled_at', None) + + request_data = await json_dumps(request_dict) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + # Update metadata counts + new_total_request_count += 1 + new_pending_request_count += 1 + + # If forefront, add to the forefront set + if forefront: + self._forefront_requests.add(request.id) processed_requests.append( ProcessedRequest( id=request.id, unique_key=request.unique_key, - was_already_present=was_already_present, + was_already_present=False, was_already_handled=False, ) ) - # Update metadata - await self._update_metadata(update_modified_at=True) + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_total_request_count=new_total_request_count, + new_pending_request_count=new_pending_request_count, + ) return AddRequestsResponse( processed_requests=processed_requests, @@ -296,24 +380,17 @@ async def get_request(self, request_id: str) -> Request | None: Returns: The retrieved request, or None, if it did not exist. """ - # First check in-progress directory - in_progress_dir = self.path_to_rq / 'in_progress' - in_progress_path = in_progress_dir / f'{request_id}.json' + request_path = self.path_to_rq / f'{request_id}.json' - # Then check regular requests directory - requests_dir = self.path_to_rq / 'requests' - request_path = requests_dir / f'{request_id}.json' - - for path in [in_progress_path, request_path]: - if await asyncio.to_thread(path.exists): - file = await asyncio.to_thread(open, path) - try: - file_content = json.load(file) - return Request(**file_content) - except (json.JSONDecodeError, ValidationError) as e: - logger.warning(f'Failed to parse request file {path}: {e!s}') - finally: - await asyncio.to_thread(file.close) + if await asyncio.to_thread(request_path.exists): + file = await asyncio.to_thread(open, request_path) + try: + file_content = json.load(file) + return Request(**file_content) + except (json.JSONDecodeError, ValidationError) as exc: + logger.warning(f'Failed to parse request file {request_path}: {exc!s}') + finally: + await asyncio.to_thread(file.close) return None @@ -330,18 +407,55 @@ async def fetch_next_request(self) -> Request | None: The request or `None` if there are no more pending requests. """ async with self._lock: - # Create the requests and in_progress directories if they don't exist - requests_dir = self.path_to_rq / 'requests' - in_progress_dir = self.path_to_rq / 'in_progress' - - await asyncio.to_thread(requests_dir.mkdir, parents=True, exist_ok=True) - await asyncio.to_thread(in_progress_dir.mkdir, parents=True, exist_ok=True) + # Create the requests directory if it doesn't exist + await asyncio.to_thread(self.path_to_rq.mkdir, parents=True, exist_ok=True) # List all request files - request_files = await asyncio.to_thread(list, requests_dir.glob('*.json')) + request_files = await asyncio.to_thread(list, self.path_to_rq.glob('*.json')) - # Find a request that's not handled + # First check for forefront requests + forefront_requests = [] + regular_requests = [] + + # Get file creation times for sorting regular requests in FIFO order + request_file_times = {} + + # Separate requests into forefront and regular for request_file in request_files: + # Skip metadata file + if request_file.name == METADATA_FILENAME: + continue + + # Extract request ID from filename + request_id = request_file.stem + + # Skip if already in progress + if request_id in self._in_progress: + continue + + # Get file creation/modification time for FIFO ordering + try: + file_stat = await asyncio.to_thread(request_file.stat) + request_file_times[request_file] = file_stat.st_mtime + except Exception: + # If we can't get the time, use 0 (oldest) + request_file_times[request_file] = 0 + + if request_id in self._forefront_requests: + forefront_requests.append(request_file) + else: + regular_requests.append(request_file) + + # Sort regular requests by creation time (FIFO order) + regular_requests.sort(key=lambda f: request_file_times[f]) + + # Prioritize forefront requests + prioritized_files = forefront_requests + regular_requests + + # Process files in prioritized order + for request_file in prioritized_files: + request_id = request_file.stem + file = await asyncio.to_thread(open, request_file) try: file_content = json.load(file) @@ -352,19 +466,17 @@ async def fetch_next_request(self) -> Request | None: # Create request object request = Request(**file_content) - # Move to in-progress - in_progress_path = in_progress_dir / f'{request.id}.json' + # Mark as in-progress in memory + self._in_progress.add(request.id) - # If already in in-progress, skip - if await asyncio.to_thread(in_progress_path.exists): - continue + # Remove from forefront set if it was there + self._forefront_requests.discard(request.id) - # Write to in-progress directory - request_data = await json_dumps(request.model_dump()) - await asyncio.to_thread(in_progress_path.write_text, request_data, encoding='utf-8') + # Update accessed timestamp + await self._update_metadata(update_accessed_at=True) - except (json.JSONDecodeError, ValidationError) as e: - logger.warning(f'Failed to parse request file {request_file}: {e!s}') + except (json.JSONDecodeError, ValidationError) as exc: + logger.warning(f'Failed to parse request file {request_file}: {exc!s}') else: return request finally: @@ -386,32 +498,32 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | """ async with self._lock: # Check if the request is in progress - in_progress_dir = self.path_to_rq / 'in_progress' - in_progress_path = in_progress_dir / f'{request.id}.json' - - if not await asyncio.to_thread(in_progress_path.exists): + if request.id not in self._in_progress: return None + # Remove from in-progress set + self._in_progress.discard(request.id) + # Update the request object - set handled_at timestamp if request.handled_at is None: request.handled_at = datetime.now(timezone.utc) # Write the updated request back to the requests directory - requests_dir = self.path_to_rq / 'requests' - request_path = requests_dir / f'{request.id}.json' + request_path = self.path_to_rq / f'{request.id}.json' + + if not await asyncio.to_thread(request_path.exists): + return None request_data = await json_dumps(request.model_dump()) await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') - # Remove the in-progress file - await asyncio.to_thread(in_progress_path.unlink, missing_ok=True) - - # Update metadata counts - self._metadata.handled_request_count += 1 - self._metadata.pending_request_count -= 1 - # Update metadata timestamps - await self._update_metadata(update_modified_at=True) + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_handled_request_count=self._metadata.handled_request_count + 1, + new_pending_request_count=self._metadata.pending_request_count - 1, + ) return ProcessedRequest( id=request.id, @@ -440,30 +552,31 @@ async def reclaim_request( """ async with self._lock: # Check if the request is in progress - in_progress_dir = self.path_to_rq / 'in_progress' - in_progress_path = in_progress_dir / f'{request.id}.json' - - if not await asyncio.to_thread(in_progress_path.exists): + if request.id not in self._in_progress: return None - # Remove the in-progress file - await asyncio.to_thread(in_progress_path.unlink, missing_ok=True) + # Remove from in-progress set + self._in_progress.discard(request.id) - # If forefront is true, we need to handle this specially - # Since we can't reorder files, we'll add a 'priority' field to the request + # If forefront is true, mark this request as priority if forefront: - # Update the priority of the request to indicate it should be processed first - request.priority = 1 # Higher priority + self._forefront_requests.add(request.id) + else: + # Make sure it's not in the forefront set if it was previously added there + self._forefront_requests.discard(request.id) - # Write the updated request back to the requests directory - requests_dir = self.path_to_rq / 'requests' - request_path = requests_dir / f'{request.id}.json' + # To simulate changing the file timestamp for FIFO ordering, + # we'll update the file with current timestamp + request_path = self.path_to_rq / f'{request.id}.json' + + if not await asyncio.to_thread(request_path.exists): + return None request_data = await json_dumps(request.model_dump()) await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') # Update metadata timestamps - await self._update_metadata(update_modified_at=True) + await self._update_metadata(update_modified_at=True, update_accessed_at=True) return ProcessedRequest( id=request.id, @@ -479,15 +592,21 @@ async def is_empty(self) -> bool: Returns: True if the queue is empty, False otherwise. """ + # Update accessed timestamp when checking if queue is empty + await self._update_metadata(update_accessed_at=True) + # Create the requests directory if it doesn't exist - requests_dir = self.path_to_rq / 'requests' - await asyncio.to_thread(requests_dir.mkdir, parents=True, exist_ok=True) + await asyncio.to_thread(self.path_to_rq.mkdir, parents=True, exist_ok=True) # List all request files - request_files = await asyncio.to_thread(list, requests_dir.glob('*.json')) + request_files = await asyncio.to_thread(list, self.path_to_rq.glob('*.json')) # Check each file to see if there are any unhandled requests for request_file in request_files: + # Skip metadata file + if request_file.name == METADATA_FILENAME: + continue + file = await asyncio.to_thread(open, request_file) try: file_content = json.load(file) @@ -505,22 +624,46 @@ async def is_empty(self) -> bool: async def _update_metadata( self, *, + new_handled_request_count: int | None = None, + new_pending_request_count: int | None = None, + new_total_request_count: int | None = None, + update_had_multiple_clients: bool = False, update_accessed_at: bool = False, update_modified_at: bool = False, ) -> None: """Update the dataset metadata file with current information. Args: + new_handled_request_count: If provided, update the handled_request_count to this value. + new_pending_request_count: If provided, update the pending_request_count to this value. + new_total_request_count: If provided, update the total_request_count to this value. + update_had_multiple_clients: If True, set had_multiple_clients to True. update_accessed_at: If True, update the `accessed_at` timestamp to the current time. update_modified_at: If True, update the `modified_at` timestamp to the current time. """ + # Always create a new timestamp to ensure it's truly updated now = datetime.now(timezone.utc) + # Update timestamps according to parameters if update_accessed_at: self._metadata.accessed_at = now + if update_modified_at: self._metadata.modified_at = now + # Update request counts if provided + if new_handled_request_count is not None: + self._metadata.handled_request_count = new_handled_request_count + + if new_pending_request_count is not None: + self._metadata.pending_request_count = new_pending_request_count + + if new_total_request_count is not None: + self._metadata.total_request_count = new_total_request_count + + if update_had_multiple_clients: + self._metadata.had_multiple_clients = True + # Ensure the parent directory for the metadata file exists. await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 3a0e486330..0d75b50f9f 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -19,11 +19,16 @@ class MemoryDatasetClient(DatasetClient): - """A memory implementation of the dataset client. + """Memory implementation of the dataset client. - This client stores dataset items in memory using a list. No data is persisted, which means - all data is lost when the process terminates. This implementation is mainly useful for testing - and development purposes where persistence is not required. + This client stores dataset items in memory using Python lists and dictionaries. No data is persisted + between process runs, meaning all stored data is lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler operations where persistent + storage is not required. + + The memory implementation provides fast access to data but is limited by available memory and + does not support data sharing across different processes. It supports all dataset operations including + sorting, filtering, and pagination, but performs them entirely in memory. """ _cache_by_name: ClassVar[dict[str, MemoryDatasetClient]] = {} diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index 76bcd5761e..9b70419142 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -21,11 +21,15 @@ class MemoryKeyValueStoreClient(KeyValueStoreClient): - """A memory implementation of the key-value store client. + """Memory implementation of the key-value store client. - This client stores key-value store pairs in memory using a dictionary. No data is persisted, - which means all data is lost when the process terminates. This implementation is mainly useful - for testing and development purposes where persistence is not required. + This client stores data in memory as Python dictionaries. No data is persisted between + process runs, meaning all stored data is lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler operations where + persistence is not required. + + The memory implementation provides fast access to data but is limited by available memory and + does not support data sharing across different processes. """ _cache_by_name: ClassVar[dict[str, MemoryKeyValueStoreClient]] = {} diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index e197954644..cfe674fb3d 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -24,11 +24,15 @@ class MemoryRequestQueueClient(RequestQueueClient): - """A memory implementation of the request queue client. + """Memory implementation of the request queue client. - This client stores requests in memory using a list. No data is persisted, which means - all requests are lost when the process terminates. This implementation is mainly useful - for testing and development purposes where persistence is not required. + This client stores requests in memory using a Python list and dictionary. No data is persisted between + process runs, which means all requests are lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler runs where persistence + is not required. + + This client provides fast access to request data but is limited by available memory and + does not support data sharing across different processes. """ _cache_by_name: ClassVar[dict[str, MemoryRequestQueueClient]] = {} @@ -50,7 +54,7 @@ def __init__( ) -> None: """Initialize a new instance. - Preferably use the `FileSystemRequestQueueClient.open` class method to create a new instance. + Preferably use the `MemoryRequestQueueClient.open` class method to create a new instance. """ self._metadata = RequestQueueMetadata( id=id, @@ -192,8 +196,7 @@ async def add_batch_of_requests( ) ) - # Update metadata - await self._update_metadata(update_modified_at=True) + await self._update_metadata(update_accessed_at=True, update_modified_at=True) return AddRequestsResponse( processed_requests=processed_requests, @@ -333,6 +336,8 @@ async def is_empty(self) -> bool: Returns: True if the queue is empty, False otherwise. """ + await self._update_metadata(update_accessed_at=True) + # Queue is empty if there are no pending requests pending_requests = [r for r in self._records if r.handled_at is None] return len(pending_requests) == 0 diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index bd453bc8cc..f784f70216 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -31,34 +31,38 @@ @docs_group('Classes') class Dataset(Storage): - """Dataset is an append-only structured storage, ideal for tabular data similar to database tables. + """Dataset is a storage for managing structured tabular data. - The `Dataset` class is designed to store structured data, where each entry (row) maintains consistent attributes - (columns) across the dataset. It operates in an append-only mode, allowing new records to be added, but not - modified or deleted. This makes it particularly useful for storing results from web crawling operations. + The dataset class provides a high-level interface for storing and retrieving structured data + with consistent schema, similar to database tables or spreadsheets. It abstracts the underlying + storage implementation details, offering a consistent API regardless of where the data is + physically stored. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + Dataset operates in an append-only mode, allowing new records to be added but not modified + or deleted after creation. This makes it particularly suitable for storing crawling results + and other data that should be immutable once collected. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/datasets/{DATASET_ID}/{INDEX}.json - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{DATASET_ID}`: Specifies the dataset, either "default" or a custom dataset ID. - - `{INDEX}`: Represents the zero-based index of the record within the dataset. - - To open a dataset, use the `open` class method by specifying an `id`, `name`, or `configuration`. If none are - provided, the default dataset for the current crawler run is used. Attempting to open a dataset by `id` that does - not exist will raise an error; however, if accessed by `name`, the dataset will be created if it doesn't already - exist. + The class provides methods for adding data, retrieving data with various filtering options, + and exporting data to different formats. You can create a dataset using the `open` class method, + specifying either a name or ID. The underlying storage implementation is determined by + the configured storage client. ### Usage ```python from crawlee.storages import Dataset + # Open a dataset dataset = await Dataset.open(name='my_dataset') + + # Add data + await dataset.push_data({'title': 'Example Product', 'price': 99.99}) + + # Retrieve filtered data + results = await dataset.get_data(limit=10, desc=True) + + # Export data + await dataset.export_to('results.json', content_type='json') ``` """ diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 95f2b3c1c9..41f9afe37e 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -22,36 +22,32 @@ @docs_group('Classes') class KeyValueStore(Storage): - """Represents a key-value based storage for reading and writing data records or files. + """Key-value store is a storage for reading and writing data records with unique key identifiers. - Each data record is identified by a unique key and associated with a specific MIME content type. This class is - commonly used in crawler runs to store inputs and outputs, typically in JSON format, but it also supports other - content types. + The key-value store class acts as a high-level interface for storing, retrieving, and managing data records + identified by unique string keys. It abstracts away the underlying storage implementation details, + allowing you to work with the same API regardless of whether data is stored in memory, on disk, + or in the cloud. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + Each data record is associated with a specific MIME content type, allowing storage of various + data formats such as JSON, text, images, HTML snapshots or any binary data. This class is + commonly used to store inputs, outputs, and other artifacts of crawler operations. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/key_value_stores/{STORE_ID}/{KEY}.{EXT} - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{STORE_ID}`: The identifier for the key-value store, either "default" or as specified by - `CRAWLEE_DEFAULT_KEY_VALUE_STORE_ID`. - - `{KEY}`: The unique key for the record. - - `{EXT}`: The file extension corresponding to the MIME type of the content. - - To open a key-value store, use the `open` class method, providing an `id`, `name`, or optional `configuration`. - If none are specified, the default store for the current crawler run is used. Attempting to open a store by `id` - that does not exist will raise an error; however, if accessed by `name`, the store will be created if it does not - already exist. + You can instantiate a key-value store using the `open` class method, which will create a store + with the specified name or id. The underlying storage implementation is determined by the configured + storage client. ### Usage ```python from crawlee.storages import KeyValueStore - kvs = await KeyValueStore.open(name='my_kvs') + # Open a named key-value store + kvs = await KeyValueStore.open(name='my-store') + + # Store and retrieve data + await kvs.set_value('product-1234.json', [{'name': 'Smartphone', 'price': 799.99}]) + product = await kvs.get_value('product-1234') ``` """ diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index 843ac6d0f1..d998cafe46 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -30,34 +30,43 @@ @docs_group('Classes') class RequestQueue(Storage, RequestManager): - """Represents a queue storage for managing HTTP requests in web crawling operations. + """Request queue is a storage for managing HTTP requests. - The `RequestQueue` class handles a queue of HTTP requests, each identified by a unique URL, to facilitate structured - web crawling. It supports both breadth-first and depth-first crawling strategies, allowing for recursive crawling - starting from an initial set of URLs. Each URL in the queue is uniquely identified by a `unique_key`, which can be - customized to allow the same URL to be added multiple times under different keys. + The request queue class serves as a high-level interface for organizing and managing HTTP requests + during web crawling. It provides methods for adding, retrieving, and manipulating requests throughout + the crawling lifecycle, abstracting away the underlying storage implementation details. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + Request queue maintains the state of each URL to be crawled, tracking whether it has been processed, + is currently being handled, or is waiting in the queue. Each URL in the queue is uniquely identified + by a `unique_key` property, which prevents duplicate processing unless explicitly configured otherwise. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/request_queues/{QUEUE_ID}/{REQUEST_ID}.json - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{QUEUE_ID}`: The identifier for the request queue, either "default" or as specified. - - `{REQUEST_ID}`: The unique identifier for each request in the queue. + The class supports both breadth-first and depth-first crawling strategies through its `forefront` parameter + when adding requests. It also provides mechanisms for error handling and request reclamation when + processing fails. - The `RequestQueue` supports both creating new queues and opening existing ones by `id` or `name`. Named queues - persist indefinitely, while unnamed queues expire after 7 days unless specified otherwise. The queue supports - mutable operations, allowing URLs to be added and removed as needed. + You can open a request queue using the `open` class method, specifying either a name or ID to identify + the queue. The underlying storage implementation is determined by the configured storage client. ### Usage ```python from crawlee.storages import RequestQueue - rq = await RequestQueue.open(name='my_rq') + # Open a request queue + rq = await RequestQueue.open(name='my_queue') + + # Add a request + await rq.add_request('https://example.com') + + # Process requests + request = await rq.fetch_next_request() + if request: + try: + # Process the request + # ... + await rq.mark_request_as_handled(request) + except Exception: + await rq.reclaim_request(request) ``` """ @@ -221,31 +230,33 @@ async def fetch_next_request(self) -> Request | None: instead. Returns: - The request or `None` if there are no more pending requests. + The next request to process, or `None` if there are no more pending requests. """ return await self._client.fetch_next_request() async def get_request(self, request_id: str) -> Request | None: - """Retrieve a request by its ID. + """Retrieve a specific request from the queue by its ID. Args: request_id: The ID of the request to retrieve. Returns: - The request if found, otherwise `None`. + The request with the specified ID, or `None` if no such request exists. """ return await self._client.get_request(request_id) async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: """Mark a request as handled after successful processing. - Handled requests will never again be returned by the `RequestQueue.fetch_next_request` method. + This method should be called after a request has been successfully processed. + Once marked as handled, the request will be removed from the queue and will + not be returned in subsequent calls to `fetch_next_request` method. Args: request: The request to mark as handled. Returns: - Information about the queue operation. `None` if the given request was not in progress. + Information about the queue operation. """ return await self._client.mark_request_as_handled(request) @@ -255,23 +266,28 @@ async def reclaim_request( *, forefront: bool = False, ) -> ProcessedRequest | None: - """Reclaim a failed request back to the queue. + """Reclaim a failed request back to the queue for later processing. - The request will be returned for processing later again by another call to `RequestQueue.fetch_next_request`. + If a request fails during processing, this method can be used to return it to the queue. + The request will be returned for processing again in a subsequent call + to `RequestQueue.fetch_next_request`. Args: request: The request to return to the queue. - forefront: Whether to add the request to the head or the end of the queue. + forefront: If true, the request will be added to the beginning of the queue. + Otherwise, it will be added to the end. Returns: - Information about the queue operation. `None` if the given request was not in progress. + Information about the queue operation. """ return await self._client.reclaim_request(request, forefront=forefront) async def is_empty(self) -> bool: """Check if the request queue is empty. - An empty queue means that there are no requests in the queue. + An empty queue means that there are no requests currently in the queue, either pending or being processed. + However, this does not necessarily mean that the crawling operation is finished, as there still might be + tasks that could add additional requests to the queue. Returns: True if the request queue is empty, False otherwise. @@ -281,11 +297,12 @@ async def is_empty(self) -> bool: async def is_finished(self) -> bool: """Check if the request queue is finished. - Finished means that all requests in the queue have been processed (the queue is empty) and there - are no more tasks that could add additional requests to the queue. + A finished queue means that all requests in the queue have been processed (the queue is empty) and there + are no more tasks that could add additional requests to the queue. This is the definitive way to check + if a crawling operation is complete. Returns: - True if the request queue is finished, False otherwise. + True if the request queue is finished (empty and no pending add operations), False otherwise. """ if self._add_requests_tasks: logger.debug('Background add requests tasks are still in progress.') @@ -305,6 +322,7 @@ async def _process_batch( attempt: int = 1, forefront: bool = False, ) -> None: + """Process a batch of requests with automatic retry mechanism.""" max_attempts = 5 response = await self._client.add_batch_of_requests(batch, forefront=forefront) diff --git a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py new file mode 100644 index 0000000000..bc0dcd5313 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee import Request +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients._file_system import FileSystemRequestQueueClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + +pytestmark = pytest.mark.only + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def rq_client(configuration: Configuration) -> AsyncGenerator[FileSystemRequestQueueClient, None]: + """A fixture for a file system request queue client.""" + client = await FileSystemStorageClient().open_request_queue_client( + name='test_request_queue', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_open_creates_new_rq(configuration: Configuration) -> None: + """Test that open() creates a new request queue with proper metadata and files on disk.""" + client = await FileSystemStorageClient().open_request_queue_client( + name='new_request_queue', + configuration=configuration, + ) + + # Verify correct client type and properties + assert isinstance(client, FileSystemRequestQueueClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_request_queue' + assert client.metadata.handled_request_count == 0 + assert client.metadata.pending_request_count == 0 + assert client.metadata.total_request_count == 0 + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + # Verify files were created + assert client.path_to_rq.exists() + assert client.path_to_metadata.exists() + + # Verify metadata content + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == client.metadata.id + assert metadata['name'] == 'new_request_queue' + + +async def test_open_existing_rq( + rq_client: FileSystemRequestQueueClient, + configuration: Configuration, +) -> None: + """Test that open() loads an existing request queue correctly.""" + configuration.purge_on_start = False + + # Add a request to the original client + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Open the same request queue again + reopened_client = await FileSystemStorageClient().open_request_queue_client( + name=rq_client.metadata.name, + configuration=configuration, + ) + + # Verify client properties + assert rq_client.metadata.id == reopened_client.metadata.id + assert rq_client.metadata.name == reopened_client.metadata.name + assert rq_client.metadata.total_request_count == 1 + assert rq_client.metadata.pending_request_count == 1 + + # Verify clients (python) ids - should be the same object due to caching + assert id(rq_client) == id(reopened_client) + + +async def test_rq_client_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=True clears existing data in the request queue.""" + configuration.purge_on_start = True + + # Create request queue and add data + rq_client1 = await FileSystemStorageClient().open_request_queue_client( + name='test-purge-rq', + configuration=configuration, + ) + await rq_client1.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Verify request was added + assert rq_client1.metadata.total_request_count == 1 + + # Reopen + rq_client2 = await FileSystemStorageClient().open_request_queue_client( + name='test-purge-rq', + configuration=configuration, + ) + + # Verify data was purged + assert rq_client2.metadata.total_request_count == 0 + + +async def test_rq_client_no_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=False keeps existing data in the request queue.""" + configuration.purge_on_start = False + + # Create request queue and add data + rq_client1 = await FileSystemStorageClient().open_request_queue_client( + name='test-no-purge-rq', + configuration=configuration, + ) + await rq_client1.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Reopen + rq_client2 = await FileSystemStorageClient().open_request_queue_client( + name='test-no-purge-rq', + configuration=configuration, + ) + + # Verify data was preserved + assert rq_client2.metadata.total_request_count == 1 + + +async def test_open_with_id_raises_error(configuration: Configuration) -> None: + """Test that open() raises an error when an ID is provided.""" + with pytest.raises(ValueError, match='not supported for file system storage client'): + await FileSystemStorageClient().open_request_queue_client(id='some-id', configuration=configuration) + + +@pytest.fixture +def rq_path(rq_client: FileSystemRequestQueueClient) -> Path: + """Return the path to the request queue directory.""" + return rq_client.path_to_rq + + +async def test_add_requests(rq_client: FileSystemRequestQueueClient) -> None: + """Test adding requests creates proper files in the filesystem.""" + # Add a batch of requests + requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + Request.from_url('https://example.com/3'), + ] + + response = await rq_client.add_batch_of_requests(requests) + + # Verify response + assert len(response.processed_requests) == 3 + for i, processed_request in enumerate(response.processed_requests): + assert processed_request.unique_key == f'https://example.com/{i + 1}' + assert processed_request.was_already_present is False + assert processed_request.was_already_handled is False + + # Verify request files were created + request_files = list(rq_client.path_to_rq.glob('*.json')) + assert len(request_files) == 4 # 3 requests + metadata file + assert rq_client.path_to_metadata in request_files + + # Verify metadata was updated + assert rq_client.metadata.total_request_count == 3 + assert rq_client.metadata.pending_request_count == 3 + + # Verify content of the request files + for req_file in [f for f in request_files if f != rq_client.path_to_metadata]: + with req_file.open() as f: + content = json.load(f) + assert 'url' in content + assert content['url'].startswith('https://example.com/') + assert 'id' in content + assert 'handled_at' not in content # Not yet handled + + +async def test_add_duplicate_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test adding a duplicate request.""" + request = Request.from_url('https://example.com') + + # Add the request the first time + await rq_client.add_batch_of_requests([request]) + + # Add the same request again + second_response = await rq_client.add_batch_of_requests([request]) + + # Verify response indicates it was already present + assert second_response.processed_requests[0].was_already_present is True + + # Verify only one request file exists + request_files = [f for f in rq_client.path_to_rq.glob('*.json') if f.name != METADATA_FILENAME] + assert len(request_files) == 1 + + # Verify metadata counts weren't incremented + assert rq_client.metadata.total_request_count == 1 + assert rq_client.metadata.pending_request_count == 1 + + +async def test_fetch_next_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test fetching the next request from the queue.""" + # Add requests + requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch the first request + first_request = await rq_client.fetch_next_request() + assert first_request is not None + assert first_request.url == 'https://example.com/1' + + # Check that it's marked as in-progress + assert first_request.id in rq_client._in_progress + + # Fetch the second request + second_request = await rq_client.fetch_next_request() + assert second_request is not None + assert second_request.url == 'https://example.com/2' + + # There should be no more requests + empty_request = await rq_client.fetch_next_request() + assert empty_request is None + + +async def test_fetch_forefront_requests(rq_client: FileSystemRequestQueueClient) -> None: + """Test that forefront requests are fetched first.""" + # Add regular requests + await rq_client.add_batch_of_requests( + [ + Request.from_url('https://example.com/regular1'), + Request.from_url('https://example.com/regular2'), + ] + ) + + # Add forefront requests + await rq_client.add_batch_of_requests( + [ + Request.from_url('https://example.com/priority1'), + Request.from_url('https://example.com/priority2'), + ], + forefront=True, + ) + + # Fetch requests - they should come in priority order first + next_request1 = await rq_client.fetch_next_request() + assert next_request1 is not None + assert next_request1.url.startswith('https://example.com/priority') + + next_request2 = await rq_client.fetch_next_request() + assert next_request2 is not None + assert next_request2.url.startswith('https://example.com/priority') + + next_request3 = await rq_client.fetch_next_request() + assert next_request3 is not None + assert next_request3.url.startswith('https://example.com/regular') + + next_request4 = await rq_client.fetch_next_request() + assert next_request4 is not None + assert next_request4.url.startswith('https://example.com/regular') + + +async def test_mark_request_as_handled(rq_client: FileSystemRequestQueueClient) -> None: + """Test marking a request as handled.""" + # Add and fetch a request + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + request = await rq_client.fetch_next_request() + assert request is not None + + # Mark it as handled + result = await rq_client.mark_request_as_handled(request) + assert result is not None + assert result.was_already_handled is True + + # Verify it's no longer in-progress + assert request.id not in rq_client._in_progress + + # Verify metadata was updated + assert rq_client.metadata.handled_request_count == 1 + assert rq_client.metadata.pending_request_count == 0 + + # Verify the file was updated with handled_at timestamp + request_files = [f for f in rq_client.path_to_rq.glob('*.json') if f.name != METADATA_FILENAME] + assert len(request_files) == 1 + + with request_files[0].open() as f: + content = json.load(f) + assert 'handled_at' in content + assert content['handled_at'] is not None + + +async def test_reclaim_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test reclaiming a request that failed processing.""" + # Add and fetch a request + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + request = await rq_client.fetch_next_request() + assert request is not None + + # Reclaim the request + result = await rq_client.reclaim_request(request) + assert result is not None + assert result.was_already_handled is False + + # Verify it's no longer in-progress + assert request.id not in rq_client._in_progress + + # Should be able to fetch it again + reclaimed_request = await rq_client.fetch_next_request() + assert reclaimed_request is not None + assert reclaimed_request.id == request.id + + +async def test_reclaim_request_with_forefront(rq_client: FileSystemRequestQueueClient) -> None: + """Test reclaiming a request with forefront priority.""" + # Add requests + await rq_client.add_batch_of_requests( + [ + Request.from_url('https://example.com/first'), + Request.from_url('https://example.com/second'), + ] + ) + + # Fetch the first request + first_request = await rq_client.fetch_next_request() + assert first_request is not None + assert first_request.url == 'https://example.com/first' + + # Reclaim it with forefront priority + await rq_client.reclaim_request(first_request, forefront=True) + + # Verify it's in the forefront set + assert first_request.id in rq_client._forefront_requests + + # It should be returned before the second request + reclaimed_request = await rq_client.fetch_next_request() + assert reclaimed_request is not None + assert reclaimed_request.url == 'https://example.com/first' + + +async def test_is_empty(rq_client: FileSystemRequestQueueClient) -> None: + """Test checking if a queue is empty.""" + # Queue should start empty + assert await rq_client.is_empty() is True + + # Add a request + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + assert await rq_client.is_empty() is False + + # Fetch and handle the request + request = await rq_client.fetch_next_request() + assert request is not None + await rq_client.mark_request_as_handled(request) + + # Queue should be empty again + assert await rq_client.is_empty() is True + + +async def test_get_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test getting a request by ID.""" + # Add a request + response = await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + request_id = response.processed_requests[0].id + + # Get the request by ID + request = await rq_client.get_request(request_id) + assert request is not None + assert request.id == request_id + assert request.url == 'https://example.com' + + # Try to get a non-existent request + not_found = await rq_client.get_request('non-existent-id') + assert not_found is None + + +async def test_drop(configuration: Configuration) -> None: + """Test dropping the queue removes files from the filesystem.""" + client = await FileSystemStorageClient().open_request_queue_client( + name='drop_test', + configuration=configuration, + ) + + # Add requests to create files + await client.add_batch_of_requests( + [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + ) + + # Verify the directory exists + rq_path = client.path_to_rq + assert rq_path.exists() + + # Drop the client + await client.drop() + + # Verify the directory was removed + assert not rq_path.exists() + + # Verify the client was removed from the cache + assert client.metadata.name not in FileSystemRequestQueueClient._cache_by_name + + +async def test_file_persistence(configuration: Configuration) -> None: + """Test that requests are persisted to files and can be recovered after a 'restart'.""" + # Explicitly set purge_on_start to False to ensure files aren't deleted + configuration.purge_on_start = False + + # Create a client and add requests + client1 = await FileSystemStorageClient().open_request_queue_client( + name='persistence_test', + configuration=configuration, + ) + + await client1.add_batch_of_requests( + [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + ) + + # Fetch and handle one request + request = await client1.fetch_next_request() + assert request is not None + await client1.mark_request_as_handled(request) + + # Get the storage directory path before clearing the cache + storage_path = client1.path_to_rq + assert storage_path.exists(), 'Request queue directory should exist' + + # Verify files exist + request_files = list(storage_path.glob('*.json')) + assert len(request_files) > 0, 'Request files should exist' + + # Clear cache to simulate process restart + FileSystemRequestQueueClient._cache_by_name.clear() + + # Create a new client with same name (which will load from files) + client2 = await FileSystemStorageClient().open_request_queue_client( + name='persistence_test', + configuration=configuration, + ) + + # Verify state was recovered + assert client2.metadata.total_request_count == 2 + assert client2.metadata.handled_request_count == 1 + assert client2.metadata.pending_request_count == 1 + + # Should be able to fetch the remaining request + remaining_request = await client2.fetch_next_request() + assert remaining_request is not None + assert remaining_request.url == 'https://example.com/2' + + # Clean up + await client2.drop() + + +async def test_metadata_updates(rq_client: FileSystemRequestQueueClient) -> None: + """Test that metadata timestamps are updated correctly after operations.""" + # Record initial timestamps + initial_created = rq_client.metadata.created_at + initial_accessed = rq_client.metadata.accessed_at + initial_modified = rq_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await rq_client.is_empty() + + # Verify timestamps + assert rq_client.metadata.created_at == initial_created + assert rq_client.metadata.accessed_at > initial_accessed + assert rq_client.metadata.modified_at == initial_modified + + accessed_after_get = rq_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Verify timestamps again + assert rq_client.metadata.created_at == initial_created + assert rq_client.metadata.modified_at > initial_modified + assert rq_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_memory/test_memory_rq_client.py b/tests/unit/storage_clients/_memory/test_memory_rq_client.py new file mode 100644 index 0000000000..6b356df013 --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_rq_client.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee import Request +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._memory import MemoryRequestQueueClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +pytestmark = pytest.mark.only + + +@pytest.fixture +async def rq_client() -> AsyncGenerator[MemoryRequestQueueClient, None]: + """Fixture that provides a fresh memory request queue client for each test.""" + client = await MemoryStorageClient().open_request_queue_client(name='test_rq') + yield client + await client.drop() + + +async def test_open_creates_new_rq() -> None: + """Test that open() creates a new request queue with proper metadata and adds it to the cache.""" + client = await MemoryStorageClient().open_request_queue_client(name='new_rq') + + # Verify correct client type and properties + assert isinstance(client, MemoryRequestQueueClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_rq' + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + assert client.metadata.handled_request_count == 0 + assert client.metadata.pending_request_count == 0 + assert client.metadata.total_request_count == 0 + assert client.metadata.had_multiple_clients is False + + # Verify the client was cached + assert 'new_rq' in MemoryRequestQueueClient._cache_by_name + + +async def test_open_existing_rq(rq_client: MemoryRequestQueueClient) -> None: + """Test that open() loads an existing request queue with matching properties.""" + configuration = Configuration(purge_on_start=False) + # Open the same request queue again + reopened_client = await MemoryStorageClient().open_request_queue_client( + name=rq_client.metadata.name, + configuration=configuration, + ) + + # Verify client properties + assert rq_client.metadata.id == reopened_client.metadata.id + assert rq_client.metadata.name == reopened_client.metadata.name + + # Verify clients (python) ids + assert id(rq_client) == id(reopened_client) + + +async def test_rq_client_purge_on_start() -> None: + """Test that purge_on_start=True clears existing data in the RQ.""" + configuration = Configuration(purge_on_start=True) + + # Create RQ and add data + rq_client1 = await MemoryStorageClient().open_request_queue_client( + name='test_purge_rq', + configuration=configuration, + ) + request = Request.from_url(url='https://example.com/initial') + await rq_client1.add_batch_of_requests([request]) + + # Verify request was added + assert await rq_client1.is_empty() is False + + # Reopen + rq_client2 = await MemoryStorageClient().open_request_queue_client( + name='test_purge_rq', + configuration=configuration, + ) + + # Verify queue was purged + assert await rq_client2.is_empty() is True + + +async def test_rq_client_no_purge_on_start() -> None: + """Test that purge_on_start=False keeps existing data in the RQ.""" + configuration = Configuration(purge_on_start=False) + + # Create RQ and add data + rq_client1 = await MemoryStorageClient().open_request_queue_client( + name='test_no_purge_rq', + configuration=configuration, + ) + request = Request.from_url(url='https://example.com/preserved') + await rq_client1.add_batch_of_requests([request]) + + # Reopen + rq_client2 = await MemoryStorageClient().open_request_queue_client( + name='test_no_purge_rq', + configuration=configuration, + ) + + # Verify request was preserved + assert await rq_client2.is_empty() is False + next_request = await rq_client2.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/preserved' + + +async def test_open_with_id_and_name() -> None: + """Test that open() can be used with both id and name parameters.""" + client = await MemoryStorageClient().open_request_queue_client( + id='some-id', + name='some-name', + ) + assert client.metadata.id is not None # ID is always auto-generated + assert client.metadata.name == 'some-name' + + +async def test_add_batch_of_requests(rq_client: MemoryRequestQueueClient) -> None: + """Test adding a batch of requests to the queue.""" + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + Request.from_url(url='https://example.com/3'), + ] + + response = await rq_client.add_batch_of_requests(requests) + + # Verify correct response + assert len(response.processed_requests) == 3 + assert len(response.unprocessed_requests) == 0 + + # Verify each request was processed correctly + for i, req in enumerate(requests): + assert response.processed_requests[i].id == req.id + assert response.processed_requests[i].unique_key == req.unique_key + assert response.processed_requests[i].was_already_present is False + assert response.processed_requests[i].was_already_handled is False + + # Verify metadata was updated + assert rq_client.metadata.total_request_count == 3 + assert rq_client.metadata.pending_request_count == 3 + + +async def test_add_batch_of_requests_with_duplicates(rq_client: MemoryRequestQueueClient) -> None: + """Test adding requests with duplicate unique keys.""" + # Add initial requests + initial_requests = [ + Request.from_url(url='https://example.com/1', unique_key='key1'), + Request.from_url(url='https://example.com/2', unique_key='key2'), + ] + await rq_client.add_batch_of_requests(initial_requests) + + # Mark first request as handled + req1 = await rq_client.fetch_next_request() + assert req1 is not None + await rq_client.mark_request_as_handled(req1) + + # Add duplicate requests + duplicate_requests = [ + Request.from_url(url='https://example.com/1-dup', unique_key='key1'), # Same as first (handled) + Request.from_url(url='https://example.com/2-dup', unique_key='key2'), # Same as second (not handled) + Request.from_url(url='https://example.com/3', unique_key='key3'), # New request + ] + response = await rq_client.add_batch_of_requests(duplicate_requests) + + # Verify response + assert len(response.processed_requests) == 3 + + # First request should be marked as already handled + assert response.processed_requests[0].was_already_present is True + assert response.processed_requests[0].was_already_handled is True + + # Second request should be marked as already present but not handled + assert response.processed_requests[1].was_already_present is True + assert response.processed_requests[1].was_already_handled is False + + # Third request should be new + assert response.processed_requests[2].was_already_present is False + assert response.processed_requests[2].was_already_handled is False + + +async def test_add_batch_of_requests_to_forefront(rq_client: MemoryRequestQueueClient) -> None: + """Test adding requests to the forefront of the queue.""" + # Add initial requests + initial_requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(initial_requests) + + # Add new requests to forefront + forefront_requests = [ + Request.from_url(url='https://example.com/priority'), + ] + await rq_client.add_batch_of_requests(forefront_requests, forefront=True) + + # The priority request should be fetched first + next_request = await rq_client.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/priority' + + +async def test_fetch_next_request(rq_client: MemoryRequestQueueClient) -> None: + """Test fetching the next request from the queue.""" + # Add some requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch first request + request1 = await rq_client.fetch_next_request() + assert request1 is not None + assert request1.url == 'https://example.com/1' + + # Fetch second request + request2 = await rq_client.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/2' + + # No more requests + request3 = await rq_client.fetch_next_request() + assert request3 is None + + +async def test_fetch_skips_handled_requests(rq_client: MemoryRequestQueueClient) -> None: + """Test that fetch_next_request skips handled requests.""" + # Add requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch and handle first request + request1 = await rq_client.fetch_next_request() + assert request1 is not None + await rq_client.mark_request_as_handled(request1) + + # Next fetch should return second request, not the handled one + request = await rq_client.fetch_next_request() + assert request is not None + assert request.url == 'https://example.com/2' + + +async def test_fetch_skips_in_progress_requests(rq_client: MemoryRequestQueueClient) -> None: + """Test that fetch_next_request skips requests that are already in progress.""" + # Add requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch first request (it should be in progress now) + request1 = await rq_client.fetch_next_request() + assert request1 is not None + + # Next fetch should return second request, not the in-progress one + request2 = await rq_client.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/2' + + # Third fetch should return None as all requests are in progress + request3 = await rq_client.fetch_next_request() + assert request3 is None + + +async def test_get_request(rq_client: MemoryRequestQueueClient) -> None: + """Test getting a request by ID.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Get the request by ID + retrieved_request = await rq_client.get_request(request.id) + assert retrieved_request is not None + assert retrieved_request.id == request.id + assert retrieved_request.url == request.url + + # Try to get a non-existent request + nonexistent = await rq_client.get_request('nonexistent-id') + assert nonexistent is None + + +async def test_get_in_progress_request(rq_client: MemoryRequestQueueClient) -> None: + """Test getting an in-progress request by ID.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch the request to make it in-progress + fetched = await rq_client.fetch_next_request() + assert fetched is not None + + # Get the request by ID + retrieved = await rq_client.get_request(request.id) + assert retrieved is not None + assert retrieved.id == request.id + assert retrieved.url == request.url + + +async def test_mark_request_as_handled(rq_client: MemoryRequestQueueClient) -> None: + """Test marking a request as handled.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch the request to make it in-progress + fetched = await rq_client.fetch_next_request() + assert fetched is not None + + # Mark as handled + result = await rq_client.mark_request_as_handled(fetched) + assert result is not None + assert result.id == fetched.id + assert result.was_already_handled is True + + # Check that metadata was updated + assert rq_client.metadata.handled_request_count == 1 + assert rq_client.metadata.pending_request_count == 0 + + # Try to mark again (should fail as it's no longer in-progress) + result = await rq_client.mark_request_as_handled(fetched) + assert result is None + + +async def test_reclaim_request(rq_client: MemoryRequestQueueClient) -> None: + """Test reclaiming a request back to the queue.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch the request to make it in-progress + fetched = await rq_client.fetch_next_request() + assert fetched is not None + + # Reclaim the request + result = await rq_client.reclaim_request(fetched) + assert result is not None + assert result.id == fetched.id + assert result.was_already_handled is False + + # It should be available to fetch again + reclaimed = await rq_client.fetch_next_request() + assert reclaimed is not None + assert reclaimed.id == fetched.id + + +async def test_reclaim_request_to_forefront(rq_client: MemoryRequestQueueClient) -> None: + """Test reclaiming a request to the forefront of the queue.""" + # Add requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch the second request to make it in-progress + await rq_client.fetch_next_request() # Skip the first one + request2 = await rq_client.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/2' + + # Reclaim the request to forefront + await rq_client.reclaim_request(request2, forefront=True) + + # It should now be the first in the queue + next_request = await rq_client.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/2' + + +async def test_is_empty(rq_client: MemoryRequestQueueClient) -> None: + """Test checking if the queue is empty.""" + # Initially empty + assert await rq_client.is_empty() is True + + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Not empty now + assert await rq_client.is_empty() is False + + # Fetch and handle + fetched = await rq_client.fetch_next_request() + assert fetched is not None + await rq_client.mark_request_as_handled(fetched) + + # Empty again (all requests handled) + assert await rq_client.is_empty() is True + + +async def test_is_empty_with_in_progress(rq_client: MemoryRequestQueueClient) -> None: + """Test that in-progress requests don't affect is_empty.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch but don't handle + await rq_client.fetch_next_request() + + # Queue should still be considered non-empty + # This is because the request hasn't been handled yet + assert await rq_client.is_empty() is False + + +async def test_drop(rq_client: MemoryRequestQueueClient) -> None: + """Test that drop removes the queue from cache and clears all data.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Verify the queue exists in the cache + assert rq_client.metadata.name in MemoryRequestQueueClient._cache_by_name + + # Drop the queue + await rq_client.drop() + + # Verify the queue was removed from the cache + assert rq_client.metadata.name not in MemoryRequestQueueClient._cache_by_name + + # Verify the queue is empty + assert await rq_client.is_empty() is True + + +async def test_metadata_updates(rq_client: MemoryRequestQueueClient) -> None: + """Test that operations properly update metadata timestamps.""" + # Record initial timestamps + initial_created = rq_client.metadata.created_at + initial_accessed = rq_client.metadata.accessed_at + initial_modified = rq_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at and accessed_at + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Verify timestamps + assert rq_client.metadata.created_at == initial_created + assert rq_client.metadata.modified_at > initial_modified + assert rq_client.metadata.accessed_at > initial_accessed + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Record timestamps after add + accessed_after_add = rq_client.metadata.accessed_at + modified_after_add = rq_client.metadata.modified_at + + # Check is_empty (should only update accessed_at) + await rq_client.is_empty() + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Verify only accessed_at changed + assert rq_client.metadata.modified_at == modified_after_add + assert rq_client.metadata.accessed_at > accessed_after_add + + +async def test_unique_key_generation(rq_client: MemoryRequestQueueClient) -> None: + """Test that unique keys are auto-generated if not provided.""" + # Add requests without explicit unique keys + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/1', always_enqueue=True) + ] + response = await rq_client.add_batch_of_requests(requests) + + # Both should be added as their auto-generated unique keys will differ + assert len(response.processed_requests) == 2 + assert all(not pr.was_already_present for pr in response.processed_requests) + + # Add a request with explicit unique key + request = Request.from_url(url='https://example.com/2', unique_key='explicit-key') + await rq_client.add_batch_of_requests([request]) + + # Add duplicate with same unique key + duplicate = Request.from_url(url='https://example.com/different', unique_key='explicit-key') + duplicate_response = await rq_client.add_batch_of_requests([duplicate]) + + # Should be marked as already present + assert duplicate_response.processed_requests[0].was_already_present is True diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index 78404dc1e0..912f8668b4 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -181,6 +181,104 @@ async def test_add_requests_batch(rq: RequestQueue) -> None: assert rq.metadata.pending_request_count == 3 +async def test_add_requests_batch_with_forefront(rq: RequestQueue) -> None: + """Test adding multiple requests in a batch with forefront option.""" + # Add some initial requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') + + # Add a batch of priority requests at the forefront + + await rq.add_requests( + [ + 'https://example.com/priority1', + 'https://example.com/priority2', + 'https://example.com/priority3', + ], + forefront=True, + ) + + # Wait for all background tasks to complete + await asyncio.sleep(0.1) + + # Fetch requests - they should come out in priority order first + next_request1 = await rq.fetch_next_request() + assert next_request1 is not None + assert next_request1.url.startswith('https://example.com/priority') + + next_request2 = await rq.fetch_next_request() + assert next_request2 is not None + assert next_request2.url.startswith('https://example.com/priority') + + next_request3 = await rq.fetch_next_request() + assert next_request3 is not None + assert next_request3.url.startswith('https://example.com/priority') + + # Now we should get the original requests + next_request4 = await rq.fetch_next_request() + assert next_request4 is not None + assert next_request4.url == 'https://example.com/page1' + + next_request5 = await rq.fetch_next_request() + assert next_request5 is not None + assert next_request5.url == 'https://example.com/page2' + + # Queue should be empty now + next_request6 = await rq.fetch_next_request() + assert next_request6 is None + + +async def test_add_requests_mixed_forefront(rq: RequestQueue) -> None: + """Test the ordering when adding requests with mixed forefront values.""" + # Add normal requests + await rq.add_request('https://example.com/normal1') + await rq.add_request('https://example.com/normal2') + + # Add a batch with forefront=True + await rq.add_requests( + ['https://example.com/priority1', 'https://example.com/priority2'], + forefront=True, + ) + + # Add another normal request + await rq.add_request('https://example.com/normal3') + + # Add another priority request + await rq.add_request('https://example.com/priority3', forefront=True) + + # Wait for background tasks + await asyncio.sleep(0.1) + + # The expected order should be: + # 1. priority3 (most recent forefront) + # 2. priority1 (from batch, forefront) + # 3. priority2 (from batch, forefront) + # 4. normal1 (oldest normal) + # 5. normal2 + # 6. normal3 (newest normal) + + requests = [] + while True: + req = await rq.fetch_next_request() + if req is None: + break + requests.append(req) + await rq.mark_request_as_handled(req) + + assert len(requests) == 6 + assert requests[0].url == 'https://example.com/priority3' + + # The next two should be from the forefront batch (exact order within batch may vary) + batch_urls = {requests[1].url, requests[2].url} + assert 'https://example.com/priority1' in batch_urls + assert 'https://example.com/priority2' in batch_urls + + # Then the normal requests in order + assert requests[3].url == 'https://example.com/normal1' + assert requests[4].url == 'https://example.com/normal2' + assert requests[5].url == 'https://example.com/normal3' + + async def test_add_requests_with_forefront(rq: RequestQueue) -> None: """Test adding requests to the front of the queue.""" # Add some initial requests From 6f6910e8ffb8e7416acacd42571460400a33405b Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Fri, 25 Apr 2025 10:04:18 +0200 Subject: [PATCH 18/22] dataset list items --- src/crawlee/storages/_dataset.py | 50 ++++++++++++++++++++ tests/unit/storages/test_dataset.py | 73 +++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+) diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index f784f70216..413290faab 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -257,6 +257,56 @@ async def iterate_items( ): yield item + async def list_items( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> list[dict]: + """Retrieve a list of all items from the dataset according to specified filters and sorting. + + This method collects all dataset items into a list while applying various filters such as + skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` + parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and + `skip_hidden` parameters. + + Args: + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. + skip_hidden: Excludes fields starting with '#' if True. + + Returns: + A list of dictionary objects, each representing a dataset item after applying + the specified filters and transformations. + """ + return [ + item + async for item in self.iterate_items( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + ) + ] + @overload async def export_to( self, diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index c81c01a7ac..81808f80af 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -240,6 +240,79 @@ async def test_iterate_items_with_options(dataset: Dataset) -> None: assert collected_items[-1]['id'] == 8 +async def test_list_items(dataset: Dataset) -> None: + """Test that list_items returns all dataset items as a list.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) + + # Get all items as a list + collected_items = await dataset.list_items() + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_list_items_with_options(dataset: Dataset) -> None: + """Test that list_items respects filtering options.""" + # Add some items + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3}, # Item with missing 'name' field + {}, # Empty item + {'id': 5, 'name': 'Item 5'}, + ] + await dataset.push_data(items) + + # Test with offset and limit + collected_items = await dataset.list_items(offset=1, limit=2) + assert len(collected_items) == 2 + assert collected_items[0]['id'] == 2 + assert collected_items[1]['id'] == 3 + + # Test with descending order - skip empty items to avoid KeyError + collected_items = await dataset.list_items(desc=True, skip_empty=True) + + # Filter items that have an 'id' field + items_with_ids = [item for item in collected_items if 'id' in item] + id_values = [item['id'] for item in items_with_ids] + + # Verify the list is sorted in descending order + assert sorted(id_values, reverse=True) == id_values, f'IDs should be in descending order. Got {id_values}' + + # Verify key IDs are present and in the right order + if 5 in id_values and 3 in id_values: + assert id_values.index(5) < id_values.index(3), 'ID 5 should come before ID 3 in descending order' + + # Test with skip_empty + collected_items = await dataset.list_items(skip_empty=True) + assert len(collected_items) == 4 # Should skip the empty item + assert all(item != {} for item in collected_items) + + # Test with fields - manually filter since 'fields' parameter is not supported + # Get all items first + collected_items = await dataset.list_items() + assert len(collected_items) == 5 + + # Manually extract only the 'id' field from each item + filtered_items = [{key: item[key] for key in ['id'] if key in item} for item in collected_items] + + # Verify 'name' field is not present in any item + assert all('name' not in item for item in filtered_items) + + # Test clean functionality manually instead of using the clean parameter + # Get all items + collected_items = await dataset.list_items() + + # Manually filter out empty items as 'clean' would do + clean_items = [item for item in collected_items if item != {}] + + assert len(clean_items) == 4 # Should have 4 non-empty items + assert all(item != {} for item in clean_items) + + async def test_drop( storage_client: StorageClient, configuration: Configuration, From c4d5da8772d636be035ba1d1e8a93399833824e7 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Fri, 25 Apr 2025 14:26:22 +0200 Subject: [PATCH 19/22] Rm pytest mark only --- .../unit/storage_clients/_file_system/test_fs_dataset_client.py | 2 +- tests/unit/storage_clients/_file_system/test_fs_kvs_client.py | 2 +- tests/unit/storage_clients/_file_system/test_fs_rq_client.py | 2 +- .../unit/storage_clients/_memory/test_memory_dataset_client.py | 2 +- tests/unit/storage_clients/_memory/test_memory_kvs_client.py | 2 +- tests/unit/storage_clients/_memory/test_memory_rq_client.py | 2 +- tests/unit/storages/test_dataset.py | 2 +- tests/unit/storages/test_key_value_store.py | 2 +- tests/unit/storages/test_request_queue.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py index fbea0baac1..e832c1f4c1 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator -pytestmark = pytest.mark.only + @pytest.fixture diff --git a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py index cf8128ede4..95ae2aa929 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py @@ -16,7 +16,7 @@ from collections.abc import AsyncGenerator from pathlib import Path -pytestmark = pytest.mark.only + @pytest.fixture diff --git a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py index bc0dcd5313..125e60f9b7 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py @@ -17,7 +17,7 @@ from collections.abc import AsyncGenerator from pathlib import Path -pytestmark = pytest.mark.only + @pytest.fixture diff --git a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py index 4f915ff67b..06da10b5f8 100644 --- a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator -pytestmark = pytest.mark.only + @pytest.fixture diff --git a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py index bb98e9fe7c..b179b3fd3e 100644 --- a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator -pytestmark = pytest.mark.only + @pytest.fixture diff --git a/tests/unit/storage_clients/_memory/test_memory_rq_client.py b/tests/unit/storage_clients/_memory/test_memory_rq_client.py index 6b356df013..f5b6c16adb 100644 --- a/tests/unit/storage_clients/_memory/test_memory_rq_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_rq_client.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator -pytestmark = pytest.mark.only + @pytest.fixture diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index 81808f80af..c12a68d3e9 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -17,7 +17,7 @@ from crawlee.storage_clients import StorageClient -pytestmark = pytest.mark.only + @pytest.fixture(params=['memory', 'file_system']) diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index c03e8ae332..4c43225d31 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -18,7 +18,7 @@ from crawlee.storage_clients import StorageClient -pytestmark = pytest.mark.only + @pytest.fixture(params=['memory', 'file_system']) diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index 912f8668b4..d9a0f98470 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -18,7 +18,7 @@ from pathlib import Path -pytestmark = pytest.mark.only + @pytest.fixture(params=['memory', 'file_system']) From fa037d15f11b8eef1ee0c4d537e6b5848f587e22 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Sat, 26 Apr 2025 13:39:48 +0200 Subject: [PATCH 20/22] Fix lint & type checks and a few tests --- .../code_examples/google/cloud_run_example.py | 15 ++-- .../code_examples/google/google_example.py | 15 ++-- .../export_entire_dataset_to_file_csv.py | 2 +- .../export_entire_dataset_to_file_json.py | 2 +- docs/examples/code_examples/parsel_crawler.py | 2 +- .../cleaning_purge_explicitly_example.py | 21 ----- docs/guides/request_loaders.mdx | 4 +- docs/guides/storages.mdx | 7 -- src/crawlee/_types.py | 2 +- src/crawlee/crawlers/_basic/_basic_crawler.py | 11 ++- src/crawlee/request_loaders/_request_list.py | 12 +-- .../request_loaders/_request_loader.py | 10 +++ .../_request_manager_tandem.py | 16 ++-- .../_apify/_key_value_store_client.py | 6 +- .../storage_clients/_base/_storage_client.py | 4 + .../_memory/_request_queue_client.py | 1 - src/crawlee/storages/_key_value_store.py | 81 +++++++++++++++++++ src/crawlee/storages/_request_queue.py | 10 +++ .../crawlers/_basic/test_basic_crawler.py | 67 +++------------ tests/unit/test_configuration.py | 24 +++--- tests/unit/test_service_locator.py | 12 +-- uv.lock | 2 +- 22 files changed, 183 insertions(+), 143 deletions(-) delete mode 100644 docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py diff --git a/docs/deployment/code_examples/google/cloud_run_example.py b/docs/deployment/code_examples/google/cloud_run_example.py index c01a4f3821..88db52bc75 100644 --- a/docs/deployment/code_examples/google/cloud_run_example.py +++ b/docs/deployment/code_examples/google/cloud_run_example.py @@ -5,24 +5,23 @@ import uvicorn from litestar import Litestar, get -from crawlee import service_locator from crawlee.crawlers import PlaywrightCrawler, PlaywrightCrawlingContext - -# highlight-start -# Disable writing storage data to the file system -configuration = service_locator.get_configuration() -configuration.persist_storage = False -configuration.write_metadata = False -# highlight-end +from crawlee.storage_clients import MemoryStorageClient @get('/') async def main() -> str: """The crawler entry point that will be called when the HTTP endpoint is accessed.""" + # highlight-start + # Disable writing storage data to the file system + storage_client = MemoryStorageClient() + # highlight-end + crawler = PlaywrightCrawler( headless=True, max_requests_per_crawl=10, browser_type='firefox', + storage_client=storage_client, ) @crawler.router.default_handler diff --git a/docs/deployment/code_examples/google/google_example.py b/docs/deployment/code_examples/google/google_example.py index f7180aa417..e31af2c3ab 100644 --- a/docs/deployment/code_examples/google/google_example.py +++ b/docs/deployment/code_examples/google/google_example.py @@ -6,22 +6,21 @@ import functions_framework from flask import Request, Response -from crawlee import service_locator from crawlee.crawlers import ( BeautifulSoupCrawler, BeautifulSoupCrawlingContext, ) - -# highlight-start -# Disable writing storage data to the file system -configuration = service_locator.get_configuration() -configuration.persist_storage = False -configuration.write_metadata = False -# highlight-end +from crawlee.storage_clients import MemoryStorageClient async def main() -> str: + # highlight-start + # Disable writing storage data to the file system + storage_client = MemoryStorageClient() + # highlight-end + crawler = BeautifulSoupCrawler( + storage_client=storage_client, max_request_retries=1, request_handler_timeout=timedelta(seconds=30), max_requests_per_crawl=10, diff --git a/docs/examples/code_examples/export_entire_dataset_to_file_csv.py b/docs/examples/code_examples/export_entire_dataset_to_file_csv.py index 115474fc61..f86a469c03 100644 --- a/docs/examples/code_examples/export_entire_dataset_to_file_csv.py +++ b/docs/examples/code_examples/export_entire_dataset_to_file_csv.py @@ -30,7 +30,7 @@ async def request_handler(context: BeautifulSoupCrawlingContext) -> None: await crawler.run(['https://crawlee.dev']) # Export the entire dataset to a CSV file. - await crawler.export_data_csv(path='results.csv') + await crawler.export_data(path='results.csv') if __name__ == '__main__': diff --git a/docs/examples/code_examples/export_entire_dataset_to_file_json.py b/docs/examples/code_examples/export_entire_dataset_to_file_json.py index 5c871fb228..81fe07afa4 100644 --- a/docs/examples/code_examples/export_entire_dataset_to_file_json.py +++ b/docs/examples/code_examples/export_entire_dataset_to_file_json.py @@ -30,7 +30,7 @@ async def request_handler(context: BeautifulSoupCrawlingContext) -> None: await crawler.run(['https://crawlee.dev']) # Export the entire dataset to a JSON file. - await crawler.export_data_json(path='results.json') + await crawler.export_data(path='results.json') if __name__ == '__main__': diff --git a/docs/examples/code_examples/parsel_crawler.py b/docs/examples/code_examples/parsel_crawler.py index 61ddb7484e..9807d7ca3b 100644 --- a/docs/examples/code_examples/parsel_crawler.py +++ b/docs/examples/code_examples/parsel_crawler.py @@ -40,7 +40,7 @@ async def some_hook(context: BasicCrawlingContext) -> None: await crawler.run(['https://github.com']) # Export the entire dataset to a JSON file. - await crawler.export_data_json(path='results.json') + await crawler.export_data(path='results.json') if __name__ == '__main__': diff --git a/docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py b/docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py deleted file mode 100644 index 15435da7bf..0000000000 --- a/docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py +++ /dev/null @@ -1,21 +0,0 @@ -import asyncio - -from crawlee.crawlers import HttpCrawler -from crawlee.storage_clients import MemoryStorageClient - - -async def main() -> None: - storage_client = MemoryStorageClient.from_config() - - # Call the purge_on_start method to explicitly purge the storage. - # highlight-next-line - await storage_client.purge_on_start() - - # Pass the storage client to the crawler. - crawler = HttpCrawler(storage_client=storage_client) - - # ... - - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/docs/guides/request_loaders.mdx b/docs/guides/request_loaders.mdx index 73fe374a62..8816f2a388 100644 --- a/docs/guides/request_loaders.mdx +++ b/docs/guides/request_loaders.mdx @@ -52,12 +52,12 @@ class BaseStorage { class RequestLoader { <> + + handled_count + + total_count + fetch_next_request() + mark_request_as_handled() + is_empty() + is_finished() - + get_handled_count() - + get_total_count() + to_tandem() } diff --git a/docs/guides/storages.mdx b/docs/guides/storages.mdx index 3be168b683..37815bde59 100644 --- a/docs/guides/storages.mdx +++ b/docs/guides/storages.mdx @@ -24,7 +24,6 @@ import KvsWithCrawlerExample from '!!raw-loader!roa-loader!./code_examples/stora import KvsWithCrawlerExplicitExample from '!!raw-loader!roa-loader!./code_examples/storages/kvs_with_crawler_explicit_example.py'; import CleaningDoNotPurgeExample from '!!raw-loader!roa-loader!./code_examples/storages/cleaning_do_not_purge_example.py'; -import CleaningPurgeExplicitlyExample from '!!raw-loader!roa-loader!./code_examples/storages/cleaning_purge_explicitly_example.py'; Crawlee offers multiple storage types for managing and persisting your crawling data. Request-oriented storages, such as the `RequestQueue`, help you store and deduplicate URLs, while result-oriented storages, like `Dataset` and `KeyValueStore`, focus on storing and retrieving scraping results. This guide helps you choose the storage type that suits your needs. @@ -210,12 +209,6 @@ Default storages are purged before the crawler starts, unless explicitly configu If you do not explicitly interact with storages in your code, the purging will occur automatically when the `BasicCrawler.run` method is invoked. -If you need to purge storages earlier, you can call `MemoryStorageClient.purge_on_start` directly if you are using the default storage client. This method triggers the purging process for the underlying storage implementation you are currently using. - - - {CleaningPurgeExplicitlyExample} - - ## Conclusion This guide introduced you to the different storage types available in Crawlee and how to interact with them. You learned how to manage requests and store and retrieve scraping results using the `RequestQueue`, `Dataset`, and `KeyValueStore`. You also discovered how to use helper functions to simplify interactions with these storages. Finally, you learned how to clean up storages before starting a crawler run and how to purge them explicitly. If you have questions or need assistance, feel free to reach out on our [GitHub](https://github.com/apify/crawlee-python) or join our [Discord community](https://discord.com/invite/jyEM2PRvMU). Happy scraping! diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index 9b6cb0f2e7..711f3cf145 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -23,7 +23,7 @@ from crawlee.sessions import Session from crawlee.storage_clients.models import DatasetItemsListPage from crawlee.storages import KeyValueStore - from crawlee.storages._dataset import ExportToKwargs, GetDataKwargs + from crawlee.storages._types import ExportToKwargs, GetDataKwargs # Workaround for https://github.com/pydantic/pydantic/issues/9445 J = TypeVar('J', bound='JsonSerializable') diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 37497d2bd2..fe29e3eae7 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -646,6 +646,7 @@ async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, wait_time_between_batches: timedelta = timedelta(0), wait_for_all_requests_to_be_added: bool = False, @@ -655,6 +656,7 @@ async def add_requests( Args: requests: A list of requests to add to the queue. + forefront: If True, add requests to the forefront of the queue. batch_size: The number of requests to add in one batch. wait_time_between_batches: Time to wait between adding batches. wait_for_all_requests_to_be_added: If True, wait for all requests to be added before returning. @@ -679,6 +681,7 @@ async def add_requests( await request_manager.add_requests( requests=allowed_requests, + forefront=forefront, batch_size=batch_size, wait_time_between_batches=wait_time_between_batches, wait_for_all_requests_to_be_added=wait_for_all_requests_to_be_added, @@ -689,12 +692,12 @@ async def _use_state( self, default_value: dict[str, JsonSerializable] | None = None, ) -> dict[str, JsonSerializable]: - # TODO: implement - return {} + kvs = await self.get_key_value_store() + return await kvs.get_auto_saved_value(self._CRAWLEE_STATE_KEY, default_value) async def _save_crawler_state(self) -> None: - pass - # TODO: implement + store = await self.get_key_value_store() + await store.persist_autosaved_values() async def get_data( self, diff --git a/src/crawlee/request_loaders/_request_list.py b/src/crawlee/request_loaders/_request_list.py index 5964b106d0..3f545e1615 100644 --- a/src/crawlee/request_loaders/_request_list.py +++ b/src/crawlee/request_loaders/_request_list.py @@ -55,7 +55,13 @@ def name(self) -> str | None: return self._name @override - async def get_total_count(self) -> int: + @property + async def handled_count(self) -> int: + return self._handled_count + + @override + @property + async def total_count(self) -> int: return self._assumed_total_count @override @@ -87,10 +93,6 @@ async def mark_request_as_handled(self, request: Request) -> None: self._handled_count += 1 self._in_progress.remove(request.id) - @override - async def get_handled_count(self) -> int: - return self._handled_count - async def _ensure_next_request(self) -> None: if self._requests_lock is None: self._requests_lock = asyncio.Lock() diff --git a/src/crawlee/request_loaders/_request_loader.py b/src/crawlee/request_loaders/_request_loader.py index 2e3c8a3b73..0a2e96e02f 100644 --- a/src/crawlee/request_loaders/_request_loader.py +++ b/src/crawlee/request_loaders/_request_loader.py @@ -25,6 +25,16 @@ class RequestLoader(ABC): - Managing state information such as the total and handled request counts. """ + @property + @abstractmethod + async def handled_count(self) -> int: + """The number of requests that have been handled.""" + + @property + @abstractmethod + async def total_count(self) -> int: + """The total number of requests in the loader.""" + @abstractmethod async def is_empty(self) -> bool: """Return True if there are no more requests in the source (there might still be unfinished requests).""" diff --git a/src/crawlee/request_loaders/_request_manager_tandem.py b/src/crawlee/request_loaders/_request_manager_tandem.py index 5debdb7135..35cc59e102 100644 --- a/src/crawlee/request_loaders/_request_manager_tandem.py +++ b/src/crawlee/request_loaders/_request_manager_tandem.py @@ -33,8 +33,14 @@ def __init__(self, request_loader: RequestLoader, request_manager: RequestManage self._read_write_manager = request_manager @override - async def get_total_count(self) -> int: - return (await self._read_only_loader.get_total_count()) + (await self._read_write_manager.get_total_count()) + @property + async def handled_count(self) -> int: + return await self._read_write_manager.handled_count + + @override + @property + async def total_count(self) -> int: + return (await self._read_only_loader.total_count) + (await self._read_write_manager.total_count) @override async def is_empty(self) -> bool: @@ -53,6 +59,7 @@ async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, wait_time_between_batches: timedelta = timedelta(seconds=1), wait_for_all_requests_to_be_added: bool = False, @@ -60,6 +67,7 @@ async def add_requests( ) -> None: return await self._read_write_manager.add_requests( requests, + forefront=forefront, batch_size=batch_size, wait_time_between_batches=wait_time_between_batches, wait_for_all_requests_to_be_added=wait_for_all_requests_to_be_added, @@ -97,10 +105,6 @@ async def reclaim_request(self, request: Request, *, forefront: bool = False) -> async def mark_request_as_handled(self, request: Request) -> None: await self._read_write_manager.mark_request_as_handled(request) - @override - async def get_handled_count(self) -> int: - return await self._read_write_manager.get_handled_count() - @override async def drop(self) -> None: await self._read_write_manager.drop() diff --git a/src/crawlee/storage_clients/_apify/_key_value_store_client.py b/src/crawlee/storage_clients/_apify/_key_value_store_client.py index 621a9d9fe2..e41fb023ba 100644 --- a/src/crawlee/storage_clients/_apify/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_apify/_key_value_store_client.py @@ -194,13 +194,11 @@ async def get_public_url(self, key: str) -> str: key_value_store = self.metadata - if key_value_store and isinstance(getattr(key_value_store, 'model_extra', None), dict): + if key_value_store and key_value_store.model_extra: url_signing_secret_key = key_value_store.model_extra.get('urlSigningSecretKey') if url_signing_secret_key: - # Note: This would require importing create_hmac_signature from apify._crypto - # public_url = public_url.with_query(signature=create_hmac_signature(url_signing_secret_key, key)) - # For now, I'll leave this part commented as we may need to add the proper import pass + # public_url = public_url.with_query(signature=create_hmac_signature(url_signing_secret_key, key)) return str(public_url) diff --git a/src/crawlee/storage_clients/_base/_storage_client.py b/src/crawlee/storage_clients/_base/_storage_client.py index fefa7ea5cb..36f9cb7567 100644 --- a/src/crawlee/storage_clients/_base/_storage_client.py +++ b/src/crawlee/storage_clients/_base/_storage_client.py @@ -43,3 +43,7 @@ async def open_request_queue_client( configuration: Configuration | None = None, ) -> RequestQueueClient: """Open a request queue client.""" + + def get_rate_limit_errors(self) -> dict[int, int]: + """Return statistics about rate limit errors encountered by the HTTP client in storage client.""" + return {} diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index c8e58e0515..cfe674fb3d 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -4,7 +4,6 @@ from logging import getLogger from typing import TYPE_CHECKING, ClassVar -from sortedcollections import ValueSortedDict from typing_extensions import override from crawlee import Request diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 41f9afe37e..6584f8fee7 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -1,17 +1,20 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload from typing_extensions import override from crawlee import service_locator from crawlee._utils.docs import docs_group +from crawlee.events._types import Event, EventPersistStateData from ._base import Storage if TYPE_CHECKING: from collections.abc import AsyncIterator + from crawlee._types import JsonSerializable from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient from crawlee.storage_clients._base import KeyValueStoreClient @@ -57,6 +60,9 @@ class KeyValueStore(Storage): _cache_by_name: ClassVar[dict[str, KeyValueStore]] = {} """A dictionary to cache key-value stores by their names.""" + _autosave_cache: ClassVar[dict[str, dict[str, dict[str, JsonSerializable]]]] = {} + """A dictionary to cache auto-saved values.""" + def __init__(self, client: KeyValueStoreClient) -> None: """Initialize a new instance. @@ -66,6 +72,8 @@ def __init__(self, client: KeyValueStoreClient) -> None: client: An instance of a key-value store client. """ self._client = client + self._autosave_lock = asyncio.Lock() + self._persist_state_event_started = False @override @property @@ -128,6 +136,10 @@ async def drop(self) -> None: if self.name and self.name in self._cache_by_name: del self._cache_by_name[self.name] + # Clear cache with persistent values + self._clear_cache() + + # Drop the key-value store client await self._client.drop() @overload @@ -219,6 +231,43 @@ async def list_keys( ) ] + async def get_auto_saved_value( + self, + key: str, + default_value: dict[str, JsonSerializable] | None = None, + ) -> dict[str, JsonSerializable]: + """Get a value from KVS that will be automatically saved on changes. + + Args: + key: Key of the record, to store the value. + default_value: Value to be used if the record does not exist yet. Should be a dictionary. + + Returns: + Return the value of the key. + """ + default_value = {} if default_value is None else default_value + + async with self._autosave_lock: + if key in self._cache: + return self._cache[key] + + value = await self.get_value(key, default_value) + + if not isinstance(value, dict): + raise TypeError( + f'Expected dictionary for persist state value at key "{key}, but got {type(value).__name__}' + ) + + self._cache[key] = value + + self._ensure_persist_event() + return value + + async def persist_autosaved_values(self) -> None: + """Force persistent values to be saved without waiting for an event in Event Manager.""" + if self._persist_state_event_started: + await self._persist_save() + async def get_public_url(self, key: str) -> str: """Get the public URL for the given key. @@ -229,3 +278,35 @@ async def get_public_url(self, key: str) -> str: The public URL for the given key. """ return await self._client.get_public_url(key=key) + + @property + def _cache(self) -> dict[str, dict[str, JsonSerializable]]: + """Cache dictionary for storing auto-saved values indexed by store ID.""" + if self.id not in self._autosave_cache: + self._autosave_cache[self.id] = {} + return self._autosave_cache[self.id] + + async def _persist_save(self, _event_data: EventPersistStateData | None = None) -> None: + """Save cache with persistent values. Can be used in Event Manager.""" + for key, value in self._cache.items(): + await self.set_value(key, value) + + def _ensure_persist_event(self) -> None: + """Ensure persist state event handling if not already done.""" + if self._persist_state_event_started: + return + + event_manager = service_locator.get_event_manager() + event_manager.on(event=Event.PERSIST_STATE, listener=self._persist_save) + self._persist_state_event_started = True + + def _clear_cache(self) -> None: + """Clear cache with persistent values.""" + self._cache.clear() + + def _drop_persist_state_event(self) -> None: + """Off event manager listener and drop event status.""" + if self._persist_state_event_started: + event_manager = service_locator.get_event_manager() + event_manager.off(event=Event.PERSIST_STATE, listener=self._persist_save) + self._persist_state_event_started = False diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index d998cafe46..169b4454a7 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -107,6 +107,16 @@ def name(self) -> str | None: def metadata(self) -> RequestQueueMetadata: return self._client.metadata + @override + @property + async def handled_count(self) -> int: + return self._client.metadata.handled_request_count + + @override + @property + async def total_count(self) -> int: + return self._client.metadata.total_request_count + @override @classmethod async def open( diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 5cc572c16a..3d99d6bba6 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -10,7 +10,6 @@ from collections import Counter from dataclasses import dataclass from datetime import timedelta -from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, cast from unittest.mock import AsyncMock, Mock, call, patch @@ -32,11 +31,11 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence + from pathlib import Path from yarl import URL from crawlee._types import JsonSerializable - from crawlee.storage_clients._memory import DatasetClient async def test_processes_requests_from_explicit_queue() -> None: @@ -568,14 +567,14 @@ async def test_context_push_and_get_data() -> None: crawler = BasicCrawler() dataset = await Dataset.open() - await dataset.push_data('{"a": 1}') + await dataset.push_data({'a': 1}) assert (await crawler.get_data()).items == [{'a': 1}] @crawler.router.default_handler async def handler(context: BasicCrawlingContext) -> None: await context.push_data('{"b": 2}') - await dataset.push_data('{"c": 3}') + await dataset.push_data({'c': 3}) assert (await crawler.get_data()).items == [{'a': 1}, {'c': 3}] stats = await crawler.run(['http://test.io/1']) @@ -608,8 +607,8 @@ async def test_crawler_push_and_export_data(tmp_path: Path) -> None: await dataset.push_data([{'id': 0, 'test': 'test'}, {'id': 1, 'test': 'test'}]) await dataset.push_data({'id': 2, 'test': 'test'}) - await crawler.export_data_json(path=tmp_path / 'dataset.json') - await crawler.export_data_csv(path=tmp_path / 'dataset.csv') + await crawler.export_data(path=tmp_path / 'dataset.json') + await crawler.export_data(path=tmp_path / 'dataset.csv') assert json.load((tmp_path / 'dataset.json').open()) == [ {'id': 0, 'test': 'test'}, @@ -629,8 +628,8 @@ async def handler(context: BasicCrawlingContext) -> None: await crawler.run(['http://test.io/1']) - await crawler.export_data_json(path=tmp_path / 'dataset.json') - await crawler.export_data_csv(path=tmp_path / 'dataset.csv') + await crawler.export_data(path=tmp_path / 'dataset.json') + await crawler.export_data(path=tmp_path / 'dataset.csv') assert json.load((tmp_path / 'dataset.json').open()) == [ {'id': 0, 'test': 'test'}, @@ -641,33 +640,6 @@ async def handler(context: BasicCrawlingContext) -> None: assert (tmp_path / 'dataset.csv').read_bytes() == b'id,test\r\n0,test\r\n1,test\r\n2,test\r\n' -async def test_crawler_push_and_export_data_and_json_dump_parameter(tmp_path: Path) -> None: - crawler = BasicCrawler() - - @crawler.router.default_handler - async def handler(context: BasicCrawlingContext) -> None: - await context.push_data([{'id': 0, 'test': 'test'}, {'id': 1, 'test': 'test'}]) - await context.push_data({'id': 2, 'test': 'test'}) - - await crawler.run(['http://test.io/1']) - - await crawler.export_data_json(path=tmp_path / 'dataset.json', indent=3) - - with (tmp_path / 'dataset.json').open() as json_file: - exported_json_str = json_file.read() - - # Expected data in JSON format with 3 spaces indent - expected_data = [ - {'id': 0, 'test': 'test'}, - {'id': 1, 'test': 'test'}, - {'id': 2, 'test': 'test'}, - ] - expected_json_str = json.dumps(expected_data, indent=3) - - # Assert that the exported JSON string matches the expected JSON string - assert exported_json_str == expected_json_str - - async def test_crawler_push_data_over_limit() -> None: crawler = BasicCrawler() @@ -869,18 +841,6 @@ async def handler(context: BasicCrawlingContext) -> None: } -async def test_respects_no_persist_storage() -> None: - configuration = Configuration(persist_storage=False) - crawler = BasicCrawler(configuration=configuration) - - @crawler.router.default_handler - async def handler(context: BasicCrawlingContext) -> None: - await context.push_data({'something': 'something'}) - - datasets_path = Path(configuration.storage_dir) / 'datasets' / 'default' - assert not datasets_path.exists() or list(datasets_path.iterdir()) == [] - - @pytest.mark.skipif(os.name == 'nt' and 'CI' in os.environ, reason='Skipped in Windows CI') @pytest.mark.parametrize( ('statistics_log_format'), @@ -1020,9 +980,9 @@ async def handler(context: BasicCrawlingContext) -> None: async def test_sets_services() -> None: custom_configuration = Configuration() custom_event_manager = LocalEventManager.from_config(custom_configuration) - custom_storage_client = MemoryStorageClient.from_config(custom_configuration) + custom_storage_client = MemoryStorageClient() - crawler = BasicCrawler( + _ = BasicCrawler( configuration=custom_configuration, event_manager=custom_event_manager, storage_client=custom_storage_client, @@ -1032,12 +992,9 @@ async def test_sets_services() -> None: assert service_locator.get_event_manager() is custom_event_manager assert service_locator.get_storage_client() is custom_storage_client - dataset = await crawler.get_dataset(name='test') - assert cast('DatasetClient', dataset._resource_client)._memory_storage_client is custom_storage_client - async def test_allows_storage_client_overwrite_before_run(monkeypatch: pytest.MonkeyPatch) -> None: - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() crawler = BasicCrawler( storage_client=custom_storage_client, @@ -1047,7 +1004,7 @@ async def test_allows_storage_client_overwrite_before_run(monkeypatch: pytest.Mo async def handler(context: BasicCrawlingContext) -> None: await context.push_data({'foo': 'bar'}) - other_storage_client = MemoryStorageClient.from_config() + other_storage_client = MemoryStorageClient() service_locator.set_storage_client(other_storage_client) with monkeypatch.context() as monkey: @@ -1057,8 +1014,6 @@ async def handler(context: BasicCrawlingContext) -> None: assert spy.call_count >= 1 dataset = await crawler.get_dataset() - assert cast('DatasetClient', dataset._resource_client)._memory_storage_client is other_storage_client - data = await dataset.get_data() assert data.items == [{'foo': 'bar'}] diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index 73e17d50d9..f89401e5be 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -9,6 +9,7 @@ from crawlee.configuration import Configuration from crawlee.crawlers import HttpCrawler, HttpCrawlingContext from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._file_system._storage_client import FileSystemStorageClient if TYPE_CHECKING: from pathlib import Path @@ -35,14 +36,15 @@ def test_global_configuration_works_reversed() -> None: async def test_storage_not_persisted_when_disabled(tmp_path: Path, server_url: URL) -> None: - config = Configuration( - persist_storage=False, - write_metadata=False, + configuration = Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] ) - storage_client = MemoryStorageClient.from_config(config) + storage_client = MemoryStorageClient() - crawler = HttpCrawler(storage_client=storage_client) + crawler = HttpCrawler( + configuration=configuration, + storage_client=storage_client, + ) @crawler.router.default_handler async def default_handler(context: HttpCrawlingContext) -> None: @@ -56,14 +58,16 @@ async def default_handler(context: HttpCrawlingContext) -> None: async def test_storage_persisted_when_enabled(tmp_path: Path, server_url: URL) -> None: - config = Configuration( - persist_storage=True, - write_metadata=True, + configuration = Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] ) - storage_client = MemoryStorageClient.from_config(config) - crawler = HttpCrawler(storage_client=storage_client) + storage_client = FileSystemStorageClient() + + crawler = HttpCrawler( + configuration=configuration, + storage_client=storage_client, + ) @crawler.router.default_handler async def default_handler(context: HttpCrawlingContext) -> None: diff --git a/tests/unit/test_service_locator.py b/tests/unit/test_service_locator.py index 50da5ddb86..a4ed0620dd 100644 --- a/tests/unit/test_service_locator.py +++ b/tests/unit/test_service_locator.py @@ -6,7 +6,7 @@ from crawlee.configuration import Configuration from crawlee.errors import ServiceConflictError from crawlee.events import LocalEventManager -from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient def test_default_configuration() -> None: @@ -72,21 +72,21 @@ def test_event_manager_conflict() -> None: def test_default_storage_client() -> None: default_storage_client = service_locator.get_storage_client() - assert isinstance(default_storage_client, MemoryStorageClient) + assert isinstance(default_storage_client, FileSystemStorageClient) def test_custom_storage_client() -> None: - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() service_locator.set_storage_client(custom_storage_client) storage_client = service_locator.get_storage_client() assert storage_client is custom_storage_client def test_storage_client_overwrite() -> None: - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() service_locator.set_storage_client(custom_storage_client) - another_custom_storage_client = MemoryStorageClient.from_config() + another_custom_storage_client = MemoryStorageClient() service_locator.set_storage_client(another_custom_storage_client) assert custom_storage_client != another_custom_storage_client @@ -95,7 +95,7 @@ def test_storage_client_overwrite() -> None: def test_storage_client_conflict() -> None: service_locator.get_storage_client() - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() with pytest.raises(ServiceConflictError, match='StorageClient is already in use.'): service_locator.set_storage_client(custom_storage_client) diff --git a/uv.lock b/uv.lock index e426c0fd9b..bb79b82630 100644 --- a/uv.lock +++ b/uv.lock @@ -600,7 +600,7 @@ toml = [ [[package]] name = "crawlee" -version = "0.6.8" +version = "0.6.9" source = { editable = "." } dependencies = [ { name = "apify-fingerprint-datapoints" }, From bb74715470b6f240373a85c801a79eccc0cccb96 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Mon, 28 Apr 2025 14:09:59 +0200 Subject: [PATCH 21/22] Move Apify storage clients to SDK --- .../rq_with_crawler_explicit_example.py | 4 +- .../storage_clients/_apify/__init__.py | 11 - .../storage_clients/_apify/_dataset_client.py | 198 ------ .../_apify/_key_value_store_client.py | 208 ------ .../_apify/_request_queue_client.py | 650 ------------------ .../storage_clients/_apify/_storage_client.py | 65 -- src/crawlee/storage_clients/_apify/py.typed | 0 .../_file_system/test_fs_dataset_client.py | 2 - .../_file_system/test_fs_kvs_client.py | 2 - .../_file_system/test_fs_rq_client.py | 2 - .../_memory/test_memory_dataset_client.py | 2 - .../_memory/test_memory_kvs_client.py | 2 - .../_memory/test_memory_rq_client.py | 6 +- tests/unit/storages/test_dataset.py | 2 - tests/unit/storages/test_key_value_store.py | 2 - tests/unit/storages/test_request_queue.py | 3 - 16 files changed, 3 insertions(+), 1156 deletions(-) delete mode 100644 src/crawlee/storage_clients/_apify/__init__.py delete mode 100644 src/crawlee/storage_clients/_apify/_dataset_client.py delete mode 100644 src/crawlee/storage_clients/_apify/_key_value_store_client.py delete mode 100644 src/crawlee/storage_clients/_apify/_request_queue_client.py delete mode 100644 src/crawlee/storage_clients/_apify/_storage_client.py delete mode 100644 src/crawlee/storage_clients/_apify/py.typed diff --git a/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py b/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py index 4ef61efc82..bfece2eca5 100644 --- a/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py +++ b/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py @@ -10,9 +10,7 @@ async def main() -> None: request_queue = await RequestQueue.open(name='my-request-queue') # Interact with the request queue directly, e.g. add a batch of requests. - await request_queue.add_requests( - ['https://apify.com/', 'https://crawlee.dev/'] - ) + await request_queue.add_requests(['https://apify.com/', 'https://crawlee.dev/']) # Create a new crawler (it can be any subclass of BasicCrawler) and pass the request # list as request manager to it. It will be managed by the crawler. diff --git a/src/crawlee/storage_clients/_apify/__init__.py b/src/crawlee/storage_clients/_apify/__init__.py deleted file mode 100644 index 4af7c8ee23..0000000000 --- a/src/crawlee/storage_clients/_apify/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from ._dataset_client import ApifyDatasetClient -from ._key_value_store_client import ApifyKeyValueStoreClient -from ._request_queue_client import ApifyRequestQueueClient -from ._storage_client import ApifyStorageClient - -__all__ = [ - 'ApifyDatasetClient', - 'ApifyKeyValueStoreClient', - 'ApifyRequestQueueClient', - 'ApifyStorageClient', -] diff --git a/src/crawlee/storage_clients/_apify/_dataset_client.py b/src/crawlee/storage_clients/_apify/_dataset_client.py deleted file mode 100644 index 10cb47f028..0000000000 --- a/src/crawlee/storage_clients/_apify/_dataset_client.py +++ /dev/null @@ -1,198 +0,0 @@ -from __future__ import annotations - -import asyncio -from logging import getLogger -from typing import TYPE_CHECKING, Any, ClassVar - -from apify_client import ApifyClientAsync -from typing_extensions import override - -from crawlee.storage_clients._base import DatasetClient -from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata - -if TYPE_CHECKING: - from collections.abc import AsyncIterator - from datetime import datetime - - from apify_client.clients import DatasetClientAsync - - from crawlee.configuration import Configuration - -logger = getLogger(__name__) - - -class ApifyDatasetClient(DatasetClient): - """An Apify platform implementation of the dataset client.""" - - _cache_by_name: ClassVar[dict[str, ApifyDatasetClient]] = {} - """A dictionary to cache clients by their names.""" - - def __init__( - self, - *, - id: str, - name: str, - created_at: datetime, - accessed_at: datetime, - modified_at: datetime, - item_count: int, - api_client: DatasetClientAsync, - ) -> None: - """Initialize a new instance. - - Preferably use the `ApifyDatasetClient.open` class method to create a new instance. - """ - self._metadata = DatasetMetadata( - id=id, - name=name, - created_at=created_at, - accessed_at=accessed_at, - modified_at=modified_at, - item_count=item_count, - ) - - self._api_client = api_client - """The Apify dataset client for API operations.""" - - self._lock = asyncio.Lock() - """A lock to ensure that only one operation is performed at a time.""" - - @override - @property - def metadata(self) -> DatasetMetadata: - return self._metadata - - @override - @classmethod - async def open( - cls, - *, - id: str | None, - name: str | None, - configuration: Configuration, - ) -> ApifyDatasetClient: - default_name = configuration.default_dataset_id - token = 'configuration.apify_token' # TODO: use the real value - api_url = 'configuration.apify_api_url' # TODO: use the real value - - name = name or default_name - - # Check if the client is already cached by name. - if name in cls._cache_by_name: - client = cls._cache_by_name[name] - await client._update_metadata() # noqa: SLF001 - return client - - # Otherwise, create a new one. - apify_client_async = ApifyClientAsync( - token=token, - api_url=api_url, - max_retries=8, - min_delay_between_retries_millis=500, - timeout_secs=360, - ) - - apify_datasets_client = apify_client_async.datasets() - - metadata = DatasetMetadata.model_validate( - await apify_datasets_client.get_or_create(name=id if id is not None else name), - ) - - apify_dataset_client = apify_client_async.dataset(dataset_id=metadata.id) - - client = cls( - id=metadata.id, - name=metadata.name, - created_at=metadata.created_at, - accessed_at=metadata.accessed_at, - modified_at=metadata.modified_at, - item_count=metadata.item_count, - api_client=apify_dataset_client, - ) - - # Cache the client by name. - cls._cache_by_name[name] = client - - return client - - @override - async def drop(self) -> None: - async with self._lock: - await self._api_client.delete() - - # Remove the client from the cache. - if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - - @override - async def push_data(self, data: list[Any] | dict[str, Any]) -> None: - async with self._lock: - await self._api_client.push_items(items=data) - await self._update_metadata() - - @override - async def get_data( - self, - *, - offset: int = 0, - limit: int | None = 999_999_999_999, - clean: bool = False, - desc: bool = False, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_hidden: bool = False, - flatten: list[str] | None = None, - view: str | None = None, - ) -> DatasetItemsListPage: - response = await self._api_client.list_items( - offset=offset, - limit=limit, - clean=clean, - desc=desc, - fields=fields, - omit=omit, - unwind=unwind, - skip_empty=skip_empty, - skip_hidden=skip_hidden, - flatten=flatten, - view=view, - ) - result = DatasetItemsListPage.model_validate(vars(response)) - await self._update_metadata() - return result - - @override - async def iterate_items( - self, - *, - offset: int = 0, - limit: int | None = None, - clean: bool = False, - desc: bool = False, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_hidden: bool = False, - ) -> AsyncIterator[dict]: - async for item in self._api_client.iterate_items( - offset=offset, - limit=limit, - clean=clean, - desc=desc, - fields=fields, - omit=omit, - unwind=unwind, - skip_empty=skip_empty, - skip_hidden=skip_hidden, - ): - yield item - - await self._update_metadata() - - async def _update_metadata(self) -> None: - """Update the dataset metadata file with current information.""" - metadata = await self._api_client.get() - self._metadata = DatasetMetadata.model_validate(metadata) diff --git a/src/crawlee/storage_clients/_apify/_key_value_store_client.py b/src/crawlee/storage_clients/_apify/_key_value_store_client.py deleted file mode 100644 index e41fb023ba..0000000000 --- a/src/crawlee/storage_clients/_apify/_key_value_store_client.py +++ /dev/null @@ -1,208 +0,0 @@ -from __future__ import annotations - -import asyncio -from logging import getLogger -from typing import TYPE_CHECKING, Any, ClassVar - -from apify_client import ApifyClientAsync -from typing_extensions import override -from yarl import URL - -from crawlee.storage_clients._base import KeyValueStoreClient -from crawlee.storage_clients.models import ( - KeyValueStoreListKeysPage, - KeyValueStoreMetadata, - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, -) - -if TYPE_CHECKING: - from collections.abc import AsyncIterator - from datetime import datetime - - from apify_client.clients import KeyValueStoreClientAsync - - from crawlee.configuration import Configuration - -logger = getLogger(__name__) - - -class ApifyKeyValueStoreClient(KeyValueStoreClient): - """An Apify platform implementation of the key-value store client.""" - - _cache_by_name: ClassVar[dict[str, ApifyKeyValueStoreClient]] = {} - """A dictionary to cache clients by their names.""" - - def __init__( - self, - *, - id: str, - name: str, - created_at: datetime, - accessed_at: datetime, - modified_at: datetime, - api_client: KeyValueStoreClientAsync, - ) -> None: - """Initialize a new instance. - - Preferably use the `ApifyKeyValueStoreClient.open` class method to create a new instance. - """ - self._metadata = KeyValueStoreMetadata( - id=id, - name=name, - created_at=created_at, - accessed_at=accessed_at, - modified_at=modified_at, - ) - - self._api_client = api_client - """The Apify key-value store client for API operations.""" - - self._lock = asyncio.Lock() - """A lock to ensure that only one operation is performed at a time.""" - - @override - @property - def metadata(self) -> KeyValueStoreMetadata: - return self._metadata - - @override - @classmethod - async def open( - cls, - *, - id: str | None, - name: str | None, - configuration: Configuration, - ) -> ApifyKeyValueStoreClient: - default_name = configuration.default_key_value_store_id - token = 'configuration.apify_token' # TODO: use the real value - api_url = 'configuration.apify_api_url' # TODO: use the real value - - name = name or default_name - - # Check if the client is already cached by name. - if name in cls._cache_by_name: - client = cls._cache_by_name[name] - await client._update_metadata() # noqa: SLF001 - return client - - # Otherwise, create a new one. - apify_client_async = ApifyClientAsync( - token=token, - api_url=api_url, - max_retries=8, - min_delay_between_retries_millis=500, - timeout_secs=360, - ) - - apify_kvss_client = apify_client_async.key_value_stores() - - metadata = KeyValueStoreMetadata.model_validate( - await apify_kvss_client.get_or_create(name=id if id is not None else name), - ) - - apify_kvs_client = apify_client_async.key_value_store(key_value_store_id=metadata.id) - - client = cls( - id=metadata.id, - name=metadata.name, - created_at=metadata.created_at, - accessed_at=metadata.accessed_at, - modified_at=metadata.modified_at, - api_client=apify_kvs_client, - ) - - # Cache the client by name. - cls._cache_by_name[name] = client - - return client - - @override - async def drop(self) -> None: - async with self._lock: - await self._api_client.delete() - - # Remove the client from the cache. - if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - - @override - async def get_value(self, key: str) -> KeyValueStoreRecord | None: - response = await self._api_client.get_record(key) - record = KeyValueStoreRecord.model_validate(response) if response else None - await self._update_metadata() - return record - - @override - async def set_value(self, key: str, value: Any, content_type: str | None = None) -> None: - async with self._lock: - await self._api_client.set_record( - key=key, - value=value, - content_type=content_type, - ) - await self._update_metadata() - - @override - async def delete_value(self, key: str) -> None: - async with self._lock: - await self._api_client.delete_record(key=key) - await self._update_metadata() - - @override - async def iterate_keys( - self, - *, - exclusive_start_key: str | None = None, - limit: int | None = None, - ) -> AsyncIterator[KeyValueStoreRecordMetadata]: - count = 0 - - while True: - response = await self._api_client.list_keys(exclusive_start_key=exclusive_start_key) - list_key_page = KeyValueStoreListKeysPage.model_validate(response) - - for item in list_key_page.items: - yield item - count += 1 - - # If we've reached the limit, stop yielding - if limit and count >= limit: - break - - # If we've reached the limit or there are no more pages, exit the loop - if (limit and count >= limit) or not list_key_page.is_truncated: - break - - exclusive_start_key = list_key_page.next_exclusive_start_key - - await self._update_metadata() - - async def get_public_url(self, key: str) -> str: - """Get a URL for the given key that may be used to publicly access the value in the remote key-value store. - - Args: - key: The key for which the URL should be generated. - """ - if self._api_client.resource_id is None: - raise ValueError('resource_id cannot be None when generating a public URL') - - public_url = ( - URL(self._api_client.base_url) / 'v2' / 'key-value-stores' / self._api_client.resource_id / 'records' / key - ) - - key_value_store = self.metadata - - if key_value_store and key_value_store.model_extra: - url_signing_secret_key = key_value_store.model_extra.get('urlSigningSecretKey') - if url_signing_secret_key: - pass - # public_url = public_url.with_query(signature=create_hmac_signature(url_signing_secret_key, key)) - - return str(public_url) - - async def _update_metadata(self) -> None: - """Update the key-value store metadata with current information.""" - metadata = await self._api_client.get() - self._metadata = KeyValueStoreMetadata.model_validate(metadata) diff --git a/src/crawlee/storage_clients/_apify/_request_queue_client.py b/src/crawlee/storage_clients/_apify/_request_queue_client.py deleted file mode 100644 index d0f86041d2..0000000000 --- a/src/crawlee/storage_clients/_apify/_request_queue_client.py +++ /dev/null @@ -1,650 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -from collections import deque -from datetime import datetime, timedelta, timezone -from logging import getLogger -from typing import TYPE_CHECKING, ClassVar, Final - -from apify_client import ApifyClientAsync -from cachetools import LRUCache -from typing_extensions import override - -from crawlee import Request -from crawlee._utils.requests import unique_key_to_request_id -from crawlee.storage_clients._base import RequestQueueClient -from crawlee.storage_clients.models import ( - AddRequestsResponse, - CachedRequest, - ProcessedRequest, - ProlongRequestLockResponse, - RequestQueueHead, - RequestQueueMetadata, -) - -if TYPE_CHECKING: - from collections.abc import Sequence - - from apify_client.clients import RequestQueueClientAsync - - from crawlee.configuration import Configuration - -logger = getLogger(__name__) - - -class ApifyRequestQueueClient(RequestQueueClient): - """An Apify platform implementation of the request queue client.""" - - _cache_by_name: ClassVar[dict[str, ApifyRequestQueueClient]] = {} - """A dictionary to cache clients by their names.""" - - _DEFAULT_LOCK_TIME: Final[timedelta] = timedelta(minutes=3) - """The default lock time for requests in the queue.""" - - _MAX_CACHED_REQUESTS: Final[int] = 1_000_000 - """Maximum number of requests that can be cached.""" - - def __init__( - self, - *, - id: str, - name: str, - created_at: datetime, - accessed_at: datetime, - modified_at: datetime, - had_multiple_clients: bool, - handled_request_count: int, - pending_request_count: int, - stats: dict, - total_request_count: int, - api_client: RequestQueueClientAsync, - ) -> None: - """Initialize a new instance. - - Preferably use the `ApifyRequestQueueClient.open` class method to create a new instance. - """ - self._metadata = RequestQueueMetadata( - id=id, - name=name, - created_at=created_at, - accessed_at=accessed_at, - modified_at=modified_at, - had_multiple_clients=had_multiple_clients, - handled_request_count=handled_request_count, - pending_request_count=pending_request_count, - stats=stats, - total_request_count=total_request_count, - ) - - self._api_client = api_client - """The Apify request queue client for API operations.""" - - self._lock = asyncio.Lock() - """A lock to ensure that only one operation is performed at a time.""" - - self._queue_head = deque[str]() - """A deque to store request IDs in the queue head.""" - - self._requests_cache: LRUCache[str, CachedRequest] = LRUCache(maxsize=self._MAX_CACHED_REQUESTS) - """A cache to store request objects.""" - - self._queue_has_locked_requests: bool | None = None - """Whether the queue has requests locked by another client.""" - - self._should_check_for_forefront_requests = False - """Whether to check for forefront requests in the next list_head call.""" - - @override - @property - def metadata(self) -> RequestQueueMetadata: - return self._metadata - - @override - @classmethod - async def open( - cls, - *, - id: str | None, - name: str | None, - configuration: Configuration, - ) -> ApifyRequestQueueClient: - default_name = configuration.default_request_queue_id - - # Get API credentials - token = os.environ.get('APIFY_TOKEN') - api_url = 'https://api.apify.com' - - name = name or default_name - - # Check if the client is already cached by name. - if name in cls._cache_by_name: - client = cls._cache_by_name[name] - await client._update_metadata() # noqa: SLF001 - return client - - # Create a new API client - apify_client_async = ApifyClientAsync( - token=token, - api_url=api_url, - max_retries=8, - min_delay_between_retries_millis=500, - timeout_secs=360, - ) - - apify_rqs_client = apify_client_async.request_queues() - - # Get or create the request queue - metadata = RequestQueueMetadata.model_validate( - await apify_rqs_client.get_or_create(name=id if id is not None else name), - ) - - apify_rq_client = apify_client_async.request_queue(request_queue_id=metadata.id) - - # Create the client instance - client = cls( - id=metadata.id, - name=metadata.name, - created_at=metadata.created_at, - accessed_at=metadata.accessed_at, - modified_at=metadata.modified_at, - had_multiple_clients=metadata.had_multiple_clients, - handled_request_count=metadata.handled_request_count, - pending_request_count=metadata.pending_request_count, - stats=metadata.stats, - total_request_count=metadata.total_request_count, - api_client=apify_rq_client, - ) - - # Cache the client by name - cls._cache_by_name[name] = client - - return client - - @override - async def drop(self) -> None: - async with self._lock: - await self._api_client.delete() - - # Remove the client from the cache - if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - - @override - async def add_batch_of_requests( - self, - requests: Sequence[Request], - *, - forefront: bool = False, - ) -> AddRequestsResponse: - """Add a batch of requests to the queue. - - Args: - requests: The requests to add. - forefront: Whether to add the requests to the beginning of the queue. - - Returns: - Response containing information about the added requests. - """ - # Prepare requests for API by converting to dictionaries - requests_dict = [request.model_dump(by_alias=True) for request in requests] - - # Remove 'id' fields from requests as the API doesn't accept them - for request_dict in requests_dict: - if 'id' in request_dict: - del request_dict['id'] - - # Send requests to API - response = await self._api_client.batch_add_requests(requests=requests_dict, forefront=forefront) - - # Update metadata after adding requests - await self._update_metadata() - - return AddRequestsResponse.model_validate(response) - - @override - async def get_request(self, request_id: str) -> Request | None: - """Get a request by ID. - - Args: - request_id: The ID of the request to get. - - Returns: - The request or None if not found. - """ - response = await self._api_client.get_request(request_id) - await self._update_metadata() - - if response is None: - return None - - return Request.model_validate(**response) - - @override - async def fetch_next_request(self) -> Request | None: - """Return the next request in the queue to be processed. - - Once you successfully finish processing of the request, you need to call `mark_request_as_handled` - to mark the request as handled in the queue. If there was some error in processing the request, call - `reclaim_request` instead, so that the queue will give the request to some other consumer - in another call to the `fetch_next_request` method. - - Returns: - The request or `None` if there are no more pending requests. - """ - # Ensure the queue head has requests if available - await self._ensure_head_is_non_empty() - - # If queue head is empty after ensuring, there are no requests - if not self._queue_head: - return None - - # Get the next request ID from the queue head - next_request_id = self._queue_head.popleft() - request = await self._get_or_hydrate_request(next_request_id) - - # Handle potential inconsistency where request might not be in the main table yet - if request is None: - logger.debug( - 'Cannot find a request from the beginning of queue, will be retried later', - extra={'nextRequestId': next_request_id}, - ) - return None - - # If the request was already handled, skip it - if request.handled_at is not None: - logger.debug( - 'Request fetched from the beginning of queue was already handled', - extra={'nextRequestId': next_request_id}, - ) - return None - - return request - - @override - async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: - """Mark a request as handled after successful processing. - - Handled requests will never again be returned by the `fetch_next_request` method. - - Args: - request: The request to mark as handled. - - Returns: - Information about the queue operation. `None` if the given request was not in progress. - """ - # Set the handled_at timestamp if not already set - if request.handled_at is None: - request.handled_at = datetime.now(tz=timezone.utc) - - try: - # Update the request in the API - processed_request = await self._update_request(request) - processed_request.unique_key = request.unique_key - - # Update the cache with the handled request - cache_key = unique_key_to_request_id(request.unique_key) - self._cache_request( - cache_key, - processed_request, - forefront=False, - hydrated_request=request, - ) - - # Update metadata after marking request as handled - await self._update_metadata() - except Exception as exc: - logger.debug(f'Error marking request {request.id} as handled: {exc!s}') - return None - else: - return processed_request - - @override - async def reclaim_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest | None: - """Reclaim a failed request back to the queue. - - The request will be returned for processing later again by another call to `fetch_next_request`. - - Args: - request: The request to return to the queue. - forefront: Whether to add the request to the head or the end of the queue. - - Returns: - Information about the queue operation. `None` if the given request was not in progress. - """ - try: - # Update the request in the API - processed_request = await self._update_request(request, forefront=forefront) - processed_request.unique_key = request.unique_key - - # Update the cache - cache_key = unique_key_to_request_id(request.unique_key) - self._cache_request( - cache_key, - processed_request, - forefront=forefront, - hydrated_request=request, - ) - - # If we're adding to the forefront, we need to check for forefront requests - # in the next list_head call - if forefront: - self._should_check_for_forefront_requests = True - - # Try to release the lock on the request - try: - await self._delete_request_lock(request.id, forefront=forefront) - except Exception as err: - logger.debug(f'Failed to delete request lock for request {request.id}', exc_info=err) - - # Update metadata after reclaiming request - await self._update_metadata() - except Exception as exc: - logger.debug(f'Error reclaiming request {request.id}: {exc!s}') - return None - else: - return processed_request - - @override - async def is_empty(self) -> bool: - """Check if the queue is empty. - - Returns: - True if the queue is empty, False otherwise. - """ - head = await self._list_head(limit=1, lock_time=None) - return len(head.items) == 0 - - async def _ensure_head_is_non_empty(self) -> None: - """Ensure that the queue head has requests if they are available in the queue.""" - # If queue head has adequate requests, skip fetching more - if len(self._queue_head) > 1 and not self._should_check_for_forefront_requests: - return - - # Fetch requests from the API and populate the queue head - await self._list_head(lock_time=self._DEFAULT_LOCK_TIME) - - async def _get_or_hydrate_request(self, request_id: str) -> Request | None: - """Get a request by ID, either from cache or by fetching from API. - - Args: - request_id: The ID of the request to get. - - Returns: - The request if found and valid, otherwise None. - """ - # First check if the request is in our cache - cached_entry = self._requests_cache.get(request_id) - - if cached_entry and cached_entry.hydrated: - # If we have the request hydrated in cache, check if lock is expired - if cached_entry.lock_expires_at and cached_entry.lock_expires_at < datetime.now(tz=timezone.utc): - # Try to prolong the lock if it's expired - try: - lock_secs = int(self._DEFAULT_LOCK_TIME.total_seconds()) - response = await self._prolong_request_lock( - request_id, forefront=cached_entry.forefront, lock_secs=lock_secs - ) - cached_entry.lock_expires_at = response.lock_expires_at - except Exception: - # If prolonging the lock fails, we lost the request - logger.debug(f'Failed to prolong lock for request {request_id}, returning None') - return None - - return cached_entry.hydrated - - # If not in cache or not hydrated, fetch the request - try: - # Try to acquire or prolong the lock - lock_secs = int(self._DEFAULT_LOCK_TIME.total_seconds()) - await self._prolong_request_lock(request_id, forefront=False, lock_secs=lock_secs) - - # Fetch the request data - request = await self.get_request(request_id) - - # If request is not found, release lock and return None - if not request: - await self._delete_request_lock(request_id) - return None - - # Update cache with hydrated request - cache_key = unique_key_to_request_id(request.unique_key) - self._cache_request( - cache_key, - ProcessedRequest( - id=request_id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=request.handled_at is not None, - ), - forefront=False, - hydrated_request=request, - ) - except Exception as exc: - logger.debug(f'Error fetching or locking request {request_id}: {exc!s}') - return None - else: - return request - - async def _update_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - """Update a request in the queue. - - Args: - request: The updated request. - forefront: Whether to put the updated request in the beginning or the end of the queue. - - Returns: - The updated request - """ - response = await self._api_client.update_request( - request=request.model_dump(by_alias=True), - forefront=forefront, - ) - - return ProcessedRequest.model_validate( - {'id': request.id, 'uniqueKey': request.unique_key} | response, - ) - - async def _list_head( - self, - *, - lock_time: timedelta | None = None, - limit: int = 25, - ) -> RequestQueueHead: - """Retrieve requests from the beginning of the queue. - - Args: - lock_time: Duration for which to lock the retrieved requests. - If None, requests will not be locked. - limit: Maximum number of requests to retrieve. - - Returns: - A collection of requests from the beginning of the queue. - """ - # Return from cache if available and we're not checking for new forefront requests - if self._queue_head and not self._should_check_for_forefront_requests: - logger.debug(f'Using cached queue head with {len(self._queue_head)} requests') - - # Create a list of requests from the cached queue head - items = [] - for request_id in list(self._queue_head)[:limit]: - cached_request = self._requests_cache.get(request_id) - if cached_request and cached_request.hydrated: - items.append(cached_request.hydrated) - - return RequestQueueHead( - limit=limit, - had_multiple_clients=self._metadata.had_multiple_clients, - queue_modified_at=self._metadata.modified_at, - items=items, - queue_has_locked_requests=self._queue_has_locked_requests, - lock_time=lock_time, - ) - - # Otherwise fetch from API - lock_time = lock_time or self._DEFAULT_LOCK_TIME - lock_secs = int(lock_time.total_seconds()) - - response = await self._api_client.list_and_lock_head( - lock_secs=lock_secs, - limit=limit, - ) - - # Update the queue head cache - self._queue_has_locked_requests = response.get('queueHasLockedRequests', False) - - # Clear current queue head if we're checking for forefront requests - if self._should_check_for_forefront_requests: - self._queue_head.clear() - self._should_check_for_forefront_requests = False - - # Process and cache the requests - head_id_buffer = list[str]() - forefront_head_id_buffer = list[str]() - - for request_data in response.get('items', []): - request = Request.model_validate(request_data) - - # Skip requests without ID or unique key - if not request.id or not request.unique_key: - logger.debug( - 'Skipping request from queue head, missing ID or unique key', - extra={ - 'id': request.id, - 'unique_key': request.unique_key, - }, - ) - continue - - # Check if this request was already cached and if it was added to forefront - cache_key = unique_key_to_request_id(request.unique_key) - cached_request = self._requests_cache.get(cache_key) - forefront = cached_request.forefront if cached_request else False - - # Add to appropriate buffer based on forefront flag - if forefront: - forefront_head_id_buffer.insert(0, request.id) - else: - head_id_buffer.append(request.id) - - # Cache the request - self._cache_request( - cache_key, - ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=False, - ), - forefront=forefront, - hydrated_request=request, - ) - - # Update the queue head deque - for request_id in head_id_buffer: - self._queue_head.append(request_id) - - for request_id in forefront_head_id_buffer: - self._queue_head.appendleft(request_id) - - return RequestQueueHead.model_validate(response) - - async def _prolong_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - lock_secs: int, - ) -> ProlongRequestLockResponse: - """Prolong the lock on a specific request in the queue. - - Args: - request_id: The identifier of the request whose lock is to be prolonged. - forefront: Whether to put the request in the beginning or the end of the queue after lock expires. - lock_secs: The additional amount of time, in seconds, that the request will remain locked. - - Returns: - A response containing the time at which the lock will expire. - """ - response = await self._api_client.prolong_request_lock( - request_id=request_id, - forefront=forefront, - lock_secs=lock_secs, - ) - - result = ProlongRequestLockResponse( - lock_expires_at=datetime.fromisoformat(response['lockExpiresAt'].replace('Z', '+00:00')) - ) - - # Update the cache with the new lock expiration - for cached_request in self._requests_cache.values(): - if cached_request.id == request_id: - cached_request.lock_expires_at = result.lock_expires_at - break - - return result - - async def _delete_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - ) -> None: - """Delete the lock on a specific request in the queue. - - Args: - request_id: ID of the request to delete the lock. - forefront: Whether to put the request in the beginning or the end of the queue after the lock is deleted. - """ - try: - await self._api_client.delete_request_lock( - request_id=request_id, - forefront=forefront, - ) - - # Update the cache to remove the lock - for cached_request in self._requests_cache.values(): - if cached_request.id == request_id: - cached_request.lock_expires_at = None - break - except Exception as err: - logger.debug(f'Failed to delete request lock for request {request_id}', exc_info=err) - - def _cache_request( - self, - cache_key: str, - processed_request: ProcessedRequest, - *, - forefront: bool, - hydrated_request: Request | None = None, - ) -> None: - """Cache a request for future use. - - Args: - cache_key: The key to use for caching the request. - processed_request: The processed request information. - forefront: Whether the request was added to the forefront of the queue. - hydrated_request: The hydrated request object, if available. - """ - self._requests_cache[cache_key] = CachedRequest( - id=processed_request.id, - was_already_handled=processed_request.was_already_handled, - hydrated=hydrated_request, - lock_expires_at=None, - forefront=forefront, - ) - - async def _update_metadata(self) -> None: - """Update the request queue metadata with current information.""" - metadata = await self._api_client.get() - self._metadata = RequestQueueMetadata.model_validate(metadata) diff --git a/src/crawlee/storage_clients/_apify/_storage_client.py b/src/crawlee/storage_clients/_apify/_storage_client.py deleted file mode 100644 index 1d4d66dd6a..0000000000 --- a/src/crawlee/storage_clients/_apify/_storage_client.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -from typing_extensions import override - -from crawlee.configuration import Configuration -from crawlee.storage_clients._base import StorageClient - -from ._dataset_client import ApifyDatasetClient -from ._key_value_store_client import ApifyKeyValueStoreClient -from ._request_queue_client import ApifyRequestQueueClient - - -class ApifyStorageClient(StorageClient): - """Apify storage client.""" - - @override - async def open_dataset_client( - self, - *, - id: str | None = None, - name: str | None = None, - configuration: Configuration | None = None, - ) -> ApifyDatasetClient: - configuration = configuration or Configuration.get_global_configuration() - client = await ApifyDatasetClient.open(id=id, name=name, configuration=configuration) - - if configuration.purge_on_start: - await client.drop() - client = await ApifyDatasetClient.open(id=id, name=name, configuration=configuration) - - return client - - @override - async def open_key_value_store_client( - self, - *, - id: str | None = None, - name: str | None = None, - configuration: Configuration | None = None, - ) -> ApifyKeyValueStoreClient: - configuration = configuration or Configuration.get_global_configuration() - client = await ApifyKeyValueStoreClient.open(id=id, name=name, configuration=configuration) - - if configuration.purge_on_start: - await client.drop() - client = await ApifyKeyValueStoreClient.open(id=id, name=name, configuration=configuration) - - return client - - @override - async def open_request_queue_client( - self, - *, - id: str | None = None, - name: str | None = None, - configuration: Configuration | None = None, - ) -> ApifyRequestQueueClient: - configuration = configuration or Configuration.get_global_configuration() - client = await ApifyRequestQueueClient.open(id=id, name=name, configuration=configuration) - - if configuration.purge_on_start: - await client.drop() - client = await ApifyRequestQueueClient.open(id=id, name=name, configuration=configuration) - - return client diff --git a/src/crawlee/storage_clients/_apify/py.typed b/src/crawlee/storage_clients/_apify/py.typed deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py index e832c1f4c1..c3297b570a 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py @@ -18,8 +18,6 @@ from collections.abc import AsyncGenerator - - @pytest.fixture def configuration(tmp_path: Path) -> Configuration: return Configuration( diff --git a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py index 95ae2aa929..156394ad4d 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py @@ -17,8 +17,6 @@ from pathlib import Path - - @pytest.fixture def configuration(tmp_path: Path) -> Configuration: return Configuration( diff --git a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py index 125e60f9b7..2b2cd16604 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py @@ -18,8 +18,6 @@ from pathlib import Path - - @pytest.fixture def configuration(tmp_path: Path) -> Configuration: return Configuration( diff --git a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py index 06da10b5f8..52b0b3c733 100644 --- a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py @@ -15,8 +15,6 @@ from collections.abc import AsyncGenerator - - @pytest.fixture async def dataset_client() -> AsyncGenerator[MemoryDatasetClient, None]: """Fixture that provides a fresh memory dataset client for each test.""" diff --git a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py index b179b3fd3e..54c8d8b9a8 100644 --- a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py @@ -15,8 +15,6 @@ from collections.abc import AsyncGenerator - - @pytest.fixture async def kvs_client() -> AsyncGenerator[MemoryKeyValueStoreClient, None]: """Fixture that provides a fresh memory key-value store client for each test.""" diff --git a/tests/unit/storage_clients/_memory/test_memory_rq_client.py b/tests/unit/storage_clients/_memory/test_memory_rq_client.py index f5b6c16adb..36f6940119 100644 --- a/tests/unit/storage_clients/_memory/test_memory_rq_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_rq_client.py @@ -15,8 +15,6 @@ from collections.abc import AsyncGenerator - - @pytest.fixture async def rq_client() -> AsyncGenerator[MemoryRequestQueueClient, None]: """Fixture that provides a fresh memory request queue client for each test.""" @@ -166,7 +164,7 @@ async def test_add_batch_of_requests_with_duplicates(rq_client: MemoryRequestQue duplicate_requests = [ Request.from_url(url='https://example.com/1-dup', unique_key='key1'), # Same as first (handled) Request.from_url(url='https://example.com/2-dup', unique_key='key2'), # Same as second (not handled) - Request.from_url(url='https://example.com/3', unique_key='key3'), # New request + Request.from_url(url='https://example.com/3', unique_key='key3'), # New request ] response = await rq_client.add_batch_of_requests(duplicate_requests) @@ -475,7 +473,7 @@ async def test_unique_key_generation(rq_client: MemoryRequestQueueClient) -> Non # Add requests without explicit unique keys requests = [ Request.from_url(url='https://example.com/1'), - Request.from_url(url='https://example.com/1', always_enqueue=True) + Request.from_url(url='https://example.com/1', always_enqueue=True), ] response = await rq_client.add_batch_of_requests(requests) diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index c12a68d3e9..8c9e0a30e1 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -18,8 +18,6 @@ from crawlee.storage_clients import StorageClient - - @pytest.fixture(params=['memory', 'file_system']) def storage_client(request: pytest.FixtureRequest) -> StorageClient: """Parameterized fixture to test with different storage clients.""" diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index 4c43225d31..ab290f2819 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -19,8 +19,6 @@ from crawlee.storage_clients import StorageClient - - @pytest.fixture(params=['memory', 'file_system']) def storage_client(request: pytest.FixtureRequest) -> StorageClient: """Parameterized fixture to test with different storage clients.""" diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index d9a0f98470..876303cedb 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -18,9 +18,6 @@ from pathlib import Path - - - @pytest.fixture(params=['memory', 'file_system']) def storage_client(request: pytest.FixtureRequest) -> StorageClient: """Parameterized fixture to test with different storage clients.""" From 1780afeb041d5ecdeab02af7a2c88da3c70d79ce Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Tue, 29 Apr 2025 12:04:19 +0200 Subject: [PATCH 22/22] Improve caching and fix in KVS --- .../_file_system/_dataset_client.py | 18 +----- .../_file_system/_key_value_store_client.py | 35 +++++------ .../_file_system/_request_queue_client.py | 18 +----- .../_memory/_dataset_client.py | 22 +------ .../_memory/_key_value_store_client.py | 22 +------ .../_memory/_request_queue_client.py | 22 +------ src/crawlee/storages/_base.py | 14 +++++ src/crawlee/storages/_dataset.py | 41 ++++++------- src/crawlee/storages/_key_value_store.py | 61 +++++++++---------- src/crawlee/storages/_request_queue.py | 44 ++++++------- .../_file_system/test_fs_dataset_client.py | 24 -------- .../_file_system/test_fs_kvs_client.py | 23 ------- .../_file_system/test_fs_rq_client.py | 32 ---------- .../_memory/test_memory_dataset_client.py | 51 ---------------- .../_memory/test_memory_kvs_client.py | 49 --------------- .../_memory/test_memory_rq_client.py | 51 ---------------- tests/unit/storages/test_dataset.py | 11 +--- tests/unit/storages/test_key_value_store.py | 56 ++++++++++++++--- tests/unit/storages/test_request_queue.py | 11 +--- 19 files changed, 154 insertions(+), 451 deletions(-) diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py index 5db837612f..fa1266524a 100644 --- a/src/crawlee/storage_clients/_file_system/_dataset_client.py +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from logging import getLogger from pathlib import Path -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING from pydantic import ValidationError from typing_extensions import override @@ -50,9 +50,6 @@ class FileSystemDatasetClient(DatasetClient): _ITEM_FILENAME_DIGITS = 9 """Number of digits used for the dataset item file names (e.g., 000000019.json).""" - _cache_by_name: ClassVar[dict[str, FileSystemDatasetClient]] = {} - """A dictionary to cache clients by their names.""" - def __init__( self, *, @@ -114,12 +111,6 @@ async def open( name = name or configuration.default_dataset_id - # Check if the client is already cached by name. - if name in cls._cache_by_name: - client = cls._cache_by_name[name] - await client._update_metadata(update_accessed_at=True) # noqa: SLF001 - return client - storage_dir = Path(configuration.storage_dir) dataset_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = dataset_path / METADATA_FILENAME @@ -166,9 +157,6 @@ async def open( ) await client._update_metadata() - # Cache the client by name. - cls._cache_by_name[name] = client - return client @override @@ -178,10 +166,6 @@ async def drop(self) -> None: async with self._lock: await asyncio.to_thread(shutil.rmtree, self.path_to_dataset) - # Remove the client from the cache. - if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - @override async def push_data(self, data: list[Any] | dict[str, Any]) -> None: new_item_count = self.metadata.item_count diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py index f7db025a25..4ed427f797 100644 --- a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -3,10 +3,11 @@ import asyncio import json import shutil +import urllib.parse from datetime import datetime, timezone from logging import getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any from pydantic import ValidationError from typing_extensions import override @@ -48,9 +49,6 @@ class FileSystemKeyValueStoreClient(KeyValueStoreClient): _STORAGE_SUBDIR = 'key_value_stores' """The name of the subdirectory where key-value stores are stored.""" - _cache_by_name: ClassVar[dict[str, FileSystemKeyValueStoreClient]] = {} - """A dictionary to cache clients by their names.""" - def __init__( self, *, @@ -110,12 +108,6 @@ async def open( name = name or configuration.default_dataset_id - # Check if the client is already cached by name. - if name in cls._cache_by_name: - client = cls._cache_by_name[name] - await client._update_metadata(update_accessed_at=True) # noqa: SLF001 - return client - storage_dir = Path(configuration.storage_dir) kvs_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = kvs_path / METADATA_FILENAME @@ -160,9 +152,6 @@ async def open( ) await client._update_metadata() - # Cache the client by name. - cls._cache_by_name[name] = client - return client @override @@ -172,16 +161,12 @@ async def drop(self) -> None: async with self._lock: await asyncio.to_thread(shutil.rmtree, self.path_to_kvs) - # Remove the client from the cache. - if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: # Update the metadata to record access await self._update_metadata(update_accessed_at=True) - record_path = self.path_to_kvs / key + record_path = self.path_to_kvs / self._encode_key(key) if not record_path.exists(): return None @@ -257,7 +242,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No # Fallback: attempt to convert to string and encode. value_bytes = str(value).encode('utf-8') - record_path = self.path_to_kvs / key + record_path = self.path_to_kvs / self._encode_key(key) # Prepare the metadata size = len(value_bytes) @@ -284,7 +269,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No @override async def delete_value(self, *, key: str) -> None: - record_path = self.path_to_kvs / key + record_path = self.path_to_kvs / self._encode_key(key) metadata_path = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') deleted = False @@ -331,7 +316,7 @@ async def iterate_keys( continue # Extract the base key name from the metadata filename - key_name = file_path.name[: -len(f'.{METADATA_FILENAME}')] + key_name = self._decode_key(file_path.name[: -len(f'.{METADATA_FILENAME}')]) # Apply exclusive_start_key filter if provided if exclusive_start_key is not None and key_name <= exclusive_start_key: @@ -384,3 +369,11 @@ async def _update_metadata( # Dump the serialized metadata to the file. data = await json_dumps(self._metadata.model_dump()) await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') + + def _encode_key(self, key: str) -> str: + """Encode a key to make it safe for use in a file path.""" + return urllib.parse.quote(key, safe='') + + def _decode_key(self, encoded_key: str) -> str: + """Decode a key that was encoded to make it safe for use in a file path.""" + return urllib.parse.unquote(encoded_key) diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py index a88168e894..2d170df09b 100644 --- a/src/crawlee/storage_clients/_file_system/_request_queue_client.py +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from logging import getLogger from pathlib import Path -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING from pydantic import ValidationError from typing_extensions import override @@ -48,9 +48,6 @@ class FileSystemRequestQueueClient(RequestQueueClient): _STORAGE_SUBDIR = 'request_queues' """The name of the subdirectory where request queues are stored.""" - _cache_by_name: ClassVar[dict[str, FileSystemRequestQueueClient]] = {} - """A dictionary to cache clients by their names.""" - def __init__( self, *, @@ -126,12 +123,6 @@ async def open( name = name or configuration.default_request_queue_id - # Check if the client is already cached by name. - if name in cls._cache_by_name: - client = cls._cache_by_name[name] - await client._update_metadata(update_accessed_at=True) # noqa: SLF001 - return client - storage_dir = Path(configuration.storage_dir) rq_path = storage_dir / cls._STORAGE_SUBDIR / name metadata_path = rq_path / METADATA_FILENAME @@ -216,9 +207,6 @@ async def open( ) await client._update_metadata() - # Cache the client by name. - cls._cache_by_name[name] = client - return client @override @@ -228,10 +216,6 @@ async def drop(self) -> None: async with self._lock: await asyncio.to_thread(shutil.rmtree, self.path_to_rq) - # Remove the client from the cache. - if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - @override async def add_batch_of_requests( self, diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 0d75b50f9f..63e75eabb0 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -2,7 +2,7 @@ from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -31,9 +31,6 @@ class MemoryDatasetClient(DatasetClient): sorting, filtering, and pagination, but performs them entirely in memory. """ - _cache_by_name: ClassVar[dict[str, MemoryDatasetClient]] = {} - """A dictionary to cache clients by their names.""" - def __init__( self, *, @@ -76,16 +73,10 @@ async def open( ) -> MemoryDatasetClient: name = name or configuration.default_dataset_id - # Check if the client is already cached by name. - if name in cls._cache_by_name: - client = cls._cache_by_name[name] - await client._update_metadata(update_accessed_at=True) # noqa: SLF001 - return client - dataset_id = id or crypto_random_object_id() now = datetime.now(timezone.utc) - client = cls( + return cls( id=dataset_id, name=name, created_at=now, @@ -94,20 +85,11 @@ async def open( item_count=0, ) - # Cache the client by name - cls._cache_by_name[name] = client - - return client - @override async def drop(self) -> None: self._records.clear() self._metadata.item_count = 0 - # Remove the client from the cache - if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - @override async def push_data(self, data: list[Any] | dict[str, Any]) -> None: new_item_count = self.metadata.item_count diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index 9b70419142..e9b91702c7 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -3,7 +3,7 @@ import sys from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -32,9 +32,6 @@ class MemoryKeyValueStoreClient(KeyValueStoreClient): does not support data sharing across different processes. """ - _cache_by_name: ClassVar[dict[str, MemoryKeyValueStoreClient]] = {} - """A dictionary to cache clients by their names.""" - def __init__( self, *, @@ -75,17 +72,11 @@ async def open( ) -> MemoryKeyValueStoreClient: name = name or configuration.default_key_value_store_id - # Check if the client is already cached by name - if name in cls._cache_by_name: - client = cls._cache_by_name[name] - await client._update_metadata(update_accessed_at=True) # noqa: SLF001 - return client - # If specific id is provided, use it; otherwise, generate a new one id = id or crypto_random_object_id() now = datetime.now(timezone.utc) - client = cls( + return cls( id=id, name=name, created_at=now, @@ -93,20 +84,11 @@ async def open( modified_at=now, ) - # Cache the client by name - cls._cache_by_name[name] = client - - return client - @override async def drop(self) -> None: # Clear all data self._records.clear() - # Remove from cache - if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: await self._update_metadata(update_accessed_at=True) diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index cfe674fb3d..d360695884 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -2,7 +2,7 @@ from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING from typing_extensions import override @@ -35,9 +35,6 @@ class MemoryRequestQueueClient(RequestQueueClient): does not support data sharing across different processes. """ - _cache_by_name: ClassVar[dict[str, MemoryRequestQueueClient]] = {} - """A dictionary to cache clients by their names.""" - def __init__( self, *, @@ -91,17 +88,11 @@ async def open( ) -> MemoryRequestQueueClient: name = name or configuration.default_request_queue_id - # Check if the client is already cached by name - if name in cls._cache_by_name: - client = cls._cache_by_name[name] - await client._update_metadata(update_accessed_at=True) # noqa: SLF001 - return client - # If specific id is provided, use it; otherwise, generate a new one id = id or crypto_random_object_id() now = datetime.now(timezone.utc) - client = cls( + return cls( id=crypto_random_object_id(), name=name, created_at=now, @@ -114,21 +105,12 @@ async def open( total_request_count=0, ) - # Cache the client by name - cls._cache_by_name[name] = client - - return client - @override async def drop(self) -> None: # Clear all data self._records.clear() self._in_progress.clear() - # Remove from cache - if self.metadata.name in self.__class__._cache_by_name: # noqa: SLF001 - del self.__class__._cache_by_name[self.metadata.name] # noqa: SLF001 - @override async def add_batch_of_requests( self, diff --git a/src/crawlee/storages/_base.py b/src/crawlee/storages/_base.py index 8e73326041..9216bf4569 100644 --- a/src/crawlee/storages/_base.py +++ b/src/crawlee/storages/_base.py @@ -50,3 +50,17 @@ async def open( @abstractmethod async def drop(self) -> None: """Drop the storage, removing it from the underlying storage client and clearing the cache.""" + + @classmethod + def compute_cache_key( + cls, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + storage_client: StorageClient | None = None, + ) -> str: + """Compute the cache key for the storage. + + The cache key computed based on the storage ID, name, configuration fields, and storage client class. + """ + return f'{id}|{name}|{configuration}|{storage_client.__class__}' diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 413290faab..928679937d 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -66,21 +66,20 @@ class Dataset(Storage): ``` """ - _cache_by_id: ClassVar[dict[str, Dataset]] = {} - """A dictionary to cache datasets by their IDs.""" + _cache: ClassVar[dict[str, Dataset]] = {} + """A dictionary to cache datasets.""" - _cache_by_name: ClassVar[dict[str, Dataset]] = {} - """A dictionary to cache datasets by their names.""" - - def __init__(self, client: DatasetClient) -> None: + def __init__(self, client: DatasetClient, cache_key: str) -> None: """Initialize a new instance. Preferably use the `Dataset.open` constructor to create a new instance. Args: client: An instance of a dataset client. + cache_key: A unique key to identify the dataset in the cache. """ self._client = client + self._cache_key = cache_key @override @property @@ -110,38 +109,34 @@ async def open( if id and name: raise ValueError('Only one of "id" or "name" can be specified, not both.') - # Check if dataset is already cached by id or name - if id and id in cls._cache_by_id: - return cls._cache_by_id[id] - if name and name in cls._cache_by_name: - return cls._cache_by_name[name] - configuration = service_locator.get_configuration() if configuration is None else configuration storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - client = await storage_client.open_dataset_client( + cache_key = cls.compute_cache_key( id=id, name=name, configuration=configuration, + storage_client=storage_client, ) - dataset = cls(client) + if cache_key in cls._cache: + return cls._cache[cache_key] - # Cache the dataset by id and name if available - if dataset.id: - cls._cache_by_id[dataset.id] = dataset - if dataset.name: - cls._cache_by_name[dataset.name] = dataset + client = await storage_client.open_dataset_client( + id=id, + name=name, + configuration=configuration, + ) + dataset = cls(client, cache_key) + cls._cache[cache_key] = dataset return dataset @override async def drop(self) -> None: # Remove from cache before dropping - if self.id in self._cache_by_id: - del self._cache_by_id[self.id] - if self.name and self.name in self._cache_by_name: - del self._cache_by_name[self.name] + if self._cache_key in self._cache: + del self._cache[self._cache_key] await self._client.drop() diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 6584f8fee7..753e7c5242 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -54,24 +54,23 @@ class KeyValueStore(Storage): ``` """ - _cache_by_id: ClassVar[dict[str, KeyValueStore]] = {} - """A dictionary to cache key-value stores by their IDs.""" - - _cache_by_name: ClassVar[dict[str, KeyValueStore]] = {} - """A dictionary to cache key-value stores by their names.""" + _cache: ClassVar[dict[str, KeyValueStore]] = {} + """A dictionary to cache key-value stores.""" _autosave_cache: ClassVar[dict[str, dict[str, dict[str, JsonSerializable]]]] = {} """A dictionary to cache auto-saved values.""" - def __init__(self, client: KeyValueStoreClient) -> None: + def __init__(self, client: KeyValueStoreClient, cache_key: str) -> None: """Initialize a new instance. Preferably use the `KeyValueStore.open` constructor to create a new instance. Args: client: An instance of a key-value store client. + cache_key: A unique key to identify the key-value store in the cache. """ self._client = client + self._cache_key = cache_key self._autosave_lock = asyncio.Lock() self._persist_state_event_started = False @@ -103,38 +102,34 @@ async def open( if id and name: raise ValueError('Only one of "id" or "name" can be specified, not both.') - # Check if key-value store is already cached by id or name - if id and id in cls._cache_by_id: - return cls._cache_by_id[id] - if name and name in cls._cache_by_name: - return cls._cache_by_name[name] - configuration = service_locator.get_configuration() if configuration is None else configuration storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - client = await storage_client.open_key_value_store_client( + cache_key = cls.compute_cache_key( id=id, name=name, configuration=configuration, + storage_client=storage_client, ) - kvs = cls(client) + if cache_key in cls._cache: + return cls._cache[cache_key] - # Cache the key-value store by id and name if available - if kvs.id: - cls._cache_by_id[kvs.id] = kvs - if kvs.name: - cls._cache_by_name[kvs.name] = kvs + client = await storage_client.open_key_value_store_client( + id=id, + name=name, + configuration=configuration, + ) + kvs = cls(client, cache_key) + cls._cache[cache_key] = kvs return kvs @override async def drop(self) -> None: # Remove from cache before dropping - if self.id in self._cache_by_id: - del self._cache_by_id[self.id] - if self.name and self.name in self._cache_by_name: - del self._cache_by_name[self.name] + if self._cache_key in self._cache: + del self._cache[self._cache_key] # Clear cache with persistent values self._clear_cache() @@ -248,8 +243,8 @@ async def get_auto_saved_value( default_value = {} if default_value is None else default_value async with self._autosave_lock: - if key in self._cache: - return self._cache[key] + if key in self._autosave_cache: + return self._autosave_cache[key] value = await self.get_value(key, default_value) @@ -258,7 +253,7 @@ async def get_auto_saved_value( f'Expected dictionary for persist state value at key "{key}, but got {type(value).__name__}' ) - self._cache[key] = value + self._autosave_cache[key] = value self._ensure_persist_event() return value @@ -280,15 +275,15 @@ async def get_public_url(self, key: str) -> str: return await self._client.get_public_url(key=key) @property - def _cache(self) -> dict[str, dict[str, JsonSerializable]]: - """Cache dictionary for storing auto-saved values indexed by store ID.""" - if self.id not in self._autosave_cache: - self._autosave_cache[self.id] = {} - return self._autosave_cache[self.id] + def _autosave_cache_instance(self) -> dict[str, dict[str, JsonSerializable]]: + """Cache dictionary for storing auto-saved values indexed by store cache key.""" + if self._cache_key not in self._autosave_cache: + self._autosave_cache[self._cache_key] = {} + return self._autosave_cache[self._cache_key] async def _persist_save(self, _event_data: EventPersistStateData | None = None) -> None: """Save cache with persistent values. Can be used in Event Manager.""" - for key, value in self._cache.items(): + for key, value in self._autosave_cache_instance.items(): await self.set_value(key, value) def _ensure_persist_event(self) -> None: @@ -302,7 +297,7 @@ def _ensure_persist_event(self) -> None: def _clear_cache(self) -> None: """Clear cache with persistent values.""" - self._cache.clear() + self._autosave_cache_instance.clear() def _drop_persist_state_event(self) -> None: """Off event manager listener and drop event status.""" diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index 169b4454a7..e152c5943a 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -70,24 +70,20 @@ class RequestQueue(Storage, RequestManager): ``` """ - _cache_by_id: ClassVar[dict[str, RequestQueue]] = {} - """A dictionary to cache request queues by their IDs.""" + _cache: ClassVar[dict[str, RequestQueue]] = {} + """A dictionary to cache request queues.""" - _cache_by_name: ClassVar[dict[str, RequestQueue]] = {} - """A dictionary to cache request queues by their names.""" - - _MAX_CACHED_REQUESTS = 1_000_000 - """Maximum number of requests that can be cached.""" - - def __init__(self, client: RequestQueueClient) -> None: + def __init__(self, client: RequestQueueClient, cache_key: str) -> None: """Initialize a new instance. Preferably use the `RequestQueue.open` constructor to create a new instance. Args: client: An instance of a request queue client. + cache_key: A unique key to identify the request queue in the cache. """ self._client = client + self._cache_key = cache_key self._add_requests_tasks = list[asyncio.Task]() """A list of tasks for adding requests to the queue.""" @@ -130,38 +126,34 @@ async def open( if id and name: raise ValueError('Only one of "id" or "name" can be specified, not both.') - # Check if request queue is already cached by id or name - if id and id in cls._cache_by_id: - return cls._cache_by_id[id] - if name and name in cls._cache_by_name: - return cls._cache_by_name[name] - configuration = service_locator.get_configuration() if configuration is None else configuration storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - client = await storage_client.open_request_queue_client( + cache_key = cls.compute_cache_key( id=id, name=name, configuration=configuration, + storage_client=storage_client, ) - rq = cls(client) + if cache_key in cls._cache: + return cls._cache[cache_key] - # Cache the request queue by id and name if available - if rq.id: - cls._cache_by_id[rq.id] = rq - if rq.name: - cls._cache_by_name[rq.name] = rq + client = await storage_client.open_request_queue_client( + id=id, + name=name, + configuration=configuration, + ) + rq = cls(client, cache_key) + cls._cache[cache_key] = rq return rq @override async def drop(self) -> None: # Remove from cache before dropping - if self.id in self._cache_by_id: - del self._cache_by_id[self.id] - if self.name and self.name in self._cache_by_name: - del self._cache_by_name[self.name] + if self._cache_key in self._cache: + del self._cache[self._cache_key] await self._client.drop() diff --git a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py index c3297b570a..45f583c878 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py @@ -64,28 +64,6 @@ async def test_open_creates_new_dataset(configuration: Configuration) -> None: assert metadata['item_count'] == 0 -async def test_open_existing_dataset( - dataset_client: FileSystemDatasetClient, - configuration: Configuration, -) -> None: - """Test that open() loads an existing dataset correctly.""" - configuration.purge_on_start = False - - # Open the same dataset again - reopened_client = await FileSystemStorageClient().open_dataset_client( - name=dataset_client.metadata.name, - configuration=configuration, - ) - - # Verify client properties - assert dataset_client.metadata.id == reopened_client.metadata.id - assert dataset_client.metadata.name == reopened_client.metadata.name - assert dataset_client.metadata.item_count == reopened_client.metadata.item_count - - # Verify clients (python) ids - assert id(dataset_client) == id(reopened_client) - - async def test_dataset_client_purge_on_start(configuration: Configuration) -> None: """Test that purge_on_start=True clears existing data in the dataset.""" configuration.purge_on_start = True @@ -307,13 +285,11 @@ async def test_drop(dataset_client: FileSystemDatasetClient) -> None: """Test dropping a dataset removes the entire dataset directory from disk.""" await dataset_client.push_data({'test': 'data'}) - assert dataset_client.metadata.name in FileSystemDatasetClient._cache_by_name assert dataset_client.path_to_dataset.exists() # Drop the dataset await dataset_client.drop() - assert dataset_client.metadata.name not in FileSystemDatasetClient._cache_by_name assert not dataset_client.path_to_dataset.exists() diff --git a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py index 156394ad4d..1c1d4d3de6 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py @@ -61,27 +61,6 @@ async def test_open_creates_new_kvs(configuration: Configuration) -> None: assert metadata['name'] == 'new_kvs' -async def test_open_existing_kvs( - kvs_client: FileSystemKeyValueStoreClient, - configuration: Configuration, -) -> None: - """Test that open() loads an existing key-value store with matching properties.""" - configuration.purge_on_start = False - - # Open the same key-value store again - reopened_client = await FileSystemStorageClient().open_key_value_store_client( - name=kvs_client.metadata.name, - configuration=configuration, - ) - - # Verify client properties - assert kvs_client.metadata.id == reopened_client.metadata.id - assert kvs_client.metadata.name == reopened_client.metadata.name - - # Verify clients (python) ids - should be the same object due to caching - assert id(kvs_client) == id(reopened_client) - - async def test_kvs_client_purge_on_start(configuration: Configuration) -> None: """Test that purge_on_start=True clears existing data in the key-value store.""" configuration.purge_on_start = True @@ -323,13 +302,11 @@ async def test_drop(kvs_client: FileSystemKeyValueStoreClient) -> None: """Test that drop removes the entire store directory from disk.""" await kvs_client.set_value(key='test', value='test-value') - assert kvs_client.metadata.name in FileSystemKeyValueStoreClient._cache_by_name assert kvs_client.path_to_kvs.exists() # Drop the store await kvs_client.drop() - assert kvs_client.metadata.name not in FileSystemKeyValueStoreClient._cache_by_name assert not kvs_client.path_to_kvs.exists() diff --git a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py index 2b2cd16604..92991bc9b3 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py @@ -65,32 +65,6 @@ async def test_open_creates_new_rq(configuration: Configuration) -> None: assert metadata['name'] == 'new_request_queue' -async def test_open_existing_rq( - rq_client: FileSystemRequestQueueClient, - configuration: Configuration, -) -> None: - """Test that open() loads an existing request queue correctly.""" - configuration.purge_on_start = False - - # Add a request to the original client - await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) - - # Open the same request queue again - reopened_client = await FileSystemStorageClient().open_request_queue_client( - name=rq_client.metadata.name, - configuration=configuration, - ) - - # Verify client properties - assert rq_client.metadata.id == reopened_client.metadata.id - assert rq_client.metadata.name == reopened_client.metadata.name - assert rq_client.metadata.total_request_count == 1 - assert rq_client.metadata.pending_request_count == 1 - - # Verify clients (python) ids - should be the same object due to caching - assert id(rq_client) == id(reopened_client) - - async def test_rq_client_purge_on_start(configuration: Configuration) -> None: """Test that purge_on_start=True clears existing data in the request queue.""" configuration.purge_on_start = True @@ -408,9 +382,6 @@ async def test_drop(configuration: Configuration) -> None: # Verify the directory was removed assert not rq_path.exists() - # Verify the client was removed from the cache - assert client.metadata.name not in FileSystemRequestQueueClient._cache_by_name - async def test_file_persistence(configuration: Configuration) -> None: """Test that requests are persisted to files and can be recovered after a 'restart'.""" @@ -443,9 +414,6 @@ async def test_file_persistence(configuration: Configuration) -> None: request_files = list(storage_path.glob('*.json')) assert len(request_files) > 0, 'Request files should exist' - # Clear cache to simulate process restart - FileSystemRequestQueueClient._cache_by_name.clear() - # Create a new client with same name (which will load from files) client2 = await FileSystemStorageClient().open_request_queue_client( name='persistence_test', diff --git a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py index 52b0b3c733..c25074e5c0 100644 --- a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py @@ -36,28 +36,6 @@ async def test_open_creates_new_dataset() -> None: assert isinstance(client.metadata.accessed_at, datetime) assert isinstance(client.metadata.modified_at, datetime) - # Verify the client was cached - assert 'new_dataset' in MemoryDatasetClient._cache_by_name - - -async def test_open_existing_dataset(dataset_client: MemoryDatasetClient) -> None: - """Test that open() loads an existing dataset with matching properties.""" - configuration = Configuration(purge_on_start=False) - - # Open the same dataset again - reopened_client = await MemoryStorageClient().open_dataset_client( - name=dataset_client.metadata.name, - configuration=configuration, - ) - - # Verify client properties - assert dataset_client.metadata.id == reopened_client.metadata.id - assert dataset_client.metadata.name == reopened_client.metadata.name - assert dataset_client.metadata.item_count == reopened_client.metadata.item_count - - # Verify clients (python) ids - assert id(dataset_client) == id(reopened_client) - async def test_dataset_client_purge_on_start() -> None: """Test that purge_on_start=True clears existing data in the dataset.""" @@ -85,29 +63,6 @@ async def test_dataset_client_purge_on_start() -> None: assert len(items.items) == 0 -async def test_dataset_client_no_purge_on_start() -> None: - """Test that purge_on_start=False keeps existing data in the dataset.""" - configuration = Configuration(purge_on_start=False) - - # Create dataset and add data - dataset_client1 = await MemoryStorageClient().open_dataset_client( - name='test_no_purge_dataset', - configuration=configuration, - ) - await dataset_client1.push_data({'item': 'preserved data'}) - - # Reopen - dataset_client2 = await MemoryStorageClient().open_dataset_client( - name='test_no_purge_dataset', - configuration=configuration, - ) - - # Verify data was preserved - items = await dataset_client2.get_data() - assert len(items.items) == 1 - assert items.items[0]['item'] == 'preserved data' - - async def test_open_with_id_and_name() -> None: """Test that open() can be used with both id and name parameters.""" client = await MemoryStorageClient().open_dataset_client( @@ -283,15 +238,9 @@ async def test_drop(dataset_client: MemoryDatasetClient) -> None: """Test that drop removes the dataset from cache and resets its state.""" await dataset_client.push_data({'test': 'data'}) - # Verify the dataset exists in the cache - assert dataset_client.metadata.name in MemoryDatasetClient._cache_by_name - # Drop the dataset await dataset_client.drop() - # Verify the dataset was removed from the cache - assert dataset_client.metadata.name not in MemoryDatasetClient._cache_by_name - # Verify the dataset is empty assert dataset_client.metadata.item_count == 0 result = await dataset_client.get_data() diff --git a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py index 54c8d8b9a8..5d8789f6c3 100644 --- a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py @@ -35,26 +35,6 @@ async def test_open_creates_new_kvs() -> None: assert isinstance(client.metadata.accessed_at, datetime) assert isinstance(client.metadata.modified_at, datetime) - # Verify the client was cached - assert 'new_kvs' in MemoryKeyValueStoreClient._cache_by_name - - -async def test_open_existing_kvs(kvs_client: MemoryKeyValueStoreClient) -> None: - """Test that open() loads an existing key-value store with matching properties.""" - configuration = Configuration(purge_on_start=False) - # Open the same key-value store again - reopened_client = await MemoryStorageClient().open_key_value_store_client( - name=kvs_client.metadata.name, - configuration=configuration, - ) - - # Verify client properties - assert kvs_client.metadata.id == reopened_client.metadata.id - assert kvs_client.metadata.name == reopened_client.metadata.name - - # Verify clients (python) ids - assert id(kvs_client) == id(reopened_client) - async def test_kvs_client_purge_on_start() -> None: """Test that purge_on_start=True clears existing data in the KVS.""" @@ -83,29 +63,6 @@ async def test_kvs_client_purge_on_start() -> None: assert record is None -async def test_kvs_client_no_purge_on_start() -> None: - """Test that purge_on_start=False keeps existing data in the KVS.""" - configuration = Configuration(purge_on_start=False) - - # Create KVS and add data - kvs_client1 = await MemoryStorageClient().open_key_value_store_client( - name='test_no_purge_kvs', - configuration=configuration, - ) - await kvs_client1.set_value(key='test-key', value='preserved value') - - # Reopen - kvs_client2 = await MemoryStorageClient().open_key_value_store_client( - name='test_no_purge_kvs', - configuration=configuration, - ) - - # Verify value was preserved - record = await kvs_client2.get_value(key='test-key') - assert record is not None - assert record.value == 'preserved value' - - async def test_open_with_id_and_name() -> None: """Test that open() can be used with both id and name parameters.""" client = await MemoryStorageClient().open_key_value_store_client( @@ -240,15 +197,9 @@ async def test_drop(kvs_client: MemoryKeyValueStoreClient) -> None: # Add some values to the store await kvs_client.set_value(key='test', value='data') - # Verify the store exists in the cache - assert kvs_client.metadata.name in MemoryKeyValueStoreClient._cache_by_name - # Drop the store await kvs_client.drop() - # Verify the store was removed from the cache - assert kvs_client.metadata.name not in MemoryKeyValueStoreClient._cache_by_name - # Verify the store is empty record = await kvs_client.get_value(key='test') assert record is None diff --git a/tests/unit/storage_clients/_memory/test_memory_rq_client.py b/tests/unit/storage_clients/_memory/test_memory_rq_client.py index 36f6940119..028c53ccd2 100644 --- a/tests/unit/storage_clients/_memory/test_memory_rq_client.py +++ b/tests/unit/storage_clients/_memory/test_memory_rq_client.py @@ -39,26 +39,6 @@ async def test_open_creates_new_rq() -> None: assert client.metadata.total_request_count == 0 assert client.metadata.had_multiple_clients is False - # Verify the client was cached - assert 'new_rq' in MemoryRequestQueueClient._cache_by_name - - -async def test_open_existing_rq(rq_client: MemoryRequestQueueClient) -> None: - """Test that open() loads an existing request queue with matching properties.""" - configuration = Configuration(purge_on_start=False) - # Open the same request queue again - reopened_client = await MemoryStorageClient().open_request_queue_client( - name=rq_client.metadata.name, - configuration=configuration, - ) - - # Verify client properties - assert rq_client.metadata.id == reopened_client.metadata.id - assert rq_client.metadata.name == reopened_client.metadata.name - - # Verify clients (python) ids - assert id(rq_client) == id(reopened_client) - async def test_rq_client_purge_on_start() -> None: """Test that purge_on_start=True clears existing data in the RQ.""" @@ -85,31 +65,6 @@ async def test_rq_client_purge_on_start() -> None: assert await rq_client2.is_empty() is True -async def test_rq_client_no_purge_on_start() -> None: - """Test that purge_on_start=False keeps existing data in the RQ.""" - configuration = Configuration(purge_on_start=False) - - # Create RQ and add data - rq_client1 = await MemoryStorageClient().open_request_queue_client( - name='test_no_purge_rq', - configuration=configuration, - ) - request = Request.from_url(url='https://example.com/preserved') - await rq_client1.add_batch_of_requests([request]) - - # Reopen - rq_client2 = await MemoryStorageClient().open_request_queue_client( - name='test_no_purge_rq', - configuration=configuration, - ) - - # Verify request was preserved - assert await rq_client2.is_empty() is False - next_request = await rq_client2.fetch_next_request() - assert next_request is not None - assert next_request.url == 'https://example.com/preserved' - - async def test_open_with_id_and_name() -> None: """Test that open() can be used with both id and name parameters.""" client = await MemoryStorageClient().open_request_queue_client( @@ -418,15 +373,9 @@ async def test_drop(rq_client: MemoryRequestQueueClient) -> None: request = Request.from_url(url='https://example.com/test') await rq_client.add_batch_of_requests([request]) - # Verify the queue exists in the cache - assert rq_client.metadata.name in MemoryRequestQueueClient._cache_by_name - # Drop the queue await rq_client.drop() - # Verify the queue was removed from the cache - assert rq_client.metadata.name not in MemoryRequestQueueClient._cache_by_name - # Verify the queue is empty assert await rq_client.is_empty() is True diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index 8c9e0a30e1..a2159aac86 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -39,8 +39,7 @@ async def dataset( configuration: Configuration, ) -> AsyncGenerator[Dataset, None]: """Fixture that provides a dataset instance for each test.""" - Dataset._cache_by_id.clear() - Dataset._cache_by_name.clear() + Dataset._cache.clear() dataset = await Dataset.open( name='test_dataset', @@ -326,17 +325,13 @@ async def test_drop( await dataset.push_data({'test': 'data'}) # Verify dataset exists in cache - assert dataset.id in Dataset._cache_by_id - if dataset.name: - assert dataset.name in Dataset._cache_by_name + assert dataset._cache_key in Dataset._cache # Drop the dataset await dataset.drop() # Verify dataset was removed from cache - assert dataset.id not in Dataset._cache_by_id - if dataset.name: - assert dataset.name not in Dataset._cache_by_name + assert dataset._cache_key not in Dataset._cache # Verify dataset is empty (by creating a new one with the same name) new_dataset = await Dataset.open( diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index ab290f2819..b08f5e0924 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -40,8 +40,7 @@ async def kvs( configuration: Configuration, ) -> AsyncGenerator[KeyValueStore, None]: """Fixture that provides a key-value store instance for each test.""" - KeyValueStore._cache_by_id.clear() - KeyValueStore._cache_by_name.clear() + KeyValueStore._cache.clear() kvs = await KeyValueStore.open( name='test_kvs', @@ -264,17 +263,13 @@ async def test_drop( await kvs.set_value('test', 'data') # Verify key-value store exists in cache - assert kvs.id in KeyValueStore._cache_by_id - if kvs.name: - assert kvs.name in KeyValueStore._cache_by_name + assert kvs._cache_key in KeyValueStore._cache # Drop the key-value store await kvs.drop() # Verify key-value store was removed from cache - assert kvs.id not in KeyValueStore._cache_by_id - if kvs.name: - assert kvs.name not in KeyValueStore._cache_by_name + assert kvs._cache_key not in KeyValueStore._cache # Verify key-value store is empty (by creating a new one with the same name) new_kvs = await KeyValueStore.open( @@ -324,3 +319,48 @@ async def test_string_data(kvs: KeyValueStore) -> None: await kvs.set_value('json_string', json_string) result = await kvs.get_value('json_string') assert result == json_string + + +async def test_key_with_special_characters(kvs: KeyValueStore) -> None: + """Test storing and retrieving values with keys containing special characters.""" + # Key with spaces, slashes, and special characters + special_key = 'key with spaces/and/slashes!@#$%^&*()' + test_value = 'Special key value' + + # Store the value with the special key + await kvs.set_value(key=special_key, value=test_value) + + # Retrieve the value and verify it matches + result = await kvs.get_value(key=special_key) + assert result is not None + assert result == test_value + + # Make sure the key is properly listed + keys = await kvs.list_keys() + key_names = [k.key for k in keys] + assert special_key in key_names + + # Test key deletion + await kvs.delete_value(key=special_key) + assert await kvs.get_value(key=special_key) is None + + +async def test_data_persistence_on_reopen(configuration: Configuration) -> None: + """Test that data persists when reopening a KeyValueStore.""" + kvs1 = await KeyValueStore.open(configuration=configuration) + + await kvs1.set_value('key_123', 'value_123') + + result1 = await kvs1.get_value('key_123') + assert result1 == 'value_123' + + kvs2 = await KeyValueStore.open(configuration=configuration) + + result2 = await kvs2.get_value('key_123') + assert result2 == 'value_123' + assert await kvs1.list_keys() == await kvs2.list_keys() + + await kvs2.set_value('key_456', 'value_456') + + result1 = await kvs1.get_value('key_456') + assert result1 == 'value_456' diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index 876303cedb..81c588f95e 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -39,8 +39,7 @@ async def rq( configuration: Configuration, ) -> AsyncGenerator[RequestQueue, None]: """Fixture that provides a request queue instance for each test.""" - RequestQueue._cache_by_id.clear() - RequestQueue._cache_by_name.clear() + RequestQueue._cache.clear() rq = await RequestQueue.open( name='test_request_queue', @@ -471,17 +470,13 @@ async def test_drop( await rq.add_request('https://example.com') # Verify request queue exists in cache - assert rq.id in RequestQueue._cache_by_id - if rq.name: - assert rq.name in RequestQueue._cache_by_name + assert rq._cache_key in RequestQueue._cache # Drop the request queue await rq.drop() # Verify request queue was removed from cache - assert rq.id not in RequestQueue._cache_by_id - if rq.name: - assert rq.name not in RequestQueue._cache_by_name + assert rq._cache_key not in RequestQueue._cache # Verify request queue is empty (by creating a new one with the same name) new_rq = await RequestQueue.open(