diff --git a/docs/deployment/code_examples/google/cloud_run_example.py b/docs/deployment/code_examples/google/cloud_run_example.py index c01a4f3821..88db52bc75 100644 --- a/docs/deployment/code_examples/google/cloud_run_example.py +++ b/docs/deployment/code_examples/google/cloud_run_example.py @@ -5,24 +5,23 @@ import uvicorn from litestar import Litestar, get -from crawlee import service_locator from crawlee.crawlers import PlaywrightCrawler, PlaywrightCrawlingContext - -# highlight-start -# Disable writing storage data to the file system -configuration = service_locator.get_configuration() -configuration.persist_storage = False -configuration.write_metadata = False -# highlight-end +from crawlee.storage_clients import MemoryStorageClient @get('/') async def main() -> str: """The crawler entry point that will be called when the HTTP endpoint is accessed.""" + # highlight-start + # Disable writing storage data to the file system + storage_client = MemoryStorageClient() + # highlight-end + crawler = PlaywrightCrawler( headless=True, max_requests_per_crawl=10, browser_type='firefox', + storage_client=storage_client, ) @crawler.router.default_handler diff --git a/docs/deployment/code_examples/google/google_example.py b/docs/deployment/code_examples/google/google_example.py index f7180aa417..e31af2c3ab 100644 --- a/docs/deployment/code_examples/google/google_example.py +++ b/docs/deployment/code_examples/google/google_example.py @@ -6,22 +6,21 @@ import functions_framework from flask import Request, Response -from crawlee import service_locator from crawlee.crawlers import ( BeautifulSoupCrawler, BeautifulSoupCrawlingContext, ) - -# highlight-start -# Disable writing storage data to the file system -configuration = service_locator.get_configuration() -configuration.persist_storage = False -configuration.write_metadata = False -# highlight-end +from crawlee.storage_clients import MemoryStorageClient async def main() -> str: + # highlight-start + # Disable writing storage data to the file system + storage_client = MemoryStorageClient() + # highlight-end + crawler = BeautifulSoupCrawler( + storage_client=storage_client, max_request_retries=1, request_handler_timeout=timedelta(seconds=30), max_requests_per_crawl=10, diff --git a/docs/examples/code_examples/export_entire_dataset_to_file_csv.py b/docs/examples/code_examples/export_entire_dataset_to_file_csv.py index 115474fc61..f86a469c03 100644 --- a/docs/examples/code_examples/export_entire_dataset_to_file_csv.py +++ b/docs/examples/code_examples/export_entire_dataset_to_file_csv.py @@ -30,7 +30,7 @@ async def request_handler(context: BeautifulSoupCrawlingContext) -> None: await crawler.run(['https://crawlee.dev']) # Export the entire dataset to a CSV file. - await crawler.export_data_csv(path='results.csv') + await crawler.export_data(path='results.csv') if __name__ == '__main__': diff --git a/docs/examples/code_examples/export_entire_dataset_to_file_json.py b/docs/examples/code_examples/export_entire_dataset_to_file_json.py index 5c871fb228..81fe07afa4 100644 --- a/docs/examples/code_examples/export_entire_dataset_to_file_json.py +++ b/docs/examples/code_examples/export_entire_dataset_to_file_json.py @@ -30,7 +30,7 @@ async def request_handler(context: BeautifulSoupCrawlingContext) -> None: await crawler.run(['https://crawlee.dev']) # Export the entire dataset to a JSON file. - await crawler.export_data_json(path='results.json') + await crawler.export_data(path='results.json') if __name__ == '__main__': diff --git a/docs/examples/code_examples/parsel_crawler.py b/docs/examples/code_examples/parsel_crawler.py index 61ddb7484e..9807d7ca3b 100644 --- a/docs/examples/code_examples/parsel_crawler.py +++ b/docs/examples/code_examples/parsel_crawler.py @@ -40,7 +40,7 @@ async def some_hook(context: BasicCrawlingContext) -> None: await crawler.run(['https://github.com']) # Export the entire dataset to a JSON file. - await crawler.export_data_json(path='results.json') + await crawler.export_data(path='results.json') if __name__ == '__main__': diff --git a/docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py b/docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py deleted file mode 100644 index 15435da7bf..0000000000 --- a/docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py +++ /dev/null @@ -1,21 +0,0 @@ -import asyncio - -from crawlee.crawlers import HttpCrawler -from crawlee.storage_clients import MemoryStorageClient - - -async def main() -> None: - storage_client = MemoryStorageClient.from_config() - - # Call the purge_on_start method to explicitly purge the storage. - # highlight-next-line - await storage_client.purge_on_start() - - # Pass the storage client to the crawler. - crawler = HttpCrawler(storage_client=storage_client) - - # ... - - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/docs/guides/code_examples/storages/rq_basic_example.py b/docs/guides/code_examples/storages/rq_basic_example.py index 9e983bb9fe..388c184fc6 100644 --- a/docs/guides/code_examples/storages/rq_basic_example.py +++ b/docs/guides/code_examples/storages/rq_basic_example.py @@ -12,7 +12,7 @@ async def main() -> None: await request_queue.add_request('https://apify.com/') # Add multiple requests as a batch. - await request_queue.add_requests_batched( + await request_queue.add_requests( ['https://crawlee.dev/', 'https://crawlee.dev/python/'] ) diff --git a/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py b/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py index 21bedad0b9..bfece2eca5 100644 --- a/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py +++ b/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py @@ -10,9 +10,7 @@ async def main() -> None: request_queue = await RequestQueue.open(name='my-request-queue') # Interact with the request queue directly, e.g. add a batch of requests. - await request_queue.add_requests_batched( - ['https://apify.com/', 'https://crawlee.dev/'] - ) + await request_queue.add_requests(['https://apify.com/', 'https://crawlee.dev/']) # Create a new crawler (it can be any subclass of BasicCrawler) and pass the request # list as request manager to it. It will be managed by the crawler. diff --git a/docs/guides/request_loaders.mdx b/docs/guides/request_loaders.mdx index 73fe374a62..8816f2a388 100644 --- a/docs/guides/request_loaders.mdx +++ b/docs/guides/request_loaders.mdx @@ -52,12 +52,12 @@ class BaseStorage { class RequestLoader { <> + + handled_count + + total_count + fetch_next_request() + mark_request_as_handled() + is_empty() + is_finished() - + get_handled_count() - + get_total_count() + to_tandem() } diff --git a/docs/guides/storages.mdx b/docs/guides/storages.mdx index 3be168b683..37815bde59 100644 --- a/docs/guides/storages.mdx +++ b/docs/guides/storages.mdx @@ -24,7 +24,6 @@ import KvsWithCrawlerExample from '!!raw-loader!roa-loader!./code_examples/stora import KvsWithCrawlerExplicitExample from '!!raw-loader!roa-loader!./code_examples/storages/kvs_with_crawler_explicit_example.py'; import CleaningDoNotPurgeExample from '!!raw-loader!roa-loader!./code_examples/storages/cleaning_do_not_purge_example.py'; -import CleaningPurgeExplicitlyExample from '!!raw-loader!roa-loader!./code_examples/storages/cleaning_purge_explicitly_example.py'; Crawlee offers multiple storage types for managing and persisting your crawling data. Request-oriented storages, such as the `RequestQueue`, help you store and deduplicate URLs, while result-oriented storages, like `Dataset` and `KeyValueStore`, focus on storing and retrieving scraping results. This guide helps you choose the storage type that suits your needs. @@ -210,12 +209,6 @@ Default storages are purged before the crawler starts, unless explicitly configu If you do not explicitly interact with storages in your code, the purging will occur automatically when the `BasicCrawler.run` method is invoked. -If you need to purge storages earlier, you can call `MemoryStorageClient.purge_on_start` directly if you are using the default storage client. This method triggers the purging process for the underlying storage implementation you are currently using. - - - {CleaningPurgeExplicitlyExample} - - ## Conclusion This guide introduced you to the different storage types available in Crawlee and how to interact with them. You learned how to manage requests and store and retrieve scraping results using the `RequestQueue`, `Dataset`, and `KeyValueStore`. You also discovered how to use helper functions to simplify interactions with these storages. Finally, you learned how to clean up storages before starting a crawler run and how to purge them explicitly. If you have questions or need assistance, feel free to reach out on our [GitHub](https://github.com/apify/crawlee-python) or join our [Discord community](https://discord.com/invite/jyEM2PRvMU). Happy scraping! diff --git a/pyproject.toml b/pyproject.toml index 78a966558a..660bf28d83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,8 +93,8 @@ crawlee = "crawlee._cli:cli" [dependency-groups] dev = [ - "apify_client", # For e2e tests. - "build~=1.2.2", # For e2e tests. + "apify-client", # For e2e tests. + "build~=1.2.2", # For e2e tests. "mypy~=1.15.0", "pre-commit~=4.2.0", "proxy-py~=2.4.0", @@ -144,7 +144,6 @@ ignore = [ "PLR0911", # Too many return statements "PLR0913", # Too many arguments in function definition "PLR0915", # Too many statements - "PTH", # flake8-use-pathlib "PYI034", # `__aenter__` methods in classes like `{name}` usually return `self` at runtime "PYI036", # The second argument in `__aexit__` should be annotated with `object` or `BaseException | None` "S102", # Use of `exec` detected @@ -166,6 +165,7 @@ indent-style = "space" "F401", # Unused imports ] "**/{tests}/*" = [ + "ASYNC230", # Async functions should not open files with blocking methods like `open` "D", # Everything from the pydocstyle "INP001", # File {filename} is part of an implicit namespace package, add an __init__.py "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable diff --git a/src/crawlee/_cli.py b/src/crawlee/_cli.py index 689cb7b182..d8b295b5ed 100644 --- a/src/crawlee/_cli.py +++ b/src/crawlee/_cli.py @@ -22,7 +22,7 @@ cli = typer.Typer(no_args_is_help=True) template_directory = importlib.resources.files('crawlee') / 'project_template' -with open(str(template_directory / 'cookiecutter.json')) as f: +with (template_directory / 'cookiecutter.json').open() as f: cookiecutter_json = json.load(f) crawler_choices = cookiecutter_json['crawler_type'] diff --git a/src/crawlee/_service_locator.py b/src/crawlee/_service_locator.py index 31bc36c63c..2cb8f8302a 100644 --- a/src/crawlee/_service_locator.py +++ b/src/crawlee/_service_locator.py @@ -3,8 +3,8 @@ from crawlee._utils.docs import docs_group from crawlee.configuration import Configuration from crawlee.errors import ServiceConflictError -from crawlee.events import EventManager -from crawlee.storage_clients import StorageClient +from crawlee.events import EventManager, LocalEventManager +from crawlee.storage_clients import FileSystemStorageClient, StorageClient @docs_group('Classes') @@ -49,8 +49,6 @@ def set_configuration(self, configuration: Configuration) -> None: def get_event_manager(self) -> EventManager: """Get the event manager.""" if self._event_manager is None: - from crawlee.events import LocalEventManager - self._event_manager = ( LocalEventManager().from_config(config=self._configuration) if self._configuration @@ -77,13 +75,7 @@ def set_event_manager(self, event_manager: EventManager) -> None: def get_storage_client(self) -> StorageClient: """Get the storage client.""" if self._storage_client is None: - from crawlee.storage_clients import MemoryStorageClient - - self._storage_client = ( - MemoryStorageClient.from_config(config=self._configuration) - if self._configuration - else MemoryStorageClient.from_config() - ) + self._storage_client = FileSystemStorageClient() self._storage_client_was_retrieved = True return self._storage_client diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index c68ae63df9..711f3cf145 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -23,7 +23,7 @@ from crawlee.sessions import Session from crawlee.storage_clients.models import DatasetItemsListPage from crawlee.storages import KeyValueStore - from crawlee.storages._dataset import ExportToKwargs, GetDataKwargs + from crawlee.storages._types import ExportToKwargs, GetDataKwargs # Workaround for https://github.com/pydantic/pydantic/issues/9445 J = TypeVar('J', bound='JsonSerializable') @@ -275,10 +275,6 @@ async def push_data( **kwargs: Unpack[PushDataKwargs], ) -> None: """Track a call to the `push_data` context helper.""" - from crawlee.storages._dataset import Dataset - - await Dataset.check_and_serialize(data) - self.push_data_calls.append( PushDataFunctionCall( data=data, diff --git a/src/crawlee/_utils/file.py b/src/crawlee/_utils/file.py index 022d0604ef..4de6804490 100644 --- a/src/crawlee/_utils/file.py +++ b/src/crawlee/_utils/file.py @@ -2,18 +2,26 @@ import asyncio import contextlib -import io +import csv import json import mimetypes import os import re import shutil from enum import Enum +from logging import getLogger from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import AsyncIterator from pathlib import Path - from typing import Any + from typing import Any, TextIO + + from typing_extensions import Unpack + + from crawlee.storages._types import ExportDataCsvKwargs, ExportDataJsonKwargs + +logger = getLogger(__name__) class ContentType(Enum): @@ -83,28 +91,67 @@ def determine_file_extension(content_type: str) -> str | None: return ext[1:] if ext is not None else ext -def is_file_or_bytes(value: Any) -> bool: - """Determine if the input value is a file-like object or bytes. - - This function checks whether the provided value is an instance of bytes, bytearray, or io.IOBase (file-like). - The method is simplified for common use cases and may not cover all edge cases. +async def json_dumps(obj: Any) -> str: + """Serialize an object to a JSON-formatted string with specific settings. Args: - value: The value to be checked. + obj: The object to serialize. Returns: - True if the value is either a file-like object or bytes, False otherwise. + A string containing the JSON representation of the input object. """ - return isinstance(value, (bytes, bytearray, io.IOBase)) + return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) -async def json_dumps(obj: Any) -> str: - """Serialize an object to a JSON-formatted string with specific settings. +def infer_mime_type(value: Any) -> str: + """Infer the MIME content type from the value. Args: - obj: The object to serialize. + value: The value to infer the content type from. Returns: - A string containing the JSON representation of the input object. + The inferred MIME content type. """ - return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) + # If the value is bytes (or bytearray), return binary content type. + if isinstance(value, (bytes, bytearray)): + return 'application/octet-stream' + + # If the value is a dict or list, assume JSON. + if isinstance(value, (dict, list)): + return 'application/json; charset=utf-8' + + # If the value is a string, assume plain text. + if isinstance(value, str): + return 'text/plain; charset=utf-8' + + # Default fallback. + return 'application/octet-stream' + + +async def export_json_to_stream( + iterator: AsyncIterator[dict], + dst: TextIO, + **kwargs: Unpack[ExportDataJsonKwargs], +) -> None: + items = [item async for item in iterator] + json.dump(items, dst, **kwargs) + + +async def export_csv_to_stream( + iterator: AsyncIterator[dict], + dst: TextIO, + **kwargs: Unpack[ExportDataCsvKwargs], +) -> None: + writer = csv.writer(dst, **kwargs) + write_header = True + + # Iterate over the dataset and write to CSV. + async for item in iterator: + if not item: + continue + + if write_header: + writer.writerow(item.keys()) + write_header = False + + writer.writerow(item.values()) diff --git a/src/crawlee/configuration.py b/src/crawlee/configuration.py index de22118816..e3ef39f486 100644 --- a/src/crawlee/configuration.py +++ b/src/crawlee/configuration.py @@ -118,21 +118,7 @@ class Configuration(BaseSettings): ) ), ] = True - """Whether to purge the storage on the start. This option is utilized by the `MemoryStorageClient`.""" - - write_metadata: Annotated[bool, Field(alias='crawlee_write_metadata')] = True - """Whether to write the storage metadata. This option is utilized by the `MemoryStorageClient`.""" - - persist_storage: Annotated[ - bool, - Field( - validation_alias=AliasChoices( - 'apify_persist_storage', - 'crawlee_persist_storage', - ) - ), - ] = True - """Whether to persist the storage. This option is utilized by the `MemoryStorageClient`.""" + """Whether to purge the storage on the start. This option is utilized by the storage clients.""" persist_state_interval: Annotated[ timedelta_ms, @@ -239,7 +225,7 @@ class Configuration(BaseSettings): ), ), ] = './storage' - """The path to the storage directory. This option is utilized by the `MemoryStorageClient`.""" + """The path to the storage directory. This option is utilized by the storage clients.""" headless: Annotated[ bool, diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 7e07c87f16..78fc6df5f7 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -35,6 +35,7 @@ SendRequestFunction, ) from crawlee._utils.docs import docs_group +from crawlee._utils.file import export_csv_to_stream, export_json_to_stream from crawlee._utils.robots import RobotsTxtFile from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute from crawlee._utils.wait import wait_for @@ -65,7 +66,7 @@ import re from contextlib import AbstractAsyncContextManager - from crawlee._types import ConcurrencySettings, HttpMethod, JsonSerializable + from crawlee._types import ConcurrencySettings, HttpMethod, JsonSerializable, PushDataKwargs from crawlee.configuration import Configuration from crawlee.events import EventManager from crawlee.http_clients import HttpClient, HttpResponse @@ -75,7 +76,7 @@ from crawlee.statistics import FinalStatistics from crawlee.storage_clients import StorageClient from crawlee.storage_clients.models import DatasetItemsListPage - from crawlee.storages._dataset import ExportDataCsvKwargs, ExportDataJsonKwargs, GetDataKwargs, PushDataKwargs + from crawlee.storages._types import GetDataKwargs TCrawlingContext = TypeVar('TCrawlingContext', bound=BasicCrawlingContext, default=BasicCrawlingContext) TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState) @@ -651,6 +652,7 @@ async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, wait_time_between_batches: timedelta = timedelta(0), wait_for_all_requests_to_be_added: bool = False, @@ -660,6 +662,7 @@ async def add_requests( Args: requests: A list of requests to add to the queue. + forefront: If True, add requests to the forefront of the queue. batch_size: The number of requests to add in one batch. wait_time_between_batches: Time to wait between adding batches. wait_for_all_requests_to_be_added: If True, wait for all requests to be added before returning. @@ -682,17 +685,21 @@ async def add_requests( request_manager = await self.get_request_manager() - await request_manager.add_requests_batched( + await request_manager.add_requests( requests=allowed_requests, + forefront=forefront, batch_size=batch_size, wait_time_between_batches=wait_time_between_batches, wait_for_all_requests_to_be_added=wait_for_all_requests_to_be_added, wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout, ) - async def _use_state(self, default_value: dict[str, JsonSerializable] | None = None) -> dict[str, JsonSerializable]: - store = await self.get_key_value_store() - return await store.get_auto_saved_value(self._CRAWLEE_STATE_KEY, default_value) + async def _use_state( + self, + default_value: dict[str, JsonSerializable] | None = None, + ) -> dict[str, JsonSerializable]: + kvs = await self.get_key_value_store() + return await kvs.get_auto_saved_value(self._CRAWLEE_STATE_KEY, default_value) async def _save_crawler_state(self) -> None: store = await self.get_key_value_store() @@ -726,78 +733,29 @@ async def export_data( dataset_id: str | None = None, dataset_name: str | None = None, ) -> None: - """Export data from a `Dataset`. + """Export all items from a Dataset to a JSON or CSV file. - This helper method simplifies the process of exporting data from a `Dataset`. It opens the specified - one and then exports the data based on the provided parameters. If you need to pass options - specific to the output format, use the `export_data_csv` or `export_data_json` method instead. + This method simplifies the process of exporting data collected during crawling. It automatically + determines the export format based on the file extension (`.json` or `.csv`) and handles + the conversion of `Dataset` items to the appropriate format. Args: - path: The destination path. - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. + path: The destination file path. Must end with '.json' or '.csv'. + dataset_id: The ID of the Dataset to export from. If None, uses `name` parameter instead. + dataset_name: The name of the Dataset to export from. If None, uses `id` parameter instead. """ dataset = await self.get_dataset(id=dataset_id, name=dataset_name) path = path if isinstance(path, Path) else Path(path) - destination = path.open('w', newline='') + dst = path.open('w', newline='') if path.suffix == '.csv': - await dataset.write_to_csv(destination) + await export_csv_to_stream(dataset.iterate_items(), dst) elif path.suffix == '.json': - await dataset.write_to_json(destination) + await export_json_to_stream(dataset.iterate_items(), dst) else: raise ValueError(f'Unsupported file extension: {path.suffix}') - async def export_data_csv( - self, - path: str | Path, - *, - dataset_id: str | None = None, - dataset_name: str | None = None, - **kwargs: Unpack[ExportDataCsvKwargs], - ) -> None: - """Export data from a `Dataset` to a CSV file. - - This helper method simplifies the process of exporting data from a `Dataset` in csv format. It opens - the specified one and then exports the data based on the provided parameters. - - Args: - path: The destination path. - content_type: The output format. - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. - kwargs: Extra configurations for dumping/writing in csv format. - """ - dataset = await self.get_dataset(id=dataset_id, name=dataset_name) - path = path if isinstance(path, Path) else Path(path) - - return await dataset.write_to_csv(path.open('w', newline=''), **kwargs) - - async def export_data_json( - self, - path: str | Path, - *, - dataset_id: str | None = None, - dataset_name: str | None = None, - **kwargs: Unpack[ExportDataJsonKwargs], - ) -> None: - """Export data from a `Dataset` to a JSON file. - - This helper method simplifies the process of exporting data from a `Dataset` in json format. It opens the - specified one and then exports the data based on the provided parameters. - - Args: - path: The destination path - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. - kwargs: Extra configurations for dumping/writing in json format. - """ - dataset = await self.get_dataset(id=dataset_id, name=dataset_name) - path = path if isinstance(path, Path) else Path(path) - - return await dataset.write_to_json(path.open('w', newline=''), **kwargs) - async def _push_data( self, data: JsonSerializable, @@ -1089,7 +1047,7 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> ): requests.append(dst_request) - await request_manager.add_requests_batched(requests) + await request_manager.add_requests(requests) for push_data_call in result.push_data_calls: await self._push_data(**push_data_call) diff --git a/src/crawlee/fingerprint_suite/_browserforge_adapter.py b/src/crawlee/fingerprint_suite/_browserforge_adapter.py index d64ddd59f0..11f9f82d79 100644 --- a/src/crawlee/fingerprint_suite/_browserforge_adapter.py +++ b/src/crawlee/fingerprint_suite/_browserforge_adapter.py @@ -1,10 +1,10 @@ from __future__ import annotations -import os.path from collections.abc import Iterable from copy import deepcopy from functools import reduce from operator import or_ +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal from browserforge.bayesian_network import extract_json @@ -253,9 +253,9 @@ def generate(self, browser_type: SupportedBrowserType = 'chromium') -> dict[str, def get_available_header_network() -> dict: """Get header network that contains possible header values.""" - if os.path.isfile(DATA_DIR / 'header-network.zip'): + if Path(DATA_DIR / 'header-network.zip').is_file(): return extract_json(DATA_DIR / 'header-network.zip') - if os.path.isfile(DATA_DIR / 'header-network-definition.zip'): + if Path(DATA_DIR / 'header-network-definition.zip').is_file(): return extract_json(DATA_DIR / 'header-network-definition.zip') raise FileNotFoundError('Missing header-network file.') diff --git a/src/crawlee/project_template/hooks/post_gen_project.py b/src/crawlee/project_template/hooks/post_gen_project.py index e076ff9308..c0495a724d 100644 --- a/src/crawlee/project_template/hooks/post_gen_project.py +++ b/src/crawlee/project_template/hooks/post_gen_project.py @@ -2,7 +2,6 @@ import subprocess from pathlib import Path - # % if cookiecutter.package_manager in ['poetry', 'uv'] Path('requirements.txt').unlink() @@ -32,8 +31,9 @@ # Install requirements and generate requirements.txt as an impromptu lockfile subprocess.check_call([str(path / 'pip'), 'install', '-r', 'requirements.txt']) -with open('requirements.txt', 'w') as requirements_txt: - subprocess.check_call([str(path / 'pip'), 'freeze'], stdout=requirements_txt) +Path('requirements.txt').write_text( + subprocess.check_output([str(path / 'pip'), 'freeze']).decode() +) # % if cookiecutter.crawler_type == 'playwright' subprocess.check_call([str(path / 'playwright'), 'install']) diff --git a/src/crawlee/request_loaders/_request_list.py b/src/crawlee/request_loaders/_request_list.py index 5964b106d0..3f545e1615 100644 --- a/src/crawlee/request_loaders/_request_list.py +++ b/src/crawlee/request_loaders/_request_list.py @@ -55,7 +55,13 @@ def name(self) -> str | None: return self._name @override - async def get_total_count(self) -> int: + @property + async def handled_count(self) -> int: + return self._handled_count + + @override + @property + async def total_count(self) -> int: return self._assumed_total_count @override @@ -87,10 +93,6 @@ async def mark_request_as_handled(self, request: Request) -> None: self._handled_count += 1 self._in_progress.remove(request.id) - @override - async def get_handled_count(self) -> int: - return self._handled_count - async def _ensure_next_request(self) -> None: if self._requests_lock is None: self._requests_lock = asyncio.Lock() diff --git a/src/crawlee/request_loaders/_request_loader.py b/src/crawlee/request_loaders/_request_loader.py index e358306a45..0a2e96e02f 100644 --- a/src/crawlee/request_loaders/_request_loader.py +++ b/src/crawlee/request_loaders/_request_loader.py @@ -25,9 +25,15 @@ class RequestLoader(ABC): - Managing state information such as the total and handled request counts. """ + @property @abstractmethod - async def get_total_count(self) -> int: - """Return an offline approximation of the total number of requests in the source (i.e. pending + handled).""" + async def handled_count(self) -> int: + """The number of requests that have been handled.""" + + @property + @abstractmethod + async def total_count(self) -> int: + """The total number of requests in the loader.""" @abstractmethod async def is_empty(self) -> bool: @@ -45,10 +51,6 @@ async def fetch_next_request(self) -> Request | None: async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: """Mark a request as handled after a successful processing (or after giving up retrying).""" - @abstractmethod - async def get_handled_count(self) -> int: - """Return the number of handled requests.""" - async def to_tandem(self, request_manager: RequestManager | None = None) -> RequestManagerTandem: """Combine the loader with a request manager to support adding and reclaiming requests. diff --git a/src/crawlee/request_loaders/_request_manager.py b/src/crawlee/request_loaders/_request_manager.py index f63f962cb9..5a8427c2cb 100644 --- a/src/crawlee/request_loaders/_request_manager.py +++ b/src/crawlee/request_loaders/_request_manager.py @@ -6,12 +6,12 @@ from crawlee._utils.docs import docs_group from crawlee.request_loaders._request_loader import RequestLoader +from crawlee.storage_clients.models import ProcessedRequest if TYPE_CHECKING: from collections.abc import Sequence from crawlee._request import Request - from crawlee.storage_clients.models import ProcessedRequest @docs_group('Abstract classes') @@ -40,10 +40,11 @@ async def add_request( Information about the request addition to the manager. """ - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, # noqa: ARG002 wait_time_between_batches: timedelta = timedelta(seconds=1), # noqa: ARG002 wait_for_all_requests_to_be_added: bool = False, # noqa: ARG002 @@ -53,14 +54,17 @@ async def add_requests_batched( Args: requests: Requests to enqueue. + forefront: If True, add requests to the beginning of the queue. batch_size: The number of requests to add in one batch. wait_time_between_batches: Time to wait between adding batches. wait_for_all_requests_to_be_added: If True, wait for all requests to be added before returning. wait_for_all_requests_to_be_added_timeout: Timeout for waiting for all requests to be added. """ # Default and dumb implementation. + processed_requests = list[ProcessedRequest]() for request in requests: - await self.add_request(request) + processed_request = await self.add_request(request, forefront=forefront) + processed_requests.append(processed_request) @abstractmethod async def reclaim_request(self, request: Request, *, forefront: bool = False) -> ProcessedRequest | None: diff --git a/src/crawlee/request_loaders/_request_manager_tandem.py b/src/crawlee/request_loaders/_request_manager_tandem.py index 9f0b8cefe8..35cc59e102 100644 --- a/src/crawlee/request_loaders/_request_manager_tandem.py +++ b/src/crawlee/request_loaders/_request_manager_tandem.py @@ -33,8 +33,14 @@ def __init__(self, request_loader: RequestLoader, request_manager: RequestManage self._read_write_manager = request_manager @override - async def get_total_count(self) -> int: - return (await self._read_only_loader.get_total_count()) + (await self._read_write_manager.get_total_count()) + @property + async def handled_count(self) -> int: + return await self._read_write_manager.handled_count + + @override + @property + async def total_count(self) -> int: + return (await self._read_only_loader.total_count) + (await self._read_write_manager.total_count) @override async def is_empty(self) -> bool: @@ -49,17 +55,19 @@ async def add_request(self, request: str | Request, *, forefront: bool = False) return await self._read_write_manager.add_request(request, forefront=forefront) @override - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, wait_time_between_batches: timedelta = timedelta(seconds=1), wait_for_all_requests_to_be_added: bool = False, wait_for_all_requests_to_be_added_timeout: timedelta | None = None, ) -> None: - return await self._read_write_manager.add_requests_batched( + return await self._read_write_manager.add_requests( requests, + forefront=forefront, batch_size=batch_size, wait_time_between_batches=wait_time_between_batches, wait_for_all_requests_to_be_added=wait_for_all_requests_to_be_added, @@ -97,10 +105,6 @@ async def reclaim_request(self, request: Request, *, forefront: bool = False) -> async def mark_request_as_handled(self, request: Request) -> None: await self._read_write_manager.mark_request_as_handled(request) - @override - async def get_handled_count(self) -> int: - return await self._read_write_manager.get_handled_count() - @override async def drop(self) -> None: await self._read_write_manager.drop() diff --git a/src/crawlee/storage_clients/__init__.py b/src/crawlee/storage_clients/__init__.py index 66d352d7a7..ce8c713ca9 100644 --- a/src/crawlee/storage_clients/__init__.py +++ b/src/crawlee/storage_clients/__init__.py @@ -1,4 +1,9 @@ from ._base import StorageClient +from ._file_system import FileSystemStorageClient from ._memory import MemoryStorageClient -__all__ = ['MemoryStorageClient', 'StorageClient'] +__all__ = [ + 'FileSystemStorageClient', + 'MemoryStorageClient', + 'StorageClient', +] diff --git a/src/crawlee/storage_clients/_base/__init__.py b/src/crawlee/storage_clients/_base/__init__.py index 5194da8768..73298560da 100644 --- a/src/crawlee/storage_clients/_base/__init__.py +++ b/src/crawlee/storage_clients/_base/__init__.py @@ -1,20 +1,11 @@ from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient from ._storage_client import StorageClient -from ._types import ResourceClient, ResourceCollectionClient __all__ = [ 'DatasetClient', - 'DatasetCollectionClient', 'KeyValueStoreClient', - 'KeyValueStoreCollectionClient', 'RequestQueueClient', - 'RequestQueueCollectionClient', - 'ResourceClient', - 'ResourceCollectionClient', 'StorageClient', ] diff --git a/src/crawlee/storage_clients/_base/_dataset_client.py b/src/crawlee/storage_clients/_base/_dataset_client.py index d8495b2dd0..c73eb6f51f 100644 --- a/src/crawlee/storage_clients/_base/_dataset_client.py +++ b/src/crawlee/storage_clients/_base/_dataset_client.py @@ -7,58 +7,76 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from contextlib import AbstractAsyncContextManager + from typing import Any - from httpx import Response - - from crawlee._types import JsonSerializable + from crawlee.configuration import Configuration from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata @docs_group('Abstract classes') class DatasetClient(ABC): - """An abstract class for dataset resource clients. + """An abstract class for dataset storage clients. - These clients are specific to the type of resource they manage and operate under a designated storage - client, like a memory storage client. - """ + Dataset clients provide an interface for accessing and manipulating dataset storage. They handle + operations like adding and getting dataset items across different storage backends. - _LIST_ITEMS_LIMIT = 999_999_999_999 - """This is what API returns in the x-apify-pagination-limit header when no limit query parameter is used.""" + Storage clients are specific to the type of storage they manage (`Dataset`, `KeyValueStore`, + `RequestQueue`), and can operate with various storage systems including memory, file system, + databases, and cloud storage solutions. - @abstractmethod - async def get(self) -> DatasetMetadata | None: - """Get metadata about the dataset being managed by this client. + This abstract class defines the interface that all specific dataset clients must implement. + """ - Returns: - An object containing the dataset's details, or None if the dataset does not exist. - """ + @property + @abstractmethod + def metadata(self) -> DatasetMetadata: + """The metadata of the dataset.""" + @classmethod @abstractmethod - async def update( - self, + async def open( + cls, *, - name: str | None = None, - ) -> DatasetMetadata: - """Update the dataset metadata. + id: str | None, + name: str | None, + configuration: Configuration, + ) -> DatasetClient: + """Open existing or create a new dataset client. + + If a dataset with the given name or ID already exists, the appropriate dataset client is returned. + Otherwise, a new dataset is created and client for it is returned. + + The backend method for the `Dataset.open` call. Args: - name: New new name for the dataset. + id: The ID of the dataset. If not provided, an ID may be generated. + name: The name of the dataset. If not provided a default name may be used. + configuration: The configuration object. Returns: - An object reflecting the updated dataset metadata. + A dataset client instance. + """ + + @abstractmethod + async def drop(self) -> None: + """Drop the whole dataset and remove all its items. + + The backend method for the `Dataset.drop` call. """ @abstractmethod - async def delete(self) -> None: - """Permanently delete the dataset managed by this client.""" + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + """Push data to the dataset. + + The backend method for the `Dataset.push_data` call. + """ @abstractmethod - async def list_items( + async def get_data( self, *, - offset: int | None = 0, - limit: int | None = _LIST_ITEMS_LIMIT, + offset: int = 0, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -69,27 +87,9 @@ async def list_items( flatten: list[str] | None = None, view: str | None = None, ) -> DatasetItemsListPage: - """Retrieve a paginated list of items from a dataset based on various filtering parameters. - - This method provides the flexibility to filter, sort, and modify the appearance of dataset items - when listed. Each parameter modifies the result set according to its purpose. The method also - supports pagination through 'offset' and 'limit' parameters. - - Args: - offset: The number of initial items to skip. - limit: The maximum number of items to return. - clean: If True, removes empty items and hidden fields, equivalent to 'skip_hidden' and 'skip_empty'. - desc: If True, items are returned in descending order, i.e., newest first. - fields: Specifies a subset of fields to include in each item. - omit: Specifies a subset of fields to exclude from each item. - unwind: Specifies a field that should be unwound. If it's an array, each element becomes a separate record. - skip_empty: If True, omits items that are empty after other filters have been applied. - skip_hidden: If True, omits fields starting with the '#' character. - flatten: A list of fields to flatten in each item. - view: The specific view of the dataset to use when retrieving items. + """Get data from the dataset with various filtering options. - Returns: - An object with filtered, sorted, and paginated dataset items plus pagination details. + The backend method for the `Dataset.get_data` call. """ @abstractmethod @@ -106,126 +106,12 @@ async def iterate_items( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict]: - """Iterate over items in the dataset according to specified filters and sorting. - - This method allows for asynchronously iterating through dataset items while applying various filters such as - skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` - parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and - `skip_hidden` parameters. + """Iterate over the dataset items with filtering options. - Args: - offset: The number of initial items to skip. - limit: The maximum number of items to iterate over. None means no limit. - clean: If True, removes empty items and hidden fields, equivalent to 'skip_hidden' and 'skip_empty'. - desc: If set to True, items are returned in descending order, i.e., newest first. - fields: Specifies a subset of fields to include in each item. - omit: Specifies a subset of fields to exclude from each item. - unwind: Specifies a field that should be unwound into separate items. - skip_empty: If set to True, omits items that are empty after other filters have been applied. - skip_hidden: If set to True, omits fields starting with the '#' character from the output. - - Yields: - An asynchronous iterator of dictionary objects, each representing a dataset item after applying - the specified filters and transformations. + The backend method for the `Dataset.iterate_items` call. """ # This syntax is to make mypy properly work with abstract AsyncIterator. # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators raise NotImplementedError if False: # type: ignore[unreachable] yield 0 - - @abstractmethod - async def get_items_as_bytes( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - flatten: list[str] | None = None, - ) -> bytes: - """Retrieve dataset items as bytes. - - Args: - item_format: Output format (e.g., 'json', 'csv'); default is 'json'. - offset: Number of items to skip; default is 0. - limit: Max number of items to return; no default limit. - desc: If True, results are returned in descending order. - clean: If True, filters out empty items and hidden fields. - bom: Include or exclude UTF-8 BOM; default behavior varies by format. - delimiter: Delimiter character for CSV; default is ','. - fields: List of fields to include in the results. - omit: List of fields to omit from the results. - unwind: Unwinds a field into separate records. - skip_empty: If True, skips empty items in the output. - skip_header_row: If True, skips the header row in CSV. - skip_hidden: If True, skips hidden fields in the output. - xml_root: Root element name for XML output; default is 'items'. - xml_row: Element name for each item in XML output; default is 'item'. - flatten: List of fields to flatten. - - Returns: - The dataset items as raw bytes. - """ - - @abstractmethod - async def stream_items( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - ) -> AbstractAsyncContextManager[Response | None]: - """Retrieve dataset items as a streaming response. - - Args: - item_format: Output format, options include json, jsonl, csv, html, xlsx, xml, rss; default is json. - offset: Number of items to skip at the start; default is 0. - limit: Maximum number of items to return; no default limit. - desc: If True, reverses the order of results. - clean: If True, filters out empty items and hidden fields. - bom: Include or exclude UTF-8 BOM; varies by format. - delimiter: Delimiter for CSV files; default is ','. - fields: List of fields to include in the output. - omit: List of fields to omit from the output. - unwind: Unwinds a field into separate records. - skip_empty: If True, empty items are omitted. - skip_header_row: If True, skips the header row in CSV. - skip_hidden: If True, hides fields starting with the # character. - xml_root: Custom root element name for XML output; default is 'items'. - xml_row: Custom element name for each item in XML; default is 'item'. - - Yields: - The dataset items in a streaming response. - """ - - @abstractmethod - async def push_items(self, items: JsonSerializable) -> None: - """Push items to the dataset. - - Args: - items: The items which to push in the dataset. They must be JSON serializable. - """ diff --git a/src/crawlee/storage_clients/_base/_dataset_collection_client.py b/src/crawlee/storage_clients/_base/_dataset_collection_client.py deleted file mode 100644 index 8530655c8c..0000000000 --- a/src/crawlee/storage_clients/_base/_dataset_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import DatasetListPage, DatasetMetadata - - -@docs_group('Abstract classes') -class DatasetCollectionClient(ABC): - """An abstract class for dataset collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> DatasetMetadata: - """Retrieve an existing dataset by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the dataset to retrieve or create. If provided, the method will attempt - to find a dataset with the ID. - name: Optional name of the dataset resource to retrieve or create. If provided, the method will - attempt to find a dataset with this name. - schema: Optional schema for the dataset resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created dataset. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> DatasetListPage: - """List the available datasets. - - Args: - unnamed: Whether to list only the unnamed datasets. - limit: Maximum number of datasets to return. - offset: Number of datasets to skip from the beginning of the list. - desc: Whether to sort the datasets in descending order. - - Returns: - The list of available datasets matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_key_value_store_client.py b/src/crawlee/storage_clients/_base/_key_value_store_client.py index 6a5d141be6..957b53db0e 100644 --- a/src/crawlee/storage_clients/_base/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_base/_key_value_store_client.py @@ -6,126 +6,105 @@ from crawlee._utils.docs import docs_group if TYPE_CHECKING: - from contextlib import AbstractAsyncContextManager + from collections.abc import AsyncIterator - from httpx import Response - - from crawlee.storage_clients.models import KeyValueStoreListKeysPage, KeyValueStoreMetadata, KeyValueStoreRecord + from crawlee.configuration import Configuration + from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata @docs_group('Abstract classes') class KeyValueStoreClient(ABC): - """An abstract class for key-value store resource clients. + """An abstract class for key-value store (KVS) storage clients. + + Key-value stores clients provide an interface for accessing and manipulating KVS storage. They handle + operations like getting, setting, deleting KVS values across different storage backends. + + Storage clients are specific to the type of storage they manage (`Dataset`, `KeyValueStore`, + `RequestQueue`), and can operate with various storage systems including memory, file system, + databases, and cloud storage solutions. - These clients are specific to the type of resource they manage and operate under a designated storage - client, like a memory storage client. + This abstract class defines the interface that all specific KVS clients must implement. """ + @property @abstractmethod - async def get(self) -> KeyValueStoreMetadata | None: - """Get metadata about the key-value store being managed by this client. - - Returns: - An object containing the key-value store's details, or None if the key-value store does not exist. - """ + def metadata(self) -> KeyValueStoreMetadata: + """The metadata of the key-value store.""" + @classmethod @abstractmethod - async def update( - self, + async def open( + cls, *, - name: str | None = None, - ) -> KeyValueStoreMetadata: - """Update the key-value store metadata. + id: str | None, + name: str | None, + configuration: Configuration, + ) -> KeyValueStoreClient: + """Open existing or create a new key-value store client. - Args: - name: New new name for the key-value store. + If a key-value store with the given name or ID already exists, the appropriate + key-value store client is returned. Otherwise, a new key-value store is created + and a client for it is returned. - Returns: - An object reflecting the updated key-value store metadata. - """ - - @abstractmethod - async def delete(self) -> None: - """Permanently delete the key-value store managed by this client.""" - - @abstractmethod - async def list_keys( - self, - *, - limit: int = 1000, - exclusive_start_key: str | None = None, - ) -> KeyValueStoreListKeysPage: - """List the keys in the key-value store. + The backend method for the `KeyValueStoreClient.open` call. Args: - limit: Number of keys to be returned. Maximum value is 1000. - exclusive_start_key: All keys up to this one (including) are skipped from the result. + id: The ID of the key-value store. If not provided, an ID may be generated. + name: The name of the key-value store. If not provided a default name may be used. + configuration: The configuration object. Returns: - The list of keys in the key-value store matching the given arguments. + A key-value store client instance. """ @abstractmethod - async def get_record(self, key: str) -> KeyValueStoreRecord | None: - """Retrieve the given record from the key-value store. + async def drop(self) -> None: + """Drop the whole key-value store and remove all its values. - Args: - key: Key of the record to retrieve. - - Returns: - The requested record, or None, if the record does not exist + The backend method for the `KeyValueStore.drop` call. """ @abstractmethod - async def get_record_as_bytes(self, key: str) -> KeyValueStoreRecord[bytes] | None: - """Retrieve the given record from the key-value store, without parsing it. - - Args: - key: Key of the record to retrieve. + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + """Retrieve the given record from the key-value store. - Returns: - The requested record, or None, if the record does not exist + The backend method for the `KeyValueStore.get_value` call. """ @abstractmethod - async def stream_record(self, key: str) -> AbstractAsyncContextManager[KeyValueStoreRecord[Response] | None]: - """Retrieve the given record from the key-value store, as a stream. - - Args: - key: Key of the record to retrieve. + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + """Set a value in the key-value store by its key. - Returns: - The requested record as a context-managed streaming Response, or None, if the record does not exist + The backend method for the `KeyValueStore.set_value` call. """ @abstractmethod - async def set_record(self, key: str, value: Any, content_type: str | None = None) -> None: - """Set a value to the given record in the key-value store. + async def delete_value(self, *, key: str) -> None: + """Delete a value from the key-value store by its key. - Args: - key: The key of the record to save the value to. - value: The value to save into the record. - content_type: The content type of the saved value. + The backend method for the `KeyValueStore.delete_value` call. """ @abstractmethod - async def delete_record(self, key: str) -> None: - """Delete the specified record from the key-value store. + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + """Iterate over all the existing keys in the key-value store. - Args: - key: The key of the record which to delete. + The backend method for the `KeyValueStore.iterate_keys` call. """ + # This syntax is to make mypy properly work with abstract AsyncIterator. + # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators + raise NotImplementedError + if False: # type: ignore[unreachable] + yield 0 @abstractmethod - async def get_public_url(self, key: str) -> str: + async def get_public_url(self, *, key: str) -> str: """Get the public URL for the given key. - Args: - key: Key of the record for which URL is required. - - Returns: - The public URL for the given key. - - Raises: - ValueError: If the key does not exist. + The backend method for the `KeyValueStore.get_public_url` call. """ diff --git a/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py b/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py deleted file mode 100644 index b447cf49b1..0000000000 --- a/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import KeyValueStoreListPage, KeyValueStoreMetadata - - -@docs_group('Abstract classes') -class KeyValueStoreCollectionClient(ABC): - """An abstract class for key-value store collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> KeyValueStoreMetadata: - """Retrieve an existing key-value store by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the key-value store to retrieve or create. If provided, the method will attempt - to find a key-value store with the ID. - name: Optional name of the key-value store resource to retrieve or create. If provided, the method will - attempt to find a key-value store with this name. - schema: Optional schema for the key-value store resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created key-value store. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> KeyValueStoreListPage: - """List the available key-value stores. - - Args: - unnamed: Whether to list only the unnamed key-value stores. - limit: Maximum number of key-value stores to return. - offset: Number of key-value stores to skip from the beginning of the list. - desc: Whether to sort the key-value stores in descending order. - - Returns: - The list of available key-value stores matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_request_queue_client.py b/src/crawlee/storage_clients/_base/_request_queue_client.py index 06b180801a..7f2cdc11f1 100644 --- a/src/crawlee/storage_clients/_base/_request_queue_client.py +++ b/src/crawlee/storage_clients/_base/_request_queue_client.py @@ -8,13 +8,11 @@ if TYPE_CHECKING: from collections.abc import Sequence + from crawlee.configuration import Configuration from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, + AddRequestsResponse, ProcessedRequest, - ProlongRequestLockResponse, Request, - RequestQueueHead, - RequestQueueHeadWithLocks, RequestQueueMetadata, ) @@ -27,91 +25,63 @@ class RequestQueueClient(ABC): client, like a memory storage client. """ + @property @abstractmethod - async def get(self) -> RequestQueueMetadata | None: - """Get metadata about the request queue being managed by this client. - - Returns: - An object containing the request queue's details, or None if the request queue does not exist. - """ + def metadata(self) -> RequestQueueMetadata: + """The metadata of the request queue.""" + @classmethod @abstractmethod - async def update( - self, + async def open( + cls, *, - name: str | None = None, - ) -> RequestQueueMetadata: - """Update the request queue metadata. - - Args: - name: New new name for the request queue. - - Returns: - An object reflecting the updated request queue metadata. - """ - - @abstractmethod - async def delete(self) -> None: - """Permanently delete the request queue managed by this client.""" - - @abstractmethod - async def list_head(self, *, limit: int | None = None) -> RequestQueueHead: - """Retrieve a given number of requests from the beginning of the queue. + id: str | None, + name: str | None, + configuration: Configuration, + ) -> RequestQueueClient: + """Open a request queue client. Args: - limit: How many requests to retrieve. + id: ID of the queue to open. If not provided, a new queue will be created with a random ID. + name: Name of the queue to open. If not provided, the queue will be unnamed. + configuration: The configuration object. Returns: - The desired number of requests from the beginning of the queue. + A request queue client. """ @abstractmethod - async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) -> RequestQueueHeadWithLocks: - """Fetch and lock a specified number of requests from the start of the queue. - - Retrieve and locks the first few requests of a queue for the specified duration. This prevents the requests - from being fetched by another client until the lock expires. - - Args: - lock_secs: Duration for which the requests are locked, in seconds. - limit: Maximum number of requests to retrieve and lock. - - Returns: - The desired number of locked requests from the beginning of the queue. - """ - - @abstractmethod - async def add_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - """Add a request to the queue. - - Args: - request: The request to add to the queue. - forefront: Whether to add the request to the head or the end of the queue. + async def drop(self) -> None: + """Drop the whole request queue and remove all its values. - Returns: - Request queue operation information. + The backend method for the `RequestQueue.drop` call. """ @abstractmethod - async def batch_add_requests( + async def add_batch_of_requests( self, requests: Sequence[Request], *, forefront: bool = False, - ) -> BatchRequestsOperationResponse: - """Add a batch of requests to the queue. + ) -> AddRequestsResponse: + """Add batch of requests to the queue. + + This method adds a batch of requests to the queue. Each request is processed based on its uniqueness + (determined by `unique_key`). Duplicates will be identified but not re-added to the queue. Args: - requests: The requests to add to the queue. - forefront: Whether to add the requests to the head or the end of the queue. + requests: The collection of requests to add to the queue. + forefront: Whether to put the added requests at the beginning (True) or the end (False) of the queue. + When True, the requests will be processed sooner than previously added requests. + batch_size: The maximum number of requests to add in a single batch. + wait_time_between_batches: The time to wait between adding batches of requests. + wait_for_all_requests_to_be_added: If True, the method will wait until all requests are added + to the queue before returning. + wait_for_all_requests_to_be_added_timeout: The maximum time to wait for all requests to be added. Returns: - Request queue batch operation information. + A response object containing information about which requests were successfully + processed and which failed (if any). """ @abstractmethod @@ -126,64 +96,58 @@ async def get_request(self, request_id: str) -> Request | None: """ @abstractmethod - async def update_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - """Update a request in the queue. + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. - Args: - request: The updated request. - forefront: Whether to put the updated request in the beginning or the end of the queue. + Once you successfully finish processing of the request, you need to call `RequestQueue.mark_request_as_handled` + to mark the request as handled in the queue. If there was some error in processing the request, call + `RequestQueue.reclaim_request` instead, so that the queue will give the request to some other consumer + in another call to the `fetch_next_request` method. + + Note that the `None` return value does not mean the queue processing finished, it means there are currently + no pending requests. To check whether all requests in queue were finished, use `RequestQueue.is_finished` + instead. Returns: - The updated request + The request or `None` if there are no more pending requests. """ @abstractmethod - async def delete_request(self, request_id: str) -> None: - """Delete a request from the queue. - - Args: - request_id: ID of the request to delete. - """ + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. - @abstractmethod - async def prolong_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - lock_secs: int, - ) -> ProlongRequestLockResponse: - """Prolong the lock on a specific request in the queue. + Handled requests will never again be returned by the `RequestQueue.fetch_next_request` method. Args: - request_id: The identifier of the request whose lock is to be prolonged. - forefront: Whether to put the request in the beginning or the end of the queue after lock expires. - lock_secs: The additional amount of time, in seconds, that the request will remain locked. + request: The request to mark as handled. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. """ @abstractmethod - async def delete_request_lock( + async def reclaim_request( self, - request_id: str, + request: Request, *, forefront: bool = False, - ) -> None: - """Delete the lock on a specific request in the queue. + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. + + The request will be returned for processing later again by another call to `RequestQueue.fetch_next_request`. Args: - request_id: ID of the request to delete the lock. - forefront: Whether to put the request in the beginning or the end of the queue after the lock is deleted. + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. """ @abstractmethod - async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: - """Delete given requests from the queue. + async def is_empty(self) -> bool: + """Check if the request queue is empty. - Args: - requests: The requests to delete from the queue. + Returns: + True if the request queue is empty, False otherwise. """ diff --git a/src/crawlee/storage_clients/_base/_request_queue_collection_client.py b/src/crawlee/storage_clients/_base/_request_queue_collection_client.py deleted file mode 100644 index 7de876c344..0000000000 --- a/src/crawlee/storage_clients/_base/_request_queue_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import RequestQueueListPage, RequestQueueMetadata - - -@docs_group('Abstract classes') -class RequestQueueCollectionClient(ABC): - """An abstract class for request queue collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> RequestQueueMetadata: - """Retrieve an existing request queue by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the request queue to retrieve or create. If provided, the method will attempt - to find a request queue with the ID. - name: Optional name of the request queue resource to retrieve or create. If provided, the method will - attempt to find a request queue with this name. - schema: Optional schema for the request queue resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created request queue. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> RequestQueueListPage: - """List the available request queues. - - Args: - unnamed: Whether to list only the unnamed request queues. - limit: Maximum number of request queues to return. - offset: Number of request queues to skip from the beginning of the list. - desc: Whether to sort the request queues in descending order. - - Returns: - The list of available request queues matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_storage_client.py b/src/crawlee/storage_clients/_base/_storage_client.py index 4f022cf30a..36f9cb7567 100644 --- a/src/crawlee/storage_clients/_base/_storage_client.py +++ b/src/crawlee/storage_clients/_base/_storage_client.py @@ -1,61 +1,48 @@ -# Inspiration: https://github.com/apify/crawlee/blob/v3.8.2/packages/types/src/storages.ts#L314:L328 - from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING -from crawlee._utils.docs import docs_group - if TYPE_CHECKING: + from crawlee.configuration import Configuration + from ._dataset_client import DatasetClient - from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient - from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._request_queue_client import RequestQueueClient - from ._request_queue_collection_client import RequestQueueCollectionClient -@docs_group('Abstract classes') class StorageClient(ABC): - """Defines an abstract base for storage clients. - - It offers interfaces to get subclients for interacting with storage resources like datasets, key-value stores, - and request queues. - """ + """Base class for storage clients.""" @abstractmethod - def dataset(self, id: str) -> DatasetClient: - """Get a subclient for a specific dataset by its ID.""" + async def open_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> DatasetClient: + """Open a dataset client.""" @abstractmethod - def datasets(self) -> DatasetCollectionClient: - """Get a subclient for dataset collection operations.""" + async def open_key_value_store_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> KeyValueStoreClient: + """Open a key-value store client.""" @abstractmethod - def key_value_store(self, id: str) -> KeyValueStoreClient: - """Get a subclient for a specific key-value store by its ID.""" - - @abstractmethod - def key_value_stores(self) -> KeyValueStoreCollectionClient: - """Get a subclient for key-value store collection operations.""" - - @abstractmethod - def request_queue(self, id: str) -> RequestQueueClient: - """Get a subclient for a specific request queue by its ID.""" - - @abstractmethod - def request_queues(self) -> RequestQueueCollectionClient: - """Get a subclient for request queue collection operations.""" - - @abstractmethod - async def purge_on_start(self) -> None: - """Perform a purge of the default storages. - - This method ensures that the purge is executed only once during the lifetime of the instance. - It is primarily used to clean up residual data from previous runs to maintain a clean state. - If the storage client does not support purging, leave it empty. - """ + async def open_request_queue_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> RequestQueueClient: + """Open a request queue client.""" def get_rate_limit_errors(self) -> dict[int, int]: """Return statistics about rate limit errors encountered by the HTTP client in storage client.""" diff --git a/src/crawlee/storage_clients/_base/_types.py b/src/crawlee/storage_clients/_base/_types.py deleted file mode 100644 index a5cf1325f5..0000000000 --- a/src/crawlee/storage_clients/_base/_types.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -from typing import Union - -from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient -from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient -from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient - -ResourceClient = Union[ - DatasetClient, - KeyValueStoreClient, - RequestQueueClient, -] - -ResourceCollectionClient = Union[ - DatasetCollectionClient, - KeyValueStoreCollectionClient, - RequestQueueCollectionClient, -] diff --git a/src/crawlee/storage_clients/_file_system/__init__.py b/src/crawlee/storage_clients/_file_system/__init__.py new file mode 100644 index 0000000000..2169896d86 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/__init__.py @@ -0,0 +1,11 @@ +from ._dataset_client import FileSystemDatasetClient +from ._key_value_store_client import FileSystemKeyValueStoreClient +from ._request_queue_client import FileSystemRequestQueueClient +from ._storage_client import FileSystemStorageClient + +__all__ = [ + 'FileSystemDatasetClient', + 'FileSystemKeyValueStoreClient', + 'FileSystemRequestQueueClient', + 'FileSystemStorageClient', +] diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py new file mode 100644 index 0000000000..fa1266524a --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -0,0 +1,412 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +from datetime import datetime, timezone +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING + +from pydantic import ValidationError +from typing_extensions import override + +from crawlee._utils.crypto import crypto_random_object_id +from crawlee.storage_clients._base import DatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata + +from ._utils import METADATA_FILENAME, json_dumps + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from typing import Any + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class FileSystemDatasetClient(DatasetClient): + """File system implementation of the dataset client. + + This client persists dataset items to the file system as individual JSON files within a structured + directory hierarchy following the pattern: + + ``` + {STORAGE_DIR}/datasets/{DATASET_ID}/{ITEM_ID}.json + ``` + + Each item is stored as a separate file, which allows for durability and the ability to + recover after process termination. Dataset operations like filtering, sorting, and pagination are + implemented by processing the stored files according to the requested parameters. + + This implementation is ideal for long-running crawlers where data persistence is important, + and for development environments where you want to easily inspect the collected data between runs. + """ + + _STORAGE_SUBDIR = 'datasets' + """The name of the subdirectory where datasets are stored.""" + + _ITEM_FILENAME_DIGITS = 9 + """Number of digits used for the dataset item file names (e.g., 000000019.json).""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + item_count: int, + storage_dir: Path, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemDatasetClient.open` class method to create a new instance. + """ + self._metadata = DatasetMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + item_count=item_count, + ) + + self._storage_dir = storage_dir + + # Internal attributes + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + @override + @property + def metadata(self) -> DatasetMetadata: + return self._metadata + + @property + def path_to_dataset(self) -> Path: + """The full path to the dataset directory.""" + return self._storage_dir / self._STORAGE_SUBDIR / self.metadata.name + + @property + def path_to_metadata(self) -> Path: + """The full path to the dataset metadata file.""" + return self.path_to_dataset / METADATA_FILENAME + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> FileSystemDatasetClient: + if id: + raise ValueError( + 'Opening a dataset by "id" is not supported for file system storage client, use "name" instead.' + ) + + name = name or configuration.default_dataset_id + + storage_dir = Path(configuration.storage_dir) + dataset_path = storage_dir / cls._STORAGE_SUBDIR / name + metadata_path = dataset_path / METADATA_FILENAME + + # If the dataset directory exists, reconstruct the client from the metadata file. + if dataset_path.exists(): + # If metadata file is missing, raise an error. + if not metadata_path.exists(): + raise ValueError(f'Metadata file not found for dataset "{name}"') + + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = DatasetMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for dataset "{name}"') from exc + + client = cls( + id=metadata.id, + name=name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + item_count=metadata.item_count, + storage_dir=storage_dir, + ) + + await client._update_metadata(update_accessed_at=True) + + # Otherwise, create a new dataset client. + else: + now = datetime.now(timezone.utc) + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + item_count=0, + storage_dir=storage_dir, + ) + await client._update_metadata() + + return client + + @override + async def drop(self) -> None: + # If the client directory exists, remove it recursively. + if self.path_to_dataset.exists(): + async with self._lock: + await asyncio.to_thread(shutil.rmtree, self.path_to_dataset) + + @override + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + new_item_count = self.metadata.item_count + + # If data is a list, push each item individually. + if isinstance(data, list): + for item in data: + new_item_count += 1 + await self._push_item(item, new_item_count) + else: + new_item_count += 1 + await self._push_item(data, new_item_count) + + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=new_item_count, + ) + + @override + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + # Check for unsupported arguments and log a warning if found. + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + 'flatten': flatten, + 'view': view, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of get_data are not supported by the ' + f'{self.__class__.__name__} client.' + ) + + # If the dataset directory does not exist, log a warning and return an empty page. + if not self.path_to_dataset.exists(): + logger.warning(f'Dataset directory not found: {self.path_to_dataset}') + return DatasetItemsListPage( + count=0, + offset=offset, + limit=limit or 0, + total=0, + desc=desc, + items=[], + ) + + # Get the list of sorted data files. + data_files = await self._get_sorted_data_files() + total = len(data_files) + + # Reverse the order if descending order is requested. + if desc: + data_files.reverse() + + # Apply offset and limit slicing. + selected_files = data_files[offset:] + if limit is not None: + selected_files = selected_files[:limit] + + # Read and parse each data file. + items = [] + for file_path in selected_files: + try: + file_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + item = json.loads(file_content) + except Exception: + logger.exception(f'Error reading {file_path}, skipping the item.') + continue + + # Skip empty items if requested. + if skip_empty and not item: + continue + + items.append(item) + + await self._update_metadata(update_accessed_at=True) + + # Return a paginated list page of dataset items. + return DatasetItemsListPage( + count=len(items), + offset=offset, + limit=limit or total - offset, + total=total, + desc=desc, + items=items, + ) + + @override + async def iterate_items( + self, + *, + offset: int = 0, + limit: int | None = None, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> AsyncIterator[dict]: + # Check for unsupported arguments and log a warning if found. + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of iterate are not supported ' + f'by the {self.__class__.__name__} client.' + ) + + # If the dataset directory does not exist, log a warning and return immediately. + if not self.path_to_dataset.exists(): + logger.warning(f'Dataset directory not found: {self.path_to_dataset}') + return + + # Get the list of sorted data files. + data_files = await self._get_sorted_data_files() + + # Reverse the order if descending order is requested. + if desc: + data_files.reverse() + + # Apply offset and limit slicing. + selected_files = data_files[offset:] + if limit is not None: + selected_files = selected_files[:limit] + + # Iterate over each data file, reading and yielding its parsed content. + for file_path in selected_files: + try: + file_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + item = json.loads(file_content) + except Exception: + logger.exception(f'Error reading {file_path}, skipping the item.') + continue + + # Skip empty items if requested. + if skip_empty and not item: + continue + + yield item + + await self._update_metadata(update_accessed_at=True) + + async def _update_metadata( + self, + *, + new_item_count: int | None = None, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the dataset metadata file with current information. + + Args: + new_item_count: If provided, update the item count to this value. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + if new_item_count is not None: + self._metadata.item_count = new_item_count + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(self._metadata.model_dump()) + await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') + + async def _push_item(self, item: dict[str, Any], item_id: int) -> None: + """Push a single item to the dataset. + + This method writes the item as a JSON file with a zero-padded numeric filename + that reflects its position in the dataset sequence. + + Args: + item: The data item to add to the dataset. + item_id: The sequential ID to use for this item's filename. + """ + # Acquire the lock to perform file operations safely. + async with self._lock: + # Generate the filename for the new item using zero-padded numbering. + filename = f'{str(item_id).zfill(self._ITEM_FILENAME_DIGITS)}.json' + file_path = self.path_to_dataset / filename + + # Ensure the dataset directory exists. + await asyncio.to_thread(self.path_to_dataset.mkdir, parents=True, exist_ok=True) + + # Dump the serialized item to the file. + data = await json_dumps(item) + await asyncio.to_thread(file_path.write_text, data, encoding='utf-8') + + async def _get_sorted_data_files(self) -> list[Path]: + """Retrieve and return a sorted list of data files in the dataset directory. + + The files are sorted numerically based on the filename (without extension), + which corresponds to the order items were added to the dataset. + + Returns: + A list of `Path` objects pointing to data files, sorted by numeric filename. + """ + # Retrieve and sort all JSON files in the dataset directory numerically. + files = await asyncio.to_thread( + sorted, + self.path_to_dataset.glob('*.json'), + key=lambda f: int(f.stem) if f.stem.isdigit() else 0, + ) + + # Remove the metadata file from the list if present. + if self.path_to_metadata in files: + files.remove(self.path_to_metadata) + + return files diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py new file mode 100644 index 0000000000..4ed427f797 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -0,0 +1,379 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +import urllib.parse +from datetime import datetime, timezone +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from pydantic import ValidationError +from typing_extensions import override + +from crawlee._utils.crypto import crypto_random_object_id +from crawlee._utils.file import infer_mime_type +from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata + +from ._utils import METADATA_FILENAME, json_dumps + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from crawlee.configuration import Configuration + + +logger = getLogger(__name__) + + +class FileSystemKeyValueStoreClient(KeyValueStoreClient): + """File system implementation of the key-value store client. + + This client persists data to the file system, making it suitable for scenarios where data needs to + survive process restarts. Keys are mapped to file paths in a directory structure following the pattern: + + ``` + {STORAGE_DIR}/key_value_stores/{STORE_ID}/{KEY} + ``` + + Binary data is stored as-is, while JSON and text data are stored in human-readable format. + The implementation automatically handles serialization based on the content type and + maintains metadata about each record. + + This implementation is ideal for long-running crawlers where persistence is important and + for development environments where you want to easily inspect the stored data between runs. + """ + + _STORAGE_SUBDIR = 'key_value_stores' + """The name of the subdirectory where key-value stores are stored.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + storage_dir: Path, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemKeyValueStoreClient.open` class method to create a new instance. + """ + self._metadata = KeyValueStoreMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + ) + + self._storage_dir = storage_dir + + # Internal attributes + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + @override + @property + def metadata(self) -> KeyValueStoreMetadata: + return self._metadata + + @property + def path_to_kvs(self) -> Path: + """The full path to the key-value store directory.""" + return self._storage_dir / self._STORAGE_SUBDIR / self.metadata.name + + @property + def path_to_metadata(self) -> Path: + """The full path to the key-value store metadata file.""" + return self.path_to_kvs / METADATA_FILENAME + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> FileSystemKeyValueStoreClient: + if id: + raise ValueError( + 'Opening a key-value store by "id" is not supported for file system storage client, use "name" instead.' + ) + + name = name or configuration.default_dataset_id + + storage_dir = Path(configuration.storage_dir) + kvs_path = storage_dir / cls._STORAGE_SUBDIR / name + metadata_path = kvs_path / METADATA_FILENAME + + # If the key-value store directory exists, reconstruct the client from the metadata file. + if kvs_path.exists(): + # If metadata file is missing, raise an error. + if not metadata_path.exists(): + raise ValueError(f'Metadata file not found for key-value store "{name}"') + + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = KeyValueStoreMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for key-value store "{name}"') from exc + + client = cls( + id=metadata.id, + name=name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + storage_dir=storage_dir, + ) + + await client._update_metadata(update_accessed_at=True) + + # Otherwise, create a new key-value store client. + else: + now = datetime.now(timezone.utc) + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + storage_dir=storage_dir, + ) + await client._update_metadata() + + return client + + @override + async def drop(self) -> None: + # If the client directory exists, remove it recursively. + if self.path_to_kvs.exists(): + async with self._lock: + await asyncio.to_thread(shutil.rmtree, self.path_to_kvs) + + @override + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + # Update the metadata to record access + await self._update_metadata(update_accessed_at=True) + + record_path = self.path_to_kvs / self._encode_key(key) + + if not record_path.exists(): + return None + + # Found a file for this key, now look for its metadata + record_metadata_filepath = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + if not record_metadata_filepath.exists(): + logger.warning(f'Found value file for key "{key}" but no metadata file.') + return None + + # Read the metadata file + async with self._lock: + file = await asyncio.to_thread(open, record_metadata_filepath) + try: + metadata_content = json.load(file) + except json.JSONDecodeError: + logger.warning(f'Invalid metadata file for key "{key}"') + return None + finally: + await asyncio.to_thread(file.close) + + try: + metadata = KeyValueStoreRecordMetadata(**metadata_content) + except ValidationError: + logger.warning(f'Invalid metadata schema for key "{key}"') + return None + + # Read the actual value + value_bytes = await asyncio.to_thread(record_path.read_bytes) + + # Handle JSON values + if 'application/json' in metadata.content_type: + try: + value = json.loads(value_bytes.decode('utf-8')) + except (json.JSONDecodeError, UnicodeDecodeError): + logger.warning(f'Failed to decode JSON value for key "{key}"') + return None + + # Handle text values + elif metadata.content_type.startswith('text/'): + try: + value = value_bytes.decode('utf-8') + except UnicodeDecodeError: + logger.warning(f'Failed to decode text value for key "{key}"') + return None + + # Handle binary values + else: + value = value_bytes + + # Calculate the size of the value in bytes + size = len(value_bytes) + + return KeyValueStoreRecord( + key=metadata.key, + value=value, + content_type=metadata.content_type, + size=size, + ) + + @override + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + content_type = content_type or infer_mime_type(value) + + # Serialize the value to bytes. + if 'application/json' in content_type: + value_bytes = (await json_dumps(value)).encode('utf-8') + elif isinstance(value, str): + value_bytes = value.encode('utf-8') + elif isinstance(value, (bytes, bytearray)): + value_bytes = value + else: + # Fallback: attempt to convert to string and encode. + value_bytes = str(value).encode('utf-8') + + record_path = self.path_to_kvs / self._encode_key(key) + + # Prepare the metadata + size = len(value_bytes) + record_metadata = KeyValueStoreRecordMetadata(key=key, content_type=content_type, size=size) + record_metadata_filepath = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + record_metadata_content = await json_dumps(record_metadata.model_dump()) + + async with self._lock: + # Ensure the key-value store directory exists. + await asyncio.to_thread(self.path_to_kvs.mkdir, parents=True, exist_ok=True) + + # Write the value to the file. + await asyncio.to_thread(record_path.write_bytes, value_bytes) + + # Write the record metadata to the file. + await asyncio.to_thread( + record_metadata_filepath.write_text, + record_metadata_content, + encoding='utf-8', + ) + + # Update the KVS metadata to record the access and modification. + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def delete_value(self, *, key: str) -> None: + record_path = self.path_to_kvs / self._encode_key(key) + metadata_path = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + deleted = False + + async with self._lock: + # Delete the value file and its metadata if found + if record_path.exists(): + await asyncio.to_thread(record_path.unlink) + + # Delete the metadata file if it exists + if metadata_path.exists(): + await asyncio.to_thread(metadata_path.unlink) + else: + logger.warning(f'Found value file for key "{key}" but no metadata file when trying to delete it.') + + deleted = True + + # If we deleted something, update the KVS metadata + if deleted: + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + # Check if the KVS directory exists + if not self.path_to_kvs.exists(): + return + + count = 0 + async with self._lock: + # Get all files in the KVS directory, sorted alphabetically + files = sorted(await asyncio.to_thread(list, self.path_to_kvs.glob('*'))) + + for file_path in files: + # Skip the main metadata file + if file_path.name == METADATA_FILENAME: + continue + + # Only process metadata files for records + if not file_path.name.endswith(f'.{METADATA_FILENAME}'): + continue + + # Extract the base key name from the metadata filename + key_name = self._decode_key(file_path.name[: -len(f'.{METADATA_FILENAME}')]) + + # Apply exclusive_start_key filter if provided + if exclusive_start_key is not None and key_name <= exclusive_start_key: + continue + + # Try to read and parse the metadata file + try: + metadata_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + metadata_dict = json.loads(metadata_content) + record_metadata = KeyValueStoreRecordMetadata(**metadata_dict) + + yield record_metadata + + count += 1 + if limit and count >= limit: + break + + except (json.JSONDecodeError, ValidationError) as e: + logger.warning(f'Failed to parse metadata file {file_path}: {e}') + + # Update accessed_at timestamp + await self._update_metadata(update_accessed_at=True) + + @override + async def get_public_url(self, *, key: str) -> str: + raise NotImplementedError('Public URLs are not supported for file system key-value stores.') + + async def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the KVS metadata file with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(self._metadata.model_dump()) + await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') + + def _encode_key(self, key: str) -> str: + """Encode a key to make it safe for use in a file path.""" + return urllib.parse.quote(key, safe='') + + def _decode_key(self, encoded_key: str) -> str: + """Decode a key that was encoded to make it safe for use in a file path.""" + return urllib.parse.unquote(encoded_key) diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py new file mode 100644 index 0000000000..2d170df09b --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -0,0 +1,656 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +from datetime import datetime, timezone +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING + +from pydantic import ValidationError +from typing_extensions import override + +from crawlee import Request +from crawlee._utils.crypto import crypto_random_object_id +from crawlee.storage_clients._base import RequestQueueClient +from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, RequestQueueMetadata + +from ._utils import METADATA_FILENAME, json_dumps + +if TYPE_CHECKING: + from collections.abc import Sequence + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class FileSystemRequestQueueClient(RequestQueueClient): + """A file system implementation of the request queue client. + + This client persists requests to the file system as individual JSON files, making it suitable for scenarios + where data needs to survive process restarts. Each request is stored as a separate file in a directory + structure following the pattern: + + ``` + {STORAGE_DIR}/request_queues/{QUEUE_ID}/{REQUEST_ID}.json + ``` + + The implementation uses file timestamps for FIFO ordering of regular requests and maintains in-memory sets + for tracking in-progress and forefront requests. File system storage provides durability at the cost of + slower I/O operations compared to memory-based storage. + + This implementation is ideal for long-running crawlers where persistence is important and for situations + where you need to resume crawling after process termination. + """ + + _STORAGE_SUBDIR = 'request_queues' + """The name of the subdirectory where request queues are stored.""" + + def __init__( + self, + *, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + had_multiple_clients: bool, + handled_request_count: int, + pending_request_count: int, + stats: dict, + total_request_count: int, + storage_dir: Path, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemRequestQueueClient.open` class method to create a new instance. + """ + self._metadata = RequestQueueMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + had_multiple_clients=had_multiple_clients, + handled_request_count=handled_request_count, + pending_request_count=pending_request_count, + stats=stats, + total_request_count=total_request_count, + ) + + self._storage_dir = storage_dir + + # Internal attributes + self._lock = asyncio.Lock() + """A lock to ensure that only one operation is performed at a time.""" + + self._in_progress = set[str]() + """A set of request IDs that are currently being processed.""" + + self._forefront_requests = set[str]() + """A set of request IDs that should be prioritized (added with forefront=True).""" + + @override + @property + def metadata(self) -> RequestQueueMetadata: + return self._metadata + + @property + def path_to_rq(self) -> Path: + """The full path to the request queue directory.""" + return self._storage_dir / self._STORAGE_SUBDIR / self.metadata.name + + @property + def path_to_metadata(self) -> Path: + """The full path to the request queue metadata file.""" + return self.path_to_rq / METADATA_FILENAME + + @override + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> FileSystemRequestQueueClient: + if id: + raise ValueError( + 'Opening a request queue by "id" is not supported for file system storage client, use "name" instead.' + ) + + name = name or configuration.default_request_queue_id + + storage_dir = Path(configuration.storage_dir) + rq_path = storage_dir / cls._STORAGE_SUBDIR / name + metadata_path = rq_path / METADATA_FILENAME + + # If the RQ directory exists, reconstruct the client from the metadata file. + if rq_path.exists() and not configuration.purge_on_start: + # If metadata file is missing, raise an error. + if not metadata_path.exists(): + raise ValueError(f'Metadata file not found for request queue "{name}"') + + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = RequestQueueMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for request queue "{name}"') from exc + + client = cls( + id=metadata.id, + name=name, + created_at=metadata.created_at, + accessed_at=metadata.accessed_at, + modified_at=metadata.modified_at, + had_multiple_clients=metadata.had_multiple_clients, + handled_request_count=metadata.handled_request_count, + pending_request_count=metadata.pending_request_count, + stats=metadata.stats, + total_request_count=metadata.total_request_count, + storage_dir=storage_dir, + ) + + # Recalculate request counts from actual files to ensure consistency + handled_count = 0 + pending_count = 0 + request_files = await asyncio.to_thread(list, rq_path.glob('*.json')) + for request_file in request_files: + if request_file.name == METADATA_FILENAME: + continue + + try: + file = await asyncio.to_thread(open, request_file) + try: + data = json.load(file) + if data.get('handled_at') is not None: + handled_count += 1 + else: + pending_count += 1 + finally: + await asyncio.to_thread(file.close) + except (json.JSONDecodeError, ValidationError): + logger.warning(f'Failed to parse request file: {request_file}') + + await client._update_metadata( + update_accessed_at=True, + new_handled_request_count=handled_count, + new_pending_request_count=pending_count, + new_total_request_count=handled_count + pending_count, + ) + + # Otherwise, create a new dataset client. + else: + # If purge_on_start is true and the directory exists, remove it + if configuration.purge_on_start and rq_path.exists(): + await asyncio.to_thread(shutil.rmtree, rq_path) + + now = datetime.now(timezone.utc) + client = cls( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + had_multiple_clients=False, + handled_request_count=0, + pending_request_count=0, + stats={}, + total_request_count=0, + storage_dir=storage_dir, + ) + await client._update_metadata() + + return client + + @override + async def drop(self) -> None: + # If the client directory exists, remove it recursively. + if self.path_to_rq.exists(): + async with self._lock: + await asyncio.to_thread(shutil.rmtree, self.path_to_rq) + + @override + async def add_batch_of_requests( + self, + requests: Sequence[Request], + *, + forefront: bool = False, + ) -> AddRequestsResponse: + """Add a batch of requests to the queue. + + Args: + requests: The requests to add. + forefront: Whether to add the requests to the beginning of the queue. + + Returns: + Response containing information about the added requests. + """ + async with self._lock: + new_total_request_count = self._metadata.total_request_count + new_pending_request_count = self._metadata.pending_request_count + + processed_requests = [] + + # Create the requests directory if it doesn't exist + await asyncio.to_thread(self.path_to_rq.mkdir, parents=True, exist_ok=True) + + for request in requests: + # Ensure the request has an ID + if not request.id: + request.id = crypto_random_object_id() + + # Check if the request is already in the queue by unique_key + existing_request = None + + # List all request files and check for matching unique_key + request_files = await asyncio.to_thread(list, self.path_to_rq.glob('*.json')) + for request_file in request_files: + # Skip metadata file + if request_file.name == METADATA_FILENAME: + continue + + file = await asyncio.to_thread(open, request_file) + try: + file_content = json.load(file) + if file_content.get('unique_key') == request.unique_key: + existing_request = Request(**file_content) + break + except (json.JSONDecodeError, ValidationError): + logger.warning(f'Failed to parse request file: {request_file}') + finally: + await asyncio.to_thread(file.close) + + was_already_present = existing_request is not None + was_already_handled = ( + was_already_present and existing_request and existing_request.handled_at is not None + ) + + # If the request is already in the queue and handled, don't add it again + if was_already_handled and existing_request: + processed_requests.append( + ProcessedRequest( + id=existing_request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + ) + continue + + # If forefront and existing request is not handled, mark it as forefront + if forefront and was_already_present and not was_already_handled and existing_request: + self._forefront_requests.add(existing_request.id) + processed_requests.append( + ProcessedRequest( + id=existing_request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + ) + continue + + # If the request is already in the queue but not handled, update it + if was_already_present and existing_request: + # Update the existing request file + request_path = self.path_to_rq / f'{existing_request.id}.json' + request_data = await json_dumps(existing_request.model_dump()) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + processed_requests.append( + ProcessedRequest( + id=existing_request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + ) + continue + + # Add the new request to the queue + request_path = self.path_to_rq / f'{request.id}.json' + + # Create a data dictionary from the request and remove handled_at if it's None + request_dict = request.model_dump() + if request_dict.get('handled_at') is None: + request_dict.pop('handled_at', None) + + request_data = await json_dumps(request_dict) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + # Update metadata counts + new_total_request_count += 1 + new_pending_request_count += 1 + + # If forefront, add to the forefront set + if forefront: + self._forefront_requests.add(request.id) + + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=False, + was_already_handled=False, + ) + ) + + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_total_request_count=new_total_request_count, + new_pending_request_count=new_pending_request_count, + ) + + return AddRequestsResponse( + processed_requests=processed_requests, + unprocessed_requests=[], + ) + + @override + async def get_request(self, request_id: str) -> Request | None: + """Retrieve a request from the queue. + + Args: + request_id: ID of the request to retrieve. + + Returns: + The retrieved request, or None, if it did not exist. + """ + request_path = self.path_to_rq / f'{request_id}.json' + + if await asyncio.to_thread(request_path.exists): + file = await asyncio.to_thread(open, request_path) + try: + file_content = json.load(file) + return Request(**file_content) + except (json.JSONDecodeError, ValidationError) as exc: + logger.warning(f'Failed to parse request file {request_path}: {exc!s}') + finally: + await asyncio.to_thread(file.close) + + return None + + @override + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. + + Once you successfully finish processing of the request, you need to call `RequestQueue.mark_request_as_handled` + to mark the request as handled in the queue. If there was some error in processing the request, call + `RequestQueue.reclaim_request` instead, so that the queue will give the request to some other consumer + in another call to the `fetch_next_request` method. + + Returns: + The request or `None` if there are no more pending requests. + """ + async with self._lock: + # Create the requests directory if it doesn't exist + await asyncio.to_thread(self.path_to_rq.mkdir, parents=True, exist_ok=True) + + # List all request files + request_files = await asyncio.to_thread(list, self.path_to_rq.glob('*.json')) + + # First check for forefront requests + forefront_requests = [] + regular_requests = [] + + # Get file creation times for sorting regular requests in FIFO order + request_file_times = {} + + # Separate requests into forefront and regular + for request_file in request_files: + # Skip metadata file + if request_file.name == METADATA_FILENAME: + continue + + # Extract request ID from filename + request_id = request_file.stem + + # Skip if already in progress + if request_id in self._in_progress: + continue + + # Get file creation/modification time for FIFO ordering + try: + file_stat = await asyncio.to_thread(request_file.stat) + request_file_times[request_file] = file_stat.st_mtime + except Exception: + # If we can't get the time, use 0 (oldest) + request_file_times[request_file] = 0 + + if request_id in self._forefront_requests: + forefront_requests.append(request_file) + else: + regular_requests.append(request_file) + + # Sort regular requests by creation time (FIFO order) + regular_requests.sort(key=lambda f: request_file_times[f]) + + # Prioritize forefront requests + prioritized_files = forefront_requests + regular_requests + + # Process files in prioritized order + for request_file in prioritized_files: + request_id = request_file.stem + + file = await asyncio.to_thread(open, request_file) + try: + file_content = json.load(file) + # Skip if already handled + if file_content.get('handled_at') is not None: + continue + + # Create request object + request = Request(**file_content) + + # Mark as in-progress in memory + self._in_progress.add(request.id) + + # Remove from forefront set if it was there + self._forefront_requests.discard(request.id) + + # Update accessed timestamp + await self._update_metadata(update_accessed_at=True) + + except (json.JSONDecodeError, ValidationError) as exc: + logger.warning(f'Failed to parse request file {request_file}: {exc!s}') + else: + return request + finally: + await asyncio.to_thread(file.close) + + return None + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. + + Handled requests will never again be returned by the `fetch_next_request` method. + + Args: + request: The request to mark as handled. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + async with self._lock: + # Check if the request is in progress + if request.id not in self._in_progress: + return None + + # Remove from in-progress set + self._in_progress.discard(request.id) + + # Update the request object - set handled_at timestamp + if request.handled_at is None: + request.handled_at = datetime.now(timezone.utc) + + # Write the updated request back to the requests directory + request_path = self.path_to_rq / f'{request.id}.json' + + if not await asyncio.to_thread(request_path.exists): + return None + + request_data = await json_dumps(request.model_dump()) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + # Update metadata timestamps + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_handled_request_count=self._metadata.handled_request_count + 1, + new_pending_request_count=self._metadata.pending_request_count - 1, + ) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + + @override + async def reclaim_request( + self, + request: Request, + *, + forefront: bool = False, + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. + + The request will be returned for processing later again by another call to `fetch_next_request`. + + Args: + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + async with self._lock: + # Check if the request is in progress + if request.id not in self._in_progress: + return None + + # Remove from in-progress set + self._in_progress.discard(request.id) + + # If forefront is true, mark this request as priority + if forefront: + self._forefront_requests.add(request.id) + else: + # Make sure it's not in the forefront set if it was previously added there + self._forefront_requests.discard(request.id) + + # To simulate changing the file timestamp for FIFO ordering, + # we'll update the file with current timestamp + request_path = self.path_to_rq / f'{request.id}.json' + + if not await asyncio.to_thread(request_path.exists): + return None + + request_data = await json_dumps(request.model_dump()) + await asyncio.to_thread(request_path.write_text, request_data, encoding='utf-8') + + # Update metadata timestamps + await self._update_metadata(update_modified_at=True, update_accessed_at=True) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + + @override + async def is_empty(self) -> bool: + """Check if the queue is empty. + + Returns: + True if the queue is empty, False otherwise. + """ + # Update accessed timestamp when checking if queue is empty + await self._update_metadata(update_accessed_at=True) + + # Create the requests directory if it doesn't exist + await asyncio.to_thread(self.path_to_rq.mkdir, parents=True, exist_ok=True) + + # List all request files + request_files = await asyncio.to_thread(list, self.path_to_rq.glob('*.json')) + + # Check each file to see if there are any unhandled requests + for request_file in request_files: + # Skip metadata file + if request_file.name == METADATA_FILENAME: + continue + + file = await asyncio.to_thread(open, request_file) + try: + file_content = json.load(file) + # If any request is not handled, the queue is not empty + if file_content.get('handled_at') is None: + return False + except (json.JSONDecodeError, ValidationError): + logger.warning(f'Failed to parse request file: {request_file}') + finally: + await asyncio.to_thread(file.close) + + # If we got here, all requests are handled or there are no requests + return True + + async def _update_metadata( + self, + *, + new_handled_request_count: int | None = None, + new_pending_request_count: int | None = None, + new_total_request_count: int | None = None, + update_had_multiple_clients: bool = False, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the dataset metadata file with current information. + + Args: + new_handled_request_count: If provided, update the handled_request_count to this value. + new_pending_request_count: If provided, update the pending_request_count to this value. + new_total_request_count: If provided, update the total_request_count to this value. + update_had_multiple_clients: If True, set had_multiple_clients to True. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + # Always create a new timestamp to ensure it's truly updated + now = datetime.now(timezone.utc) + + # Update timestamps according to parameters + if update_accessed_at: + self._metadata.accessed_at = now + + if update_modified_at: + self._metadata.modified_at = now + + # Update request counts if provided + if new_handled_request_count is not None: + self._metadata.handled_request_count = new_handled_request_count + + if new_pending_request_count is not None: + self._metadata.pending_request_count = new_pending_request_count + + if new_total_request_count is not None: + self._metadata.total_request_count = new_total_request_count + + if update_had_multiple_clients: + self._metadata.had_multiple_clients = True + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(self._metadata.model_dump()) + await asyncio.to_thread(self.path_to_metadata.write_text, data, encoding='utf-8') diff --git a/src/crawlee/storage_clients/_file_system/_storage_client.py b/src/crawlee/storage_clients/_file_system/_storage_client.py new file mode 100644 index 0000000000..2765d15536 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_storage_client.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing_extensions import override + +from crawlee.configuration import Configuration +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import FileSystemDatasetClient +from ._key_value_store_client import FileSystemKeyValueStoreClient +from ._request_queue_client import FileSystemRequestQueueClient + + +class FileSystemStorageClient(StorageClient): + """File system storage client.""" + + @override + async def open_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> FileSystemDatasetClient: + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemDatasetClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await FileSystemDatasetClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_key_value_store_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> FileSystemKeyValueStoreClient: + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await FileSystemKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_request_queue_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> FileSystemRequestQueueClient: + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemRequestQueueClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await FileSystemRequestQueueClient.open(id=id, name=name, configuration=configuration) + + return client diff --git a/src/crawlee/storage_clients/_file_system/_utils.py b/src/crawlee/storage_clients/_file_system/_utils.py new file mode 100644 index 0000000000..c172df50cc --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_utils.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import asyncio +import json +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + +METADATA_FILENAME = '__metadata__.json' +"""The name of the metadata file for storage clients.""" + + +async def json_dumps(obj: Any) -> str: + """Serialize an object to a JSON-formatted string with specific settings. + + Args: + obj: The object to serialize. + + Returns: + A string containing the JSON representation of the input object. + """ + return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) diff --git a/src/crawlee/storage_clients/_file_system/py.typed b/src/crawlee/storage_clients/_file_system/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/crawlee/storage_clients/_memory/__init__.py b/src/crawlee/storage_clients/_memory/__init__.py index 09912e124d..3746907b4f 100644 --- a/src/crawlee/storage_clients/_memory/__init__.py +++ b/src/crawlee/storage_clients/_memory/__init__.py @@ -1,17 +1,11 @@ -from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient -from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient -from ._memory_storage_client import MemoryStorageClient -from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient +from ._dataset_client import MemoryDatasetClient +from ._key_value_store_client import MemoryKeyValueStoreClient +from ._request_queue_client import MemoryRequestQueueClient +from ._storage_client import MemoryStorageClient __all__ = [ - 'DatasetClient', - 'DatasetCollectionClient', - 'KeyValueStoreClient', - 'KeyValueStoreCollectionClient', + 'MemoryDatasetClient', + 'MemoryKeyValueStoreClient', + 'MemoryRequestQueueClient', 'MemoryStorageClient', - 'RequestQueueClient', - 'RequestQueueCollectionClient', ] diff --git a/src/crawlee/storage_clients/_memory/_creation_management.py b/src/crawlee/storage_clients/_memory/_creation_management.py deleted file mode 100644 index f6d4fc1c91..0000000000 --- a/src/crawlee/storage_clients/_memory/_creation_management.py +++ /dev/null @@ -1,429 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import mimetypes -import os -import pathlib -from datetime import datetime, timezone -from logging import getLogger -from typing import TYPE_CHECKING - -from crawlee._consts import METADATA_FILENAME -from crawlee._utils.data_processing import maybe_parse_body -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import ( - DatasetMetadata, - InternalRequest, - KeyValueStoreMetadata, - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, - RequestQueueMetadata, -) - -if TYPE_CHECKING: - from ._dataset_client import DatasetClient - from ._key_value_store_client import KeyValueStoreClient - from ._memory_storage_client import MemoryStorageClient, TResourceClient - from ._request_queue_client import RequestQueueClient - -logger = getLogger(__name__) - - -async def persist_metadata_if_enabled(*, data: dict, entity_directory: str, write_metadata: bool) -> None: - """Update or writes metadata to a specified directory. - - The function writes a given metadata dictionary to a JSON file within a specified directory. - The writing process is skipped if `write_metadata` is False. Before writing, it ensures that - the target directory exists, creating it if necessary. - - Args: - data: A dictionary containing metadata to be written. - entity_directory: The directory path where the metadata file should be stored. - write_metadata: A boolean flag indicating whether the metadata should be written to file. - """ - # Skip metadata write; ensure directory exists first - if not write_metadata: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Write the metadata to the file - file_path = os.path.join(entity_directory, METADATA_FILENAME) - f = await asyncio.to_thread(open, file_path, mode='wb') - try: - s = await json_dumps(data) - await asyncio.to_thread(f.write, s.encode('utf-8')) - finally: - await asyncio.to_thread(f.close) - - -def find_or_create_client_by_id_or_name_inner( - resource_client_class: type[TResourceClient], - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> TResourceClient | None: - """Locate or create a new storage client based on the given ID or name. - - This method attempts to find a storage client in the memory cache first. If not found, - it tries to locate a storage directory by name. If still not found, it searches through - storage directories for a matching ID or name in their metadata. If none exists, and the - specified ID is 'default', it checks for a default storage directory. If a storage client - is found or created, it is added to the memory cache. If no storage client can be located or - created, the method returns None. - - Args: - resource_client_class: The class of the resource client. - memory_storage_client: The memory storage client used to store and retrieve storage clients. - id: The unique identifier for the storage client. - name: The name of the storage client. - - Raises: - ValueError: If both id and name are None. - - Returns: - The found or created storage client, or None if no client could be found or created. - """ - from ._dataset_client import DatasetClient - from ._key_value_store_client import KeyValueStoreClient - from ._request_queue_client import RequestQueueClient - - if id is None and name is None: - raise ValueError('Either id or name must be specified.') - - # First check memory cache - found = memory_storage_client.get_cached_resource_client(resource_client_class, id, name) - - if found is not None: - return found - - storage_path = _determine_storage_path(resource_client_class, memory_storage_client, id, name) - - if not storage_path: - return None - - # Create from directory if storage path is found - if issubclass(resource_client_class, DatasetClient): - resource_client = create_dataset_from_directory(storage_path, memory_storage_client, id, name) - elif issubclass(resource_client_class, KeyValueStoreClient): - resource_client = create_kvs_from_directory(storage_path, memory_storage_client, id, name) - elif issubclass(resource_client_class, RequestQueueClient): - resource_client = create_rq_from_directory(storage_path, memory_storage_client, id, name) - else: - raise TypeError('Invalid resource client class.') - - memory_storage_client.add_resource_client_to_cache(resource_client) - return resource_client - - -async def get_or_create_inner( - *, - memory_storage_client: MemoryStorageClient, - storage_client_cache: list[TResourceClient], - resource_client_class: type[TResourceClient], - name: str | None = None, - id: str | None = None, -) -> TResourceClient: - """Retrieve a named storage, or create a new one when it doesn't exist. - - Args: - memory_storage_client: The memory storage client. - storage_client_cache: The cache of storage clients. - resource_client_class: The class of the storage to retrieve or create. - name: The name of the storage to retrieve or create. - id: ID of the storage to retrieve or create. - - Returns: - The retrieved or newly-created storage. - """ - # If the name or id is provided, try to find the dataset in the cache - if name or id: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=resource_client_class, - memory_storage_client=memory_storage_client, - name=name, - id=id, - ) - if found: - return found - - # Otherwise, create a new one and add it to the cache - resource_client = resource_client_class( - id=id, - name=name, - memory_storage_client=memory_storage_client, - ) - - storage_client_cache.append(resource_client) - - # Write to the disk - await persist_metadata_if_enabled( - data=resource_client.resource_info.model_dump(), - entity_directory=resource_client.resource_directory, - write_metadata=memory_storage_client.write_metadata, - ) - - return resource_client - - -def create_dataset_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> DatasetClient: - from ._dataset_client import DatasetClient - - item_count = 0 - has_seen_metadata_file = False - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - has_seen_metadata_file = True - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = DatasetMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - item_count = resource_info.item_count - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - - # Load dataset entries - entries: dict[str, dict] = {} - - for entry in os.scandir(storage_directory): - if entry.is_file(): - if entry.name == METADATA_FILENAME: - has_seen_metadata_file = True - continue - - with open(os.path.join(storage_directory, entry.name), encoding='utf-8') as f: - entry_content = json.load(f) - - entry_name = entry.name.split('.')[0] - entries[entry_name] = entry_content - - if not has_seen_metadata_file: - item_count += 1 - - # Create new dataset client - new_client = DatasetClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - created_at=created_at, - accessed_at=accessed_at, - modified_at=modified_at, - item_count=item_count, - ) - - new_client.dataset_entries.update(entries) - return new_client - - -def create_kvs_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> KeyValueStoreClient: - from ._key_value_store_client import KeyValueStoreClient - - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = KeyValueStoreMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - - # Create new KVS client - new_client = KeyValueStoreClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - accessed_at=accessed_at, - created_at=created_at, - modified_at=modified_at, - ) - - # Scan the KVS folder, check each entry in there and parse it as a store record - for entry in os.scandir(storage_directory): - if not entry.is_file(): - continue - - # Ignore metadata files on their own - if entry.name.endswith(METADATA_FILENAME): - continue - - # Try checking if this file has a metadata file associated with it - record_metadata = None - record_metadata_filepath = os.path.join(storage_directory, f'{entry.name}.__metadata__.json') - - if os.path.exists(record_metadata_filepath): - with open(record_metadata_filepath, encoding='utf-8') as metadata_file: - try: - json_content = json.load(metadata_file) - record_metadata = KeyValueStoreRecordMetadata(**json_content) - - except Exception: - logger.warning( - f'Metadata of key-value store entry "{entry.name}" for store {name or id} could ' - 'not be parsed. The metadata file will be ignored.', - exc_info=True, - ) - - if not record_metadata: - content_type, _ = mimetypes.guess_type(entry.name) - if content_type is None: - content_type = 'application/octet-stream' - - record_metadata = KeyValueStoreRecordMetadata( - key=pathlib.Path(entry.name).stem, - content_type=content_type, - ) - - with open(os.path.join(storage_directory, entry.name), 'rb') as f: - file_content = f.read() - - try: - maybe_parse_body(file_content, record_metadata.content_type) - except Exception: - record_metadata.content_type = 'application/octet-stream' - logger.warning( - f'Key-value store entry "{record_metadata.key}" for store {name or id} could not be parsed.' - 'The entry will be assumed as binary.', - exc_info=True, - ) - - new_client.records[record_metadata.key] = KeyValueStoreRecord( - key=record_metadata.key, - content_type=record_metadata.content_type, - filename=entry.name, - value=file_content, - ) - - return new_client - - -def create_rq_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> RequestQueueClient: - from ._request_queue_client import RequestQueueClient - - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - handled_request_count = 0 - pending_request_count = 0 - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = RequestQueueMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - handled_request_count = resource_info.handled_request_count - pending_request_count = resource_info.pending_request_count - - # Load request entries - entries: dict[str, InternalRequest] = {} - - for entry in os.scandir(storage_directory): - if entry.is_file(): - if entry.name == METADATA_FILENAME: - continue - - with open(os.path.join(storage_directory, entry.name), encoding='utf-8') as f: - content = json.load(f) - - request = InternalRequest(**content) - - entries[request.id] = request - - # Create new RQ client - new_client = RequestQueueClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - accessed_at=accessed_at, - created_at=created_at, - modified_at=modified_at, - handled_request_count=handled_request_count, - pending_request_count=pending_request_count, - ) - - new_client.requests.update(entries) - return new_client - - -def _determine_storage_path( - resource_client_class: type[TResourceClient], - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> str | None: - storages_dir = memory_storage_client._get_storage_dir(resource_client_class) # noqa: SLF001 - default_id = memory_storage_client._get_default_storage_id(resource_client_class) # noqa: SLF001 - - # Try to find by name directly from directories - if name: - possible_storage_path = os.path.join(storages_dir, name) - if os.access(possible_storage_path, os.F_OK): - return possible_storage_path - - # If not found, try finding by metadata - if os.access(storages_dir, os.F_OK): - for entry in os.scandir(storages_dir): - if entry.is_dir(): - metadata_path = os.path.join(entry.path, METADATA_FILENAME) - if os.access(metadata_path, os.F_OK): - with open(metadata_path, encoding='utf-8') as metadata_file: - try: - metadata = json.load(metadata_file) - if (id and metadata.get('id') == id) or (name and metadata.get('name') == name): - return entry.path - except Exception: - logger.warning( - f'Metadata of store entry "{entry.name}" for store {name or id} could not be parsed. ' - 'The metadata file will be ignored.', - exc_info=True, - ) - - # Check for default storage directory as a last resort - if id == default_id: - possible_storage_path = os.path.join(storages_dir, default_id) - if os.access(possible_storage_path, os.F_OK): - return possible_storage_path - - return None diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 50c8c7c8d4..63e75eabb0 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -1,162 +1,119 @@ from __future__ import annotations -import asyncio -import json -import os -import shutil from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any from typing_extensions import override -from crawlee._types import StorageTypes from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import force_rename, json_dumps -from crawlee.storage_clients._base import DatasetClient as BaseDatasetClient +from crawlee.storage_clients._base import DatasetClient from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata -from ._creation_management import find_or_create_client_by_id_or_name_inner - if TYPE_CHECKING: from collections.abc import AsyncIterator - from contextlib import AbstractAsyncContextManager - - from httpx import Response - from crawlee._types import JsonSerializable - from crawlee.storage_clients import MemoryStorageClient + from crawlee.configuration import Configuration logger = getLogger(__name__) -class DatasetClient(BaseDatasetClient): - """Subclient for manipulating a single dataset.""" +class MemoryDatasetClient(DatasetClient): + """Memory implementation of the dataset client. - _LIST_ITEMS_LIMIT = 999_999_999_999 - """This is what API returns in the x-apify-pagination-limit header when no limit query parameter is used.""" + This client stores dataset items in memory using Python lists and dictionaries. No data is persisted + between process runs, meaning all stored data is lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler operations where persistent + storage is not required. - _LOCAL_ENTRY_NAME_DIGITS = 9 - """Number of characters of the dataset item file names, e.g.: 000000019.json - 9 digits.""" + The memory implementation provides fast access to data but is limited by available memory and + does not support data sharing across different processes. It supports all dataset operations including + sorting, filtering, and pagination, but performs them entirely in memory. + """ def __init__( self, *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, - item_count: int = 0, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + item_count: int, ) -> None: - self._memory_storage_client = memory_storage_client - self.id = id or crypto_random_object_id() - self.name = name - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) - - self.dataset_entries: dict[str, dict] = {} - self.file_operation_lock = asyncio.Lock() - self.item_count = item_count + """Initialize a new instance. - @property - def resource_info(self) -> DatasetMetadata: - """Get the resource info for the dataset client.""" - return DatasetMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - item_count=self.item_count, + Preferably use the `MemoryDatasetClient.open` class method to create a new instance. + """ + self._metadata = DatasetMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + item_count=item_count, ) - @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.datasets_directory, self.name or self.id) + # List to hold dataset items + self._records = list[dict[str, Any]]() @override - async def get(self) -> DatasetMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info - - return None + @property + def metadata(self) -> DatasetMetadata: + return self._metadata @override - async def update(self, *, name: str | None = None) -> DatasetMetadata: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> MemoryDatasetClient: + name = name or configuration.default_dataset_id + + dataset_id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + return cls( + id=dataset_id, + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + item_count=0, ) - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) - - # Skip if no changes - if name is None: - return existing_dataset_by_id.resource_info - - async with existing_dataset_by_id.file_operation_lock: - # Check that name is not in use already - existing_dataset_by_name = next( - ( - dataset - for dataset in self._memory_storage_client.datasets_handled - if dataset.name and dataset.name.lower() == name.lower() - ), - None, - ) - - if existing_dataset_by_name is not None: - raise_on_duplicate_storage(StorageTypes.DATASET, 'name', name) - - previous_dir = existing_dataset_by_id.resource_directory - existing_dataset_by_id.name = name - - await force_rename(previous_dir, existing_dataset_by_id.resource_directory) - - # Update timestamps - await existing_dataset_by_id.update_timestamps(has_been_modified=True) - - return existing_dataset_by_id.resource_info + @override + async def drop(self) -> None: + self._records.clear() + self._metadata.item_count = 0 @override - async def delete(self) -> None: - dataset = next( - (dataset for dataset in self._memory_storage_client.datasets_handled if dataset.id == self.id), None + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + new_item_count = self.metadata.item_count + + if isinstance(data, list): + for item in data: + new_item_count += 1 + await self._push_item(item) + else: + new_item_count += 1 + await self._push_item(data) + + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=new_item_count, ) - if dataset is not None: - async with dataset.file_operation_lock: - self._memory_storage_client.datasets_handled.remove(dataset) - dataset.item_count = 0 - dataset.dataset_entries.clear() - - if os.path.exists(dataset.resource_directory): - await asyncio.to_thread(shutil.rmtree, dataset.resource_directory) - @override - async def list_items( + async def get_data( self, *, - offset: int | None = 0, - limit: int | None = _LIST_ITEMS_LIMIT, + offset: int = 0, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -167,44 +124,48 @@ async def list_items( flatten: list[str] | None = None, view: str | None = None, ) -> DatasetItemsListPage: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) - - async with existing_dataset_by_id.file_operation_lock: - start, end = existing_dataset_by_id.get_start_and_end_indexes( - max(existing_dataset_by_id.item_count - (offset or 0) - (limit or self._LIST_ITEMS_LIMIT), 0) - if desc - else offset or 0, - limit, + # Check for unsupported arguments and log a warning if found + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + 'flatten': flatten, + 'view': view, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of get_data are not supported ' + f'by the {self.__class__.__name__} client.' ) - items = [] + total = len(self._records) + items = self._records.copy() - for idx in range(start, end): - entry_number = self._generate_local_entry_name(idx) - items.append(existing_dataset_by_id.dataset_entries[entry_number]) + # Apply skip_empty filter if requested + if skip_empty: + items = [item for item in items if item] - await existing_dataset_by_id.update_timestamps(has_been_modified=False) + # Apply sorting + if desc: + items = list(reversed(items)) - if desc: - items.reverse() + # Apply pagination + sliced_items = items[offset : (offset + limit) if limit is not None else total] - return DatasetItemsListPage( - count=len(items), - desc=desc or False, - items=items, - limit=limit or self._LIST_ITEMS_LIMIT, - offset=offset or 0, - total=existing_dataset_by_id.item_count, - ) + await self._update_metadata(update_accessed_at=True) + + return DatasetItemsListPage( + count=len(sliced_items), + offset=offset, + limit=limit or (total - offset), + total=total, + desc=desc, + items=sliced_items, + ) @override async def iterate_items( @@ -220,191 +181,66 @@ async def iterate_items( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict]: - cache_size = 1000 - first_item = offset - - # If there is no limit, set last_item to None until we get the total from the first API response - last_item = None if limit is None else offset + limit - current_offset = first_item - - while last_item is None or current_offset < last_item: - current_limit = cache_size if last_item is None else min(cache_size, last_item - current_offset) - - current_items_page = await self.list_items( - offset=current_offset, - limit=current_limit, - desc=desc, + # Check for unsupported arguments and log a warning if found + unsupported_args = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of iterate are not supported ' + f'by the {self.__class__.__name__} client.' ) - current_offset += current_items_page.count - if last_item is None or current_items_page.total < last_item: - last_item = current_items_page.total - - for item in current_items_page.items: - yield item - - @override - async def get_items_as_bytes( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - flatten: list[str] | None = None, - ) -> bytes: - raise NotImplementedError('This method is not supported in memory storage.') - - @override - async def stream_items( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - ) -> AbstractAsyncContextManager[Response | None]: - raise NotImplementedError('This method is not supported in memory storage.') + items = self._records.copy() - @override - async def push_items( - self, - items: JsonSerializable, - ) -> None: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) + # Apply sorting + if desc: + items = list(reversed(items)) - normalized = self._normalize_items(items) + # Apply pagination + sliced_items = items[offset : (offset + limit) if limit is not None else len(items)] - added_ids: list[str] = [] - for entry in normalized: - existing_dataset_by_id.item_count += 1 - idx = self._generate_local_entry_name(existing_dataset_by_id.item_count) + # Yield items one by one + for item in sliced_items: + if skip_empty and not item: + continue + yield item - existing_dataset_by_id.dataset_entries[idx] = entry - added_ids.append(idx) - - data_entries = [(id, existing_dataset_by_id.dataset_entries[id]) for id in added_ids] - - async with existing_dataset_by_id.file_operation_lock: - await existing_dataset_by_id.update_timestamps(has_been_modified=True) - - await self._persist_dataset_items_to_disk( - data=data_entries, - entity_directory=existing_dataset_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, - ) + await self._update_metadata(update_accessed_at=True) - async def _persist_dataset_items_to_disk( + async def _update_metadata( self, *, - data: list[tuple[str, dict]], - entity_directory: str, - persist_storage: bool, + new_item_count: int | None = None, + update_accessed_at: bool = False, + update_modified_at: bool = False, ) -> None: - """Write dataset items to the disk. - - The function iterates over a list of dataset items, each represented as a tuple of an identifier - and a dictionary, and writes them as individual JSON files in a specified directory. The function - will skip writing if `persist_storage` is False. Before writing, it ensures that the target - directory exists, creating it if necessary. + """Update the dataset metadata with current information. Args: - data: A list of tuples, each containing an identifier (string) and a data dictionary. - entity_directory: The directory path where the dataset items should be stored. - persist_storage: A boolean flag indicating whether the data should be persisted to the disk. + new_item_count: If provided, update the item count to this value. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. """ - # Skip writing files to the disk if the client has the option set to false - if not persist_storage: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Save all the new items to the disk - for idx, item in data: - file_path = os.path.join(entity_directory, f'{idx}.json') - f = await asyncio.to_thread(open, file_path, mode='w', encoding='utf-8') - try: - s = await json_dumps(item) - await asyncio.to_thread(f.write, s) - finally: - await asyncio.to_thread(f.close) - - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the dataset.""" - from ._creation_management import persist_metadata_if_enabled - - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) - - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) - - def get_start_and_end_indexes(self, offset: int, limit: int | None = None) -> tuple[int, int]: - """Calculate the start and end indexes for listing items.""" - actual_limit = limit or self.item_count - start = offset + 1 - end = min(offset + actual_limit, self.item_count) + 1 - return (start, end) - - def _generate_local_entry_name(self, idx: int) -> str: - return str(idx).zfill(self._LOCAL_ENTRY_NAME_DIGITS) - - def _normalize_items(self, items: JsonSerializable) -> list[dict]: - def normalize_item(item: Any) -> dict | None: - if isinstance(item, str): - item = json.loads(item) + now = datetime.now(timezone.utc) - if isinstance(item, list): - received = ',\n'.join(item) - raise TypeError( - f'Each dataset item can only be a single JSON object, not an array. Received: [{received}]' - ) + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + if new_item_count: + self._metadata.item_count = new_item_count - if (not isinstance(item, dict)) and item is not None: - raise TypeError(f'Each dataset item must be a JSON object. Received: {item}') + async def _push_item(self, item: dict[str, Any]) -> None: + """Push a single item to the dataset. - return item - - if isinstance(items, str): - items = json.loads(items) - - result = list(map(normalize_item, items)) if isinstance(items, list) else [normalize_item(items)] - # filter(None, ..) returns items that are True - return list(filter(None, result)) + Args: + item: The data item to add to the dataset. + """ + self._records.append(item) diff --git a/src/crawlee/storage_clients/_memory/_dataset_collection_client.py b/src/crawlee/storage_clients/_memory/_dataset_collection_client.py deleted file mode 100644 index 9e32b4086b..0000000000 --- a/src/crawlee/storage_clients/_memory/_dataset_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import DatasetCollectionClient as BaseDatasetCollectionClient -from crawlee.storage_clients.models import DatasetListPage, DatasetMetadata - -from ._creation_management import get_or_create_inner -from ._dataset_client import DatasetClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class DatasetCollectionClient(BaseDatasetCollectionClient): - """Subclient for manipulating datasets.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[DatasetClient]: - return self._memory_storage_client.datasets_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> DatasetMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=DatasetClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> DatasetListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return DatasetListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index ab9def0f06..e9b91702c7 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -1,425 +1,172 @@ from __future__ import annotations -import asyncio -import io -import os -import shutil +import sys from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any from typing_extensions import override -from crawlee._types import StorageTypes from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import maybe_parse_body, raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import determine_file_extension, force_remove, force_rename, is_file_or_bytes, json_dumps -from crawlee.storage_clients._base import KeyValueStoreClient as BaseKeyValueStoreClient -from crawlee.storage_clients.models import ( - KeyValueStoreKeyInfo, - KeyValueStoreListKeysPage, - KeyValueStoreMetadata, - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, -) - -from ._creation_management import find_or_create_client_by_id_or_name_inner, persist_metadata_if_enabled +from crawlee._utils.file import infer_mime_type +from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata if TYPE_CHECKING: - from contextlib import AbstractAsyncContextManager + from collections.abc import AsyncIterator - from httpx import Response - - from crawlee.storage_clients import MemoryStorageClient + from crawlee.configuration import Configuration logger = getLogger(__name__) -class KeyValueStoreClient(BaseKeyValueStoreClient): - """Subclient for manipulating a single key-value store.""" +class MemoryKeyValueStoreClient(KeyValueStoreClient): + """Memory implementation of the key-value store client. + + This client stores data in memory as Python dictionaries. No data is persisted between + process runs, meaning all stored data is lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler operations where + persistence is not required. + + The memory implementation provides fast access to data but is limited by available memory and + does not support data sharing across different processes. + """ def __init__( self, *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, ) -> None: - self.id = id or crypto_random_object_id() - self.name = name - - self._memory_storage_client = memory_storage_client - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) - - self.records: dict[str, KeyValueStoreRecord] = {} - self.file_operation_lock = asyncio.Lock() - - @property - def resource_info(self) -> KeyValueStoreMetadata: - """Get the resource info for the key-value store client.""" - return KeyValueStoreMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - user_id='1', + """Initialize a new instance. + + Preferably use the `MemoryKeyValueStoreClient.open` class method to create a new instance. + """ + self._metadata = KeyValueStoreMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, ) + # Dictionary to hold key-value records with metadata + self._records = dict[str, KeyValueStoreRecord]() + + @override @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.key_value_stores_directory, self.name or self.id) + def metadata(self) -> KeyValueStoreMetadata: + return self._metadata @override - async def get(self) -> KeyValueStoreMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> MemoryKeyValueStoreClient: + name = name or configuration.default_key_value_store_id + + # If specific id is provided, use it; otherwise, generate a new one + id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + return cls( + id=id, + name=name, + created_at=now, + accessed_at=now, + modified_at=now, ) - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info - - return None - @override - async def update(self, *, name: str | None = None) -> KeyValueStoreMetadata: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + async def drop(self) -> None: + # Clear all data + self._records.clear() - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - # Skip if no changes - if name is None: - return existing_store_by_id.resource_info - - async with existing_store_by_id.file_operation_lock: - # Check that name is not in use already - existing_store_by_name = next( - ( - store - for store in self._memory_storage_client.key_value_stores_handled - if store.name and store.name.lower() == name.lower() - ), - None, - ) + @override + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + await self._update_metadata(update_accessed_at=True) - if existing_store_by_name is not None: - raise_on_duplicate_storage(StorageTypes.KEY_VALUE_STORE, 'name', name) + # Return None if key doesn't exist + return self._records.get(key, None) - previous_dir = existing_store_by_id.resource_directory - existing_store_by_id.name = name + @override + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + content_type = content_type or infer_mime_type(value) + size = sys.getsizeof(value) - await force_rename(previous_dir, existing_store_by_id.resource_directory) + # Create and store the record + record = KeyValueStoreRecord( + key=key, + value=value, + content_type=content_type, + size=size, + ) - # Update timestamps - await existing_store_by_id.update_timestamps(has_been_modified=True) + self._records[key] = record - return existing_store_by_id.resource_info + await self._update_metadata(update_accessed_at=True, update_modified_at=True) @override - async def delete(self) -> None: - store = next( - (store for store in self._memory_storage_client.key_value_stores_handled if store.id == self.id), None - ) - - if store is not None: - async with store.file_operation_lock: - self._memory_storage_client.key_value_stores_handled.remove(store) - store.records.clear() - - if os.path.exists(store.resource_directory): - await asyncio.to_thread(shutil.rmtree, store.resource_directory) + async def delete_value(self, *, key: str) -> None: + if key in self._records: + del self._records[key] + await self._update_metadata(update_accessed_at=True, update_modified_at=True) @override - async def list_keys( + async def iterate_keys( self, *, - limit: int = 1000, exclusive_start_key: str | None = None, - ) -> KeyValueStoreListKeysPage: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + await self._update_metadata(update_accessed_at=True) - items: list[KeyValueStoreKeyInfo] = [] + # Get all keys, sorted alphabetically + keys = sorted(self._records.keys()) - for record in existing_store_by_id.records.values(): - size = len(record.value) - items.append(KeyValueStoreKeyInfo(key=record.key, size=size)) - - if len(items) == 0: - return KeyValueStoreListKeysPage( - count=len(items), - limit=limit, - exclusive_start_key=exclusive_start_key, - is_truncated=False, - next_exclusive_start_key=None, - items=items, - ) - - # Lexically sort to emulate the API - items = sorted(items, key=lambda item: item.key) - - truncated_items = items + # Apply exclusive_start_key filter if provided if exclusive_start_key is not None: - key_pos = next((idx for idx, item in enumerate(items) if item.key == exclusive_start_key), None) - if key_pos is not None: - truncated_items = items[(key_pos + 1) :] - - limited_items = truncated_items[:limit] - - last_item_in_store = items[-1] - last_selected_item = limited_items[-1] - is_last_selected_item_absolutely_last = last_item_in_store == last_selected_item - next_exclusive_start_key = None if is_last_selected_item_absolutely_last else last_selected_item.key - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=False) - - return KeyValueStoreListKeysPage( - count=len(items), - limit=limit, - exclusive_start_key=exclusive_start_key, - is_truncated=not is_last_selected_item_absolutely_last, - next_exclusive_start_key=next_exclusive_start_key, - items=limited_items, - ) - - @override - async def get_record(self, key: str) -> KeyValueStoreRecord | None: - return await self._get_record_internal(key) - - @override - async def get_record_as_bytes(self, key: str) -> KeyValueStoreRecord[bytes] | None: - return await self._get_record_internal(key, as_bytes=True) - - @override - async def stream_record(self, key: str) -> AbstractAsyncContextManager[KeyValueStoreRecord[Response] | None]: - raise NotImplementedError('This method is not supported in memory storage.') - - @override - async def set_record(self, key: str, value: Any, content_type: str | None = None) -> None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - if isinstance(value, io.IOBase): - raise NotImplementedError('File-like values are not supported in local memory storage') - - if content_type is None: - if is_file_or_bytes(value): - content_type = 'application/octet-stream' - elif isinstance(value, str): - content_type = 'text/plain; charset=utf-8' - else: - content_type = 'application/json; charset=utf-8' - - if 'application/json' in content_type and not is_file_or_bytes(value) and not isinstance(value, str): - s = await json_dumps(value) - value = s.encode('utf-8') - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=True) - record = KeyValueStoreRecord(key=key, value=value, content_type=content_type, filename=None) - - old_record = existing_store_by_id.records.get(key) - existing_store_by_id.records[key] = record - - if self._memory_storage_client.persist_storage: - record_filename = self._filename_from_record(record) - record.filename = record_filename - - if old_record is not None and self._filename_from_record(old_record) != record_filename: - await existing_store_by_id.delete_persisted_record(old_record) - - await existing_store_by_id.persist_record(record) - - @override - async def delete_record(self, key: str) -> None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - record = existing_store_by_id.records.get(key) - - if record is not None: - async with existing_store_by_id.file_operation_lock: - del existing_store_by_id.records[key] - await existing_store_by_id.update_timestamps(has_been_modified=True) - if self._memory_storage_client.persist_storage: - await existing_store_by_id.delete_persisted_record(record) + keys = [k for k in keys if k > exclusive_start_key] + + # Apply limit if provided + if limit is not None: + keys = keys[:limit] + + # Yield metadata for each key + for key in keys: + record = self._records[key] + yield KeyValueStoreRecordMetadata( + key=key, + content_type=record.content_type, + size=record.size, + ) @override - async def get_public_url(self, key: str) -> str: - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - record = await self._get_record_internal(key) - - if not record: - raise ValueError(f'Record with key "{key}" was not found.') - - resource_dir = existing_store_by_id.resource_directory - record_filename = self._filename_from_record(record) - record_path = os.path.join(resource_dir, record_filename) - return f'file://{record_path}' - - async def persist_record(self, record: KeyValueStoreRecord) -> None: - """Persist the specified record to the key-value store.""" - store_directory = self.resource_directory - record_filename = self._filename_from_record(record) - record.filename = record_filename - record.content_type = record.content_type or 'application/octet-stream' - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, store_directory, exist_ok=True) - - # Create files for the record - record_path = os.path.join(store_directory, record_filename) - record_metadata_path = os.path.join(store_directory, f'{record_filename}.__metadata__.json') - - # Convert to bytes if string - if isinstance(record.value, str): - record.value = record.value.encode('utf-8') - - f = await asyncio.to_thread(open, record_path, mode='wb') - try: - await asyncio.to_thread(f.write, record.value) - finally: - await asyncio.to_thread(f.close) - - if self._memory_storage_client.write_metadata: - metadata_f = await asyncio.to_thread(open, record_metadata_path, mode='wb') - - try: - record_metadata = KeyValueStoreRecordMetadata(key=record.key, content_type=record.content_type) - await asyncio.to_thread(metadata_f.write, record_metadata.model_dump_json(indent=2).encode('utf-8')) - finally: - await asyncio.to_thread(metadata_f.close) - - async def delete_persisted_record(self, record: KeyValueStoreRecord) -> None: - """Delete the specified record from the key-value store.""" - store_directory = self.resource_directory - record_filename = self._filename_from_record(record) - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, store_directory, exist_ok=True) + async def get_public_url(self, *, key: str) -> str: + raise NotImplementedError('Public URLs are not supported for memory key-value stores.') - # Create files for the record - record_path = os.path.join(store_directory, record_filename) - record_metadata_path = os.path.join(store_directory, record_filename + '.__metadata__.json') - - await force_remove(record_path) - await force_remove(record_metadata_path) - - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the key-value store.""" - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) - - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) - - async def _get_record_internal( + async def _update_metadata( self, - key: str, *, - as_bytes: bool = False, - ) -> KeyValueStoreRecord | None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - stored_record = existing_store_by_id.records.get(key) - - if stored_record is None: - return None - - record = KeyValueStoreRecord( - key=stored_record.key, - value=stored_record.value, - content_type=stored_record.content_type, - filename=stored_record.filename, - ) - - if not as_bytes: - try: - record.value = maybe_parse_body(record.value, str(record.content_type)) - except ValueError: - logger.exception('Error parsing key-value store record') - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=False) - - return record - - def _filename_from_record(self, record: KeyValueStoreRecord) -> str: - if record.filename is not None: - return record.filename - - if not record.content_type or record.content_type == 'application/octet-stream': - return record.key - - extension = determine_file_extension(record.content_type) - - if record.key.endswith(f'.{extension}'): - return record.key - - return f'{record.key}.{extension}' + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the key-value store metadata with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py deleted file mode 100644 index 939780449f..0000000000 --- a/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import KeyValueStoreCollectionClient as BaseKeyValueStoreCollectionClient -from crawlee.storage_clients.models import KeyValueStoreListPage, KeyValueStoreMetadata - -from ._creation_management import get_or_create_inner -from ._key_value_store_client import KeyValueStoreClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class KeyValueStoreCollectionClient(BaseKeyValueStoreCollectionClient): - """Subclient for manipulating key-value stores.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[KeyValueStoreClient]: - return self._memory_storage_client.key_value_stores_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> KeyValueStoreMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=KeyValueStoreClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> KeyValueStoreListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return KeyValueStoreListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_memory_storage_client.py b/src/crawlee/storage_clients/_memory/_memory_storage_client.py deleted file mode 100644 index 8000f41274..0000000000 --- a/src/crawlee/storage_clients/_memory/_memory_storage_client.py +++ /dev/null @@ -1,358 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import os -import shutil -from logging import getLogger -from pathlib import Path -from typing import TYPE_CHECKING, TypeVar - -from typing_extensions import override - -from crawlee._utils.docs import docs_group -from crawlee.configuration import Configuration -from crawlee.storage_clients import StorageClient - -from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient -from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient -from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient - -if TYPE_CHECKING: - from crawlee.storage_clients._base import ResourceClient - - -TResourceClient = TypeVar('TResourceClient', DatasetClient, KeyValueStoreClient, RequestQueueClient) - -logger = getLogger(__name__) - - -@docs_group('Classes') -class MemoryStorageClient(StorageClient): - """Represents an in-memory storage client for managing datasets, key-value stores, and request queues. - - It emulates in-memory storage similar to the Apify platform, supporting both in-memory and local file system-based - persistence. - - The behavior of the storage, such as data persistence and metadata writing, can be customized via initialization - parameters or environment variables. - """ - - _MIGRATING_KEY_VALUE_STORE_DIR_NAME = '__CRAWLEE_MIGRATING_KEY_VALUE_STORE' - """Name of the directory used to temporarily store files during the migration of the default key-value store.""" - - _TEMPORARY_DIR_NAME = '__CRAWLEE_TEMPORARY' - """Name of the directory used to temporarily store files during purges.""" - - _DATASETS_DIR_NAME = 'datasets' - """Name of the directory containing datasets.""" - - _KEY_VALUE_STORES_DIR_NAME = 'key_value_stores' - """Name of the directory containing key-value stores.""" - - _REQUEST_QUEUES_DIR_NAME = 'request_queues' - """Name of the directory containing request queues.""" - - def __init__( - self, - *, - write_metadata: bool, - persist_storage: bool, - storage_dir: str, - default_request_queue_id: str, - default_key_value_store_id: str, - default_dataset_id: str, - ) -> None: - """Initialize a new instance. - - In most cases, you should use the `from_config` constructor to create a new instance based on - the provided configuration. - - Args: - write_metadata: Whether to write metadata to the storage. - persist_storage: Whether to persist the storage. - storage_dir: Path to the storage directory. - default_request_queue_id: The default request queue ID. - default_key_value_store_id: The default key-value store ID. - default_dataset_id: The default dataset ID. - """ - # Set the internal attributes. - self._write_metadata = write_metadata - self._persist_storage = persist_storage - self._storage_dir = storage_dir - self._default_request_queue_id = default_request_queue_id - self._default_key_value_store_id = default_key_value_store_id - self._default_dataset_id = default_dataset_id - - self.datasets_handled: list[DatasetClient] = [] - self.key_value_stores_handled: list[KeyValueStoreClient] = [] - self.request_queues_handled: list[RequestQueueClient] = [] - - self._purged_on_start = False # Indicates whether a purge was already performed on this instance. - self._purge_lock = asyncio.Lock() - - @classmethod - def from_config(cls, config: Configuration | None = None) -> MemoryStorageClient: - """Initialize a new instance based on the provided `Configuration`. - - Args: - config: The `Configuration` instance. Uses the global (default) one if not provided. - """ - config = config or Configuration.get_global_configuration() - - return cls( - write_metadata=config.write_metadata, - persist_storage=config.persist_storage, - storage_dir=config.storage_dir, - default_request_queue_id=config.default_request_queue_id, - default_key_value_store_id=config.default_key_value_store_id, - default_dataset_id=config.default_dataset_id, - ) - - @property - def write_metadata(self) -> bool: - """Whether to write metadata to the storage.""" - return self._write_metadata - - @property - def persist_storage(self) -> bool: - """Whether to persist the storage.""" - return self._persist_storage - - @property - def storage_dir(self) -> str: - """Path to the storage directory.""" - return self._storage_dir - - @property - def datasets_directory(self) -> str: - """Path to the directory containing datasets.""" - return os.path.join(self.storage_dir, self._DATASETS_DIR_NAME) - - @property - def key_value_stores_directory(self) -> str: - """Path to the directory containing key-value stores.""" - return os.path.join(self.storage_dir, self._KEY_VALUE_STORES_DIR_NAME) - - @property - def request_queues_directory(self) -> str: - """Path to the directory containing request queues.""" - return os.path.join(self.storage_dir, self._REQUEST_QUEUES_DIR_NAME) - - @override - def dataset(self, id: str) -> DatasetClient: - return DatasetClient(memory_storage_client=self, id=id) - - @override - def datasets(self) -> DatasetCollectionClient: - return DatasetCollectionClient(memory_storage_client=self) - - @override - def key_value_store(self, id: str) -> KeyValueStoreClient: - return KeyValueStoreClient(memory_storage_client=self, id=id) - - @override - def key_value_stores(self) -> KeyValueStoreCollectionClient: - return KeyValueStoreCollectionClient(memory_storage_client=self) - - @override - def request_queue(self, id: str) -> RequestQueueClient: - return RequestQueueClient(memory_storage_client=self, id=id) - - @override - def request_queues(self) -> RequestQueueCollectionClient: - return RequestQueueCollectionClient(memory_storage_client=self) - - @override - async def purge_on_start(self) -> None: - # Optimistic, non-blocking check - if self._purged_on_start is True: - logger.debug('Storage was already purged on start.') - return - - async with self._purge_lock: - # Another check under the lock just to be sure - if self._purged_on_start is True: - # Mypy doesn't understand that the _purged_on_start can change while we're getting the async lock - return # type: ignore[unreachable] - - await self._purge_default_storages() - self._purged_on_start = True - - def get_cached_resource_client( - self, - resource_client_class: type[TResourceClient], - id: str | None, - name: str | None, - ) -> TResourceClient | None: - """Try to return a resource client from the internal cache.""" - if issubclass(resource_client_class, DatasetClient): - cache = self.datasets_handled - elif issubclass(resource_client_class, KeyValueStoreClient): - cache = self.key_value_stores_handled - elif issubclass(resource_client_class, RequestQueueClient): - cache = self.request_queues_handled - else: - return None - - for storage_client in cache: - if storage_client.id == id or ( - storage_client.name and name and storage_client.name.lower() == name.lower() - ): - return storage_client - - return None - - def add_resource_client_to_cache(self, resource_client: ResourceClient) -> None: - """Add a new resource client to the internal cache.""" - if isinstance(resource_client, DatasetClient): - self.datasets_handled.append(resource_client) - if isinstance(resource_client, KeyValueStoreClient): - self.key_value_stores_handled.append(resource_client) - if isinstance(resource_client, RequestQueueClient): - self.request_queues_handled.append(resource_client) - - async def _purge_default_storages(self) -> None: - """Clean up the storage directories, preparing the environment for a new run. - - It aims to remove residues from previous executions to avoid data contamination between runs. - - It specifically targets: - - The local directory containing the default dataset. - - All records from the default key-value store in the local directory, except for the 'INPUT' key. - - The local directory containing the default request queue. - """ - # Key-value stores - if await asyncio.to_thread(os.path.exists, self.key_value_stores_directory): - key_value_store_folders = await asyncio.to_thread(os.scandir, self.key_value_stores_directory) - for key_value_store_folder in key_value_store_folders: - if key_value_store_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ) or key_value_store_folder.name.startswith('__OLD'): - await self._batch_remove_files(key_value_store_folder.path) - elif key_value_store_folder.name == self._default_key_value_store_id: - await self._handle_default_key_value_store(key_value_store_folder.path) - - # Datasets - if await asyncio.to_thread(os.path.exists, self.datasets_directory): - dataset_folders = await asyncio.to_thread(os.scandir, self.datasets_directory) - for dataset_folder in dataset_folders: - if dataset_folder.name == self._default_dataset_id or dataset_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ): - await self._batch_remove_files(dataset_folder.path) - - # Request queues - if await asyncio.to_thread(os.path.exists, self.request_queues_directory): - request_queue_folders = await asyncio.to_thread(os.scandir, self.request_queues_directory) - for request_queue_folder in request_queue_folders: - if request_queue_folder.name == self._default_request_queue_id or request_queue_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ): - await self._batch_remove_files(request_queue_folder.path) - - async def _handle_default_key_value_store(self, folder: str) -> None: - """Manage the cleanup of the default key-value store. - - It removes all files to ensure a clean state except for a set of predefined input keys (`possible_input_keys`). - - Args: - folder: Path to the default key-value store directory to clean. - """ - folder_exists = await asyncio.to_thread(os.path.exists, folder) - temporary_path = os.path.normpath(os.path.join(folder, '..', self._MIGRATING_KEY_VALUE_STORE_DIR_NAME)) - - # For optimization, we want to only attempt to copy a few files from the default key-value store - possible_input_keys = [ - 'INPUT', - 'INPUT.json', - 'INPUT.bin', - 'INPUT.txt', - ] - - if folder_exists: - # Create a temporary folder to save important files in - Path(temporary_path).mkdir(parents=True, exist_ok=True) - - # Go through each file and save the ones that are important - for entity in possible_input_keys: - original_file_path = os.path.join(folder, entity) - temp_file_path = os.path.join(temporary_path, entity) - with contextlib.suppress(Exception): - await asyncio.to_thread(os.rename, original_file_path, temp_file_path) - - # Remove the original folder and all its content - counter = 0 - temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) - done = False - try: - while not done: - await asyncio.to_thread(os.rename, folder, temp_path_for_old_folder) - done = True - except Exception: - counter += 1 - temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) - - # Replace the temporary folder with the original folder - await asyncio.to_thread(os.rename, temporary_path, folder) - - # Remove the old folder - await self._batch_remove_files(temp_path_for_old_folder) - - async def _batch_remove_files(self, folder: str, counter: int = 0) -> None: - """Remove a folder and its contents in batches to minimize blocking time. - - This method first renames the target folder to a temporary name, then deletes the temporary folder, - allowing the file system operations to proceed without hindering other asynchronous tasks. - - Args: - folder: The directory path to remove. - counter: A counter used for generating temporary directory names in case of conflicts. - """ - folder_exists = await asyncio.to_thread(os.path.exists, folder) - - if folder_exists: - temporary_folder = ( - folder - if os.path.basename(folder).startswith(f'{self._TEMPORARY_DIR_NAME}_') - else os.path.normpath(os.path.join(folder, '..', f'{self._TEMPORARY_DIR_NAME}_{counter}')) - ) - - try: - # Rename the old folder to the new one to allow background deletions - await asyncio.to_thread(os.rename, folder, temporary_folder) - except Exception: - # Folder exists already, try again with an incremented counter - return await self._batch_remove_files(folder, counter + 1) - - await asyncio.to_thread(shutil.rmtree, temporary_folder, ignore_errors=True) - return None - - def _get_default_storage_id(self, storage_client_class: type[TResourceClient]) -> str: - """Get the default storage ID based on the storage class.""" - if issubclass(storage_client_class, DatasetClient): - return self._default_dataset_id - - if issubclass(storage_client_class, KeyValueStoreClient): - return self._default_key_value_store_id - - if issubclass(storage_client_class, RequestQueueClient): - return self._default_request_queue_id - - raise ValueError(f'Invalid storage class: {storage_client_class.__name__}') - - def _get_storage_dir(self, storage_client_class: type[TResourceClient]) -> str: - """Get the storage directory based on the storage class.""" - if issubclass(storage_client_class, DatasetClient): - return self.datasets_directory - - if issubclass(storage_client_class, KeyValueStoreClient): - return self.key_value_stores_directory - - if issubclass(storage_client_class, RequestQueueClient): - return self.request_queues_directory - - raise ValueError(f'Invalid storage class: {storage_client_class.__name__}') diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index 477d53df07..d360695884 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -1,558 +1,344 @@ from __future__ import annotations -import asyncio -import os -import shutil from datetime import datetime, timezone -from decimal import Decimal from logging import getLogger from typing import TYPE_CHECKING -from sortedcollections import ValueSortedDict from typing_extensions import override -from crawlee._types import StorageTypes +from crawlee import Request from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import force_remove, force_rename, json_dumps -from crawlee._utils.requests import unique_key_to_request_id -from crawlee.storage_clients._base import RequestQueueClient as BaseRequestQueueClient +from crawlee.storage_clients._base import RequestQueueClient from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, - InternalRequest, + AddRequestsResponse, ProcessedRequest, - ProlongRequestLockResponse, - RequestQueueHead, - RequestQueueHeadWithLocks, RequestQueueMetadata, - UnprocessedRequest, ) -from ._creation_management import find_or_create_client_by_id_or_name_inner, persist_metadata_if_enabled - if TYPE_CHECKING: from collections.abc import Sequence - from sortedcontainers import SortedDict + from crawlee.configuration import Configuration - from crawlee import Request +logger = getLogger(__name__) - from ._memory_storage_client import MemoryStorageClient -logger = getLogger(__name__) +class MemoryRequestQueueClient(RequestQueueClient): + """Memory implementation of the request queue client. + This client stores requests in memory using a Python list and dictionary. No data is persisted between + process runs, which means all requests are lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler runs where persistence + is not required. -class RequestQueueClient(BaseRequestQueueClient): - """Subclient for manipulating a single request queue.""" + This client provides fast access to request data but is limited by available memory and + does not support data sharing across different processes. + """ def __init__( self, *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, - handled_request_count: int = 0, - pending_request_count: int = 0, + id: str, + name: str, + created_at: datetime, + accessed_at: datetime, + modified_at: datetime, + had_multiple_clients: bool, + handled_request_count: int, + pending_request_count: int, + stats: dict, + total_request_count: int, ) -> None: - self._memory_storage_client = memory_storage_client - self.id = id or crypto_random_object_id() - self.name = name - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) - self.handled_request_count = handled_request_count - self.pending_request_count = pending_request_count - - self.requests: SortedDict[str, InternalRequest] = ValueSortedDict( - lambda request: request.order_no or -float('inf') - ) - self.file_operation_lock = asyncio.Lock() - self._last_used_timestamp = Decimal(0) - - self._in_progress = set[str]() - - @property - def resource_info(self) -> RequestQueueMetadata: - """Get the resource info for the request queue client.""" - return RequestQueueMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - had_multiple_clients=False, - handled_request_count=self.handled_request_count, - pending_request_count=self.pending_request_count, - stats={}, - total_request_count=len(self.requests), - user_id='1', - resource_directory=self.resource_directory, - ) - - @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.request_queues_directory, self.name or self.id) - - @override - async def get(self) -> RequestQueueMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info - - return None + """Initialize a new instance. - @override - async def update(self, *, name: str | None = None) -> RequestQueueMetadata: - # Check by id - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + Preferably use the `MemoryRequestQueueClient.open` class method to create a new instance. + """ + self._metadata = RequestQueueMetadata( + id=id, + name=name, + created_at=created_at, + accessed_at=accessed_at, + modified_at=modified_at, + had_multiple_clients=had_multiple_clients, + handled_request_count=handled_request_count, + pending_request_count=pending_request_count, + stats=stats, + total_request_count=total_request_count, ) - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - # Skip if no changes - if name is None: - return existing_queue_by_id.resource_info - - async with existing_queue_by_id.file_operation_lock: - # Check that name is not in use already - existing_queue_by_name = next( - ( - queue - for queue in self._memory_storage_client.request_queues_handled - if queue.name and queue.name.lower() == name.lower() - ), - None, - ) - - if existing_queue_by_name is not None: - raise_on_duplicate_storage(StorageTypes.REQUEST_QUEUE, 'name', name) + # List to hold RQ items + self._records = list[Request]() - previous_dir = existing_queue_by_id.resource_directory - existing_queue_by_id.name = name - - await force_rename(previous_dir, existing_queue_by_id.resource_directory) - - # Update timestamps - await existing_queue_by_id.update_timestamps(has_been_modified=True) - - return existing_queue_by_id.resource_info + # Dictionary to track in-progress requests (fetched but not yet handled or reclaimed) + self._in_progress = dict[str, Request]() @override - async def delete(self) -> None: - queue = next( - (queue for queue in self._memory_storage_client.request_queues_handled if queue.id == self.id), - None, - ) - - if queue is not None: - async with queue.file_operation_lock: - self._memory_storage_client.request_queues_handled.remove(queue) - queue.pending_request_count = 0 - queue.handled_request_count = 0 - queue.requests.clear() - - if os.path.exists(queue.resource_directory): - await asyncio.to_thread(shutil.rmtree, queue.resource_directory) + @property + def metadata(self) -> RequestQueueMetadata: + return self._metadata @override - async def list_head(self, *, limit: int | None = None, skip_in_progress: bool = False) -> RequestQueueHead: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> MemoryRequestQueueClient: + name = name or configuration.default_request_queue_id + + # If specific id is provided, use it; otherwise, generate a new one + id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + return cls( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + had_multiple_clients=False, + handled_request_count=0, + pending_request_count=0, + stats={}, + total_request_count=0, ) - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - async with existing_queue_by_id.file_operation_lock: - await existing_queue_by_id.update_timestamps(has_been_modified=False) - - requests: list[Request] = [] - - # Iterate all requests in the queue which have sorted key larger than infinity, which means - # `order_no` is not `None`. This will iterate them in order of `order_no`. - for request_key in existing_queue_by_id.requests.irange_key( # type: ignore[attr-defined] # irange_key is a valid SortedDict method but not recognized by mypy - min_key=-float('inf'), inclusive=(False, True) - ): - if len(requests) == limit: - break - - if skip_in_progress and request_key in existing_queue_by_id._in_progress: # noqa: SLF001 - continue - internal_request = existing_queue_by_id.requests.get(request_key) - - # Check that the request still exists and was not handled, - # in case something deleted it or marked it as handled concurrenctly - if internal_request and not internal_request.handled_at: - requests.append(internal_request.to_request()) - - return RequestQueueHead( - limit=limit, - had_multiple_clients=False, - queue_modified_at=existing_queue_by_id._modified_at, # noqa: SLF001 - items=requests, - ) - @override - async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) -> RequestQueueHeadWithLocks: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - result = await self.list_head(limit=limit, skip_in_progress=True) - - for item in result.items: - existing_queue_by_id._in_progress.add(item.id) # noqa: SLF001 - - return RequestQueueHeadWithLocks( - queue_has_locked_requests=len(existing_queue_by_id._in_progress) > 0, # noqa: SLF001 - lock_secs=lock_secs, - limit=result.limit, - had_multiple_clients=result.had_multiple_clients, - queue_modified_at=result.queue_modified_at, - items=result.items, - ) + async def drop(self) -> None: + # Clear all data + self._records.clear() + self._in_progress.clear() @override - async def add_request( + async def add_batch_of_requests( self, - request: Request, + requests: Sequence[Request], *, forefront: bool = False, - ) -> ProcessedRequest: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + ) -> AddRequestsResponse: + """Add a batch of requests to the queue. - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + Args: + requests: The requests to add. + forefront: Whether to add the requests to the beginning of the queue. - internal_request = await self._create_internal_request(request, forefront) + Returns: + Response containing information about the added requests. + """ + processed_requests = [] + for request in requests: + # Ensure the request has an ID + if not request.id: + request.id = crypto_random_object_id() - async with existing_queue_by_id.file_operation_lock: - existing_internal_request_with_id = existing_queue_by_id.requests.get(internal_request.id) + # Check if the request is already in the queue by unique_key + existing_request = next((r for r in self._records if r.unique_key == request.unique_key), None) - # We already have the request present, so we return information about it - if existing_internal_request_with_id is not None: - await existing_queue_by_id.update_timestamps(has_been_modified=False) + was_already_present = existing_request is not None + was_already_handled = was_already_present and existing_request and existing_request.handled_at is not None - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=True, - was_already_handled=existing_internal_request_with_id.handled_at is not None, + # If the request is already in the queue and handled, don't add it again + if was_already_handled: + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) ) - - existing_queue_by_id.requests[internal_request.id] = internal_request - if internal_request.handled_at: - existing_queue_by_id.handled_request_count += 1 + continue + + # If the request is already in the queue but not handled, update it + if was_already_present: + # Update the existing request with any new data + for idx, rec in enumerate(self._records): + if rec.unique_key == request.unique_key: + self._records[idx] = request + break else: - existing_queue_by_id.pending_request_count += 1 - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._persist_single_request_to_storage( - request=internal_request, - entity_directory=existing_queue_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, + # Add the new request to the queue + if forefront: + self._records.insert(0, request) + else: + self._records.append(request) + + # Update metadata counts + self._metadata.total_request_count += 1 + self._metadata.pending_request_count += 1 + + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=was_already_present, + was_already_handled=False, + ) ) - # We return was_already_handled=False even though the request may have been added as handled, - # because that's how API behaves. - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=False, - was_already_handled=False, - ) + await self._update_metadata(update_accessed_at=True, update_modified_at=True) - @override - async def get_request(self, request_id: str) -> Request | None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + return AddRequestsResponse( + processed_requests=processed_requests, + unprocessed_requests=[], ) - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + @override + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. - async with existing_queue_by_id.file_operation_lock: - await existing_queue_by_id.update_timestamps(has_been_modified=False) + Returns: + The request or `None` if there are no more pending requests. + """ + # Find the first request that's not handled or in progress + for request in self._records: + if request.handled_at is None and request.id not in self._in_progress: + # Mark as in progress + self._in_progress[request.id] = request + return request - internal_request = existing_queue_by_id.requests.get(request_id) - return internal_request.to_request() if internal_request else None + return None @override - async def update_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + async def get_request(self, request_id: str) -> Request | None: + """Retrieve a request from the queue. - internal_request = await self._create_internal_request(request, forefront) + Args: + request_id: ID of the request to retrieve. - # First we need to check the existing request to be able to return information about its handled state. - existing_internal_request = existing_queue_by_id.requests.get(internal_request.id) + Returns: + The retrieved request, or None, if it did not exist. + """ + # Check in-progress requests first + if request_id in self._in_progress: + return self._in_progress[request_id] - # Undefined means that the request is not present in the queue. - # We need to insert it, to behave the same as API. - if existing_internal_request is None: - return await self.add_request(request, forefront=forefront) + # Otherwise search in the records + for request in self._records: + if request.id == request_id: + return request - async with existing_queue_by_id.file_operation_lock: - # When updating the request, we need to make sure that - # the handled counts are updated correctly in all cases. - existing_queue_by_id.requests[internal_request.id] = internal_request + return None - pending_count_adjustment = 0 - is_request_handled_state_changing = existing_internal_request.handled_at != internal_request.handled_at + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. - request_was_handled_before_update = existing_internal_request.handled_at is not None + Handled requests will never again be returned by the `fetch_next_request` method. - # We add 1 pending request if previous state was handled - if is_request_handled_state_changing: - pending_count_adjustment = 1 if request_was_handled_before_update else -1 + Args: + request: The request to mark as handled. - existing_queue_by_id.pending_request_count += pending_count_adjustment - existing_queue_by_id.handled_request_count -= pending_count_adjustment - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._persist_single_request_to_storage( - request=internal_request, - entity_directory=existing_queue_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, - ) + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + # Check if the request is in progress + if request.id not in self._in_progress: + return None - if request.handled_at is not None: - existing_queue_by_id._in_progress.discard(request.id) # noqa: SLF001 + # Set handled_at timestamp if not already set + if request.handled_at is None: + request.handled_at = datetime.now(timezone.utc) - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=True, - was_already_handled=request_was_handled_before_update, - ) + # Update the request in records + for idx, rec in enumerate(self._records): + if rec.id == request.id: + self._records[idx] = request + break - @override - async def delete_request(self, request_id: str) -> None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + # Remove from in-progress + del self._in_progress[request.id] - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + # Update metadata counts + self._metadata.handled_request_count += 1 + self._metadata.pending_request_count -= 1 - async with existing_queue_by_id.file_operation_lock: - internal_request = existing_queue_by_id.requests.get(request_id) + # Update metadata timestamps + await self._update_metadata(update_modified_at=True) - if internal_request: - del existing_queue_by_id.requests[request_id] - if internal_request.handled_at: - existing_queue_by_id.handled_request_count -= 1 - else: - existing_queue_by_id.pending_request_count -= 1 - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._delete_request_file_from_storage( - entity_directory=existing_queue_by_id.resource_directory, - request_id=request_id, - ) + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) @override - async def prolong_request_lock( + async def reclaim_request( self, - request_id: str, + request: Request, *, forefront: bool = False, - lock_secs: int, - ) -> ProlongRequestLockResponse: - return ProlongRequestLockResponse(lock_expires_at=datetime.now(timezone.utc)) + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. - @override - async def delete_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - ) -> None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + The request will be returned for processing later again by another call to `fetch_next_request`. - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + Args: + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. - existing_queue_by_id._in_progress.discard(request_id) # noqa: SLF001 + Returns: + Information about the queue operation. `None` if the given request was not in progress. + """ + # Check if the request is in progress + if request.id not in self._in_progress: + return None - @override - async def batch_add_requests( - self, - requests: Sequence[Request], - *, - forefront: bool = False, - ) -> BatchRequestsOperationResponse: - processed_requests = list[ProcessedRequest]() - unprocessed_requests = list[UnprocessedRequest]() + # Remove from in-progress + del self._in_progress[request.id] - for request in requests: - try: - processed_request = await self.add_request(request, forefront=forefront) - processed_requests.append( - ProcessedRequest( - id=processed_request.id, - unique_key=processed_request.unique_key, - was_already_present=processed_request.was_already_present, - was_already_handled=processed_request.was_already_handled, - ) - ) - except Exception as exc: # noqa: PERF203 - logger.warning(f'Error adding request to the queue: {exc}') - unprocessed_requests.append( - UnprocessedRequest( - unique_key=request.unique_key, - url=request.url, - method=request.method, - ) - ) + # If forefront is true, move the request to the beginning of the queue + if forefront: + # First remove the request from its current position + for idx, rec in enumerate(self._records): + if rec.id == request.id: + self._records.pop(idx) + break - return BatchRequestsOperationResponse( - processed_requests=processed_requests, - unprocessed_requests=unprocessed_requests, + # Then insert it at the beginning + self._records.insert(0, request) + + # Update metadata timestamps + await self._update_metadata(update_modified_at=True) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, ) @override - async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: - raise NotImplementedError('This method is not supported in memory storage.') + async def is_empty(self) -> bool: + """Check if the queue is empty. - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the request queue.""" - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) + Returns: + True if the queue is empty, False otherwise. + """ + await self._update_metadata(update_accessed_at=True) - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) + # Queue is empty if there are no pending requests + pending_requests = [r for r in self._records if r.handled_at is None] + return len(pending_requests) == 0 - async def _persist_single_request_to_storage( + async def _update_metadata( self, *, - request: InternalRequest, - entity_directory: str, - persist_storage: bool, + update_accessed_at: bool = False, + update_modified_at: bool = False, ) -> None: - """Update or writes a single request item to the disk. - - This function writes a given request dictionary to a JSON file, named after the request's ID, - within a specified directory. The writing process is skipped if `persist_storage` is False. - Before writing, it ensures that the target directory exists, creating it if necessary. + """Update the request queue metadata with current information. Args: - request: The dictionary containing the request data. - entity_directory: The directory path where the request file should be stored. - persist_storage: A boolean flag indicating whether the request should be persisted to the disk. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. """ - # Skip writing files to the disk if the client has the option set to false - if not persist_storage: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Write the request to the file - file_path = os.path.join(entity_directory, f'{request.id}.json') - f = await asyncio.to_thread(open, file_path, mode='w', encoding='utf-8') - try: - s = await json_dumps(request.model_dump()) - await asyncio.to_thread(f.write, s) - finally: - f.close() - - async def _delete_request_file_from_storage(self, *, request_id: str, entity_directory: str) -> None: - """Delete a specific request item from the disk. - - This function removes a file representing a request, identified by the request's ID, from a - specified directory. Before attempting to remove the file, it ensures that the target directory - exists, creating it if necessary. - - Args: - request_id: The identifier of the request to be deleted. - entity_directory: The directory path where the request file is stored. - """ - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - file_path = os.path.join(entity_directory, f'{request_id}.json') - await force_remove(file_path) - - async def _create_internal_request(self, request: Request, forefront: bool | None) -> InternalRequest: - order_no = self._calculate_order_no(request, forefront) - id = unique_key_to_request_id(request.unique_key) - - if request.id is not None and request.id != id: - logger.warning( - f'The request ID does not match the ID from the unique_key (request.id={request.id}, id={id}).' - ) - - return InternalRequest.from_request(request=request, id=id, order_no=order_no) - - def _calculate_order_no(self, request: Request, forefront: bool | None) -> Decimal | None: - if request.handled_at is not None: - return None - - # Get the current timestamp in milliseconds - timestamp = Decimal(str(datetime.now(tz=timezone.utc).timestamp())) * Decimal(1000) - timestamp = round(timestamp, 6) - - # Make sure that this timestamp was not used yet, so that we have unique order_nos - if timestamp <= self._last_used_timestamp: - timestamp = self._last_used_timestamp + Decimal('0.000001') - - self._last_used_timestamp = timestamp + now = datetime.now(timezone.utc) - return -timestamp if forefront else timestamp + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now diff --git a/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py b/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py deleted file mode 100644 index 2f2df2be89..0000000000 --- a/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import RequestQueueCollectionClient as BaseRequestQueueCollectionClient -from crawlee.storage_clients.models import RequestQueueListPage, RequestQueueMetadata - -from ._creation_management import get_or_create_inner -from ._request_queue_client import RequestQueueClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class RequestQueueCollectionClient(BaseRequestQueueCollectionClient): - """Subclient for manipulating request queues.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[RequestQueueClient]: - return self._memory_storage_client.request_queues_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> RequestQueueMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=RequestQueueClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> RequestQueueListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return RequestQueueListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_storage_client.py b/src/crawlee/storage_clients/_memory/_storage_client.py new file mode 100644 index 0000000000..6123a6ca53 --- /dev/null +++ b/src/crawlee/storage_clients/_memory/_storage_client.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing_extensions import override + +from crawlee.configuration import Configuration +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import MemoryDatasetClient +from ._key_value_store_client import MemoryKeyValueStoreClient +from ._request_queue_client import MemoryRequestQueueClient + + +class MemoryStorageClient(StorageClient): + """Memory storage client.""" + + @override + async def open_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> MemoryDatasetClient: + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryDatasetClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await MemoryDatasetClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_key_value_store_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> MemoryKeyValueStoreClient: + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await MemoryKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + + return client + + @override + async def open_request_queue_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> MemoryRequestQueueClient: + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryRequestQueueClient.open(id=id, name=name, configuration=configuration) + + if configuration.purge_on_start: + await client.drop() + client = await MemoryRequestQueueClient.open(id=id, name=name, configuration=configuration) + + return client diff --git a/src/crawlee/storage_clients/models.py b/src/crawlee/storage_clients/models.py index f016e24730..f680ba945f 100644 --- a/src/crawlee/storage_clients/models.py +++ b/src/crawlee/storage_clients/models.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from datetime import datetime +from datetime import datetime, timedelta from decimal import Decimal from typing import Annotated, Any, Generic @@ -26,10 +26,19 @@ class StorageMetadata(BaseModel): model_config = ConfigDict(populate_by_name=True, extra='allow') id: Annotated[str, Field(alias='id')] - name: Annotated[str | None, Field(alias='name', default='')] + """The unique identifier of the storage.""" + + name: Annotated[str, Field(alias='name', default='default')] + """The name of the storage.""" + accessed_at: Annotated[datetime, Field(alias='accessedAt')] + """The timestamp when the storage was last accessed.""" + created_at: Annotated[datetime, Field(alias='createdAt')] + """The timestamp when the storage was created.""" + modified_at: Annotated[datetime, Field(alias='modifiedAt')] + """The timestamp when the storage was last modified.""" @docs_group('Data structures') @@ -39,6 +48,7 @@ class DatasetMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) item_count: Annotated[int, Field(alias='itemCount')] + """The number of items in the dataset.""" @docs_group('Data structures') @@ -47,8 +57,6 @@ class KeyValueStoreMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) - user_id: Annotated[str, Field(alias='userId')] - @docs_group('Data structures') class RequestQueueMetadata(StorageMetadata): @@ -57,44 +65,51 @@ class RequestQueueMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] + """Indicates whether the queue has been accessed by multiple clients (consumers).""" + handled_request_count: Annotated[int, Field(alias='handledRequestCount')] + """The number of requests that have been handled from the queue.""" + pending_request_count: Annotated[int, Field(alias='pendingRequestCount')] + """The number of requests that are still pending in the queue.""" + stats: Annotated[dict, Field(alias='stats')] + """Statistics about the request queue, TODO?""" + total_request_count: Annotated[int, Field(alias='totalRequestCount')] - user_id: Annotated[str, Field(alias='userId')] - resource_directory: Annotated[str, Field(alias='resourceDirectory')] + """The total number of requests that have been added to the queue.""" @docs_group('Data structures') -class KeyValueStoreRecord(BaseModel, Generic[KvsValueType]): - """Model for a key-value store record.""" +class KeyValueStoreRecordMetadata(BaseModel): + """Model for a key-value store record metadata.""" model_config = ConfigDict(populate_by_name=True) key: Annotated[str, Field(alias='key')] - value: Annotated[KvsValueType, Field(alias='value')] - content_type: Annotated[str | None, Field(alias='contentType', default=None)] - filename: Annotated[str | None, Field(alias='filename', default=None)] + """The key of the record. + A unique identifier for the record in the key-value store. + """ -@docs_group('Data structures') -class KeyValueStoreRecordMetadata(BaseModel): - """Model for a key-value store record metadata.""" + content_type: Annotated[str, Field(alias='contentType')] + """The MIME type of the record. - model_config = ConfigDict(populate_by_name=True) + Describe the format and type of data stored in the record, following the MIME specification. + """ - key: Annotated[str, Field(alias='key')] - content_type: Annotated[str, Field(alias='contentType')] + size: Annotated[int, Field(alias='size')] + """The size of the record in bytes.""" @docs_group('Data structures') -class KeyValueStoreKeyInfo(BaseModel): - """Model for a key-value store key info.""" +class KeyValueStoreRecord(KeyValueStoreRecordMetadata, Generic[KvsValueType]): + """Model for a key-value store record.""" model_config = ConfigDict(populate_by_name=True) - key: Annotated[str, Field(alias='key')] - size: Annotated[int, Field(alias='size')] + value: Annotated[KvsValueType, Field(alias='value')] + """The value of the record.""" @docs_group('Data structures') @@ -104,11 +119,22 @@ class KeyValueStoreListKeysPage(BaseModel): model_config = ConfigDict(populate_by_name=True) count: Annotated[int, Field(alias='count')] + """The number of keys returned on this page.""" + limit: Annotated[int, Field(alias='limit')] + """The maximum number of keys to return.""" + is_truncated: Annotated[bool, Field(alias='isTruncated')] - items: Annotated[list[KeyValueStoreKeyInfo], Field(alias='items', default_factory=list)] + """Indicates whether there are more keys to retrieve.""" + exclusive_start_key: Annotated[str | None, Field(alias='exclusiveStartKey', default=None)] + """The key from which to start this page of results.""" + next_exclusive_start_key: Annotated[str | None, Field(alias='nextExclusiveStartKey', default=None)] + """The key from which to start the next page of results.""" + + items: Annotated[list[KeyValueStoreRecordMetadata], Field(alias='items', default_factory=list)] + """The list of KVS items metadata returned on this page.""" @docs_group('Data structures') @@ -126,22 +152,31 @@ class RequestQueueHeadState(BaseModel): @docs_group('Data structures') class RequestQueueHead(BaseModel): - """Model for the request queue head.""" + """Model for request queue head. + + Represents a collection of requests retrieved from the beginning of a queue, + including metadata about the queue's state and lock information for the requests. + """ model_config = ConfigDict(populate_by_name=True) limit: Annotated[int | None, Field(alias='limit', default=None)] - had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] + """The maximum number of requests that were requested from the queue.""" + + had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients', default=False)] + """Indicates whether the queue has been accessed by multiple clients (consumers).""" + queue_modified_at: Annotated[datetime, Field(alias='queueModifiedAt')] - items: Annotated[list[Request], Field(alias='items', default_factory=list)] + """The timestamp when the queue was last modified.""" + lock_time: Annotated[timedelta | None, Field(alias='lockSecs', default=None)] + """The duration for which the returned requests are locked and cannot be processed by other clients.""" -@docs_group('Data structures') -class RequestQueueHeadWithLocks(RequestQueueHead): - """Model for request queue head with locks.""" + queue_has_locked_requests: Annotated[bool | None, Field(alias='queueHasLockedRequests', default=False)] + """Indicates whether the queue contains any locked requests.""" - lock_secs: Annotated[int, Field(alias='lockSecs')] - queue_has_locked_requests: Annotated[bool | None, Field(alias='queueHasLockedRequests')] = None + items: Annotated[list[Request], Field(alias='items', default_factory=list[Request])] + """The list of request objects retrieved from the beginning of the queue.""" class _ListPage(BaseModel): @@ -230,13 +265,22 @@ class UnprocessedRequest(BaseModel): @docs_group('Data structures') -class BatchRequestsOperationResponse(BaseModel): - """Response to batch request deletion calls.""" +class AddRequestsResponse(BaseModel): + """Model for a response to add requests to a queue. + + Contains detailed information about the processing results when adding multiple requests + to a queue. This includes which requests were successfully processed and which ones + encountered issues during processing. + """ model_config = ConfigDict(populate_by_name=True) processed_requests: Annotated[list[ProcessedRequest], Field(alias='processedRequests')] + """Successfully processed requests, including information about whether they were + already present in the queue and whether they had been handled previously.""" + unprocessed_requests: Annotated[list[UnprocessedRequest], Field(alias='unprocessedRequests')] + """Requests that could not be processed, typically due to validation errors or other issues.""" class InternalRequest(BaseModel): @@ -275,3 +319,22 @@ def from_request(cls, request: Request, id: str, order_no: Decimal | None) -> In def to_request(self) -> Request: """Convert the internal request back to a `Request` object.""" return self.request + + +class CachedRequest(BaseModel): + """Pydantic model for cached request information.""" + + id: str + """The ID of the request.""" + + was_already_handled: bool + """Whether the request was already handled.""" + + hydrated: Request | None = None + """The hydrated request object (the original one).""" + + lock_expires_at: datetime | None = None + """The expiration time of the lock on the request.""" + + forefront: bool = False + """Whether the request was added to the forefront of the queue.""" diff --git a/src/crawlee/storages/_base.py b/src/crawlee/storages/_base.py index 08d2cbd7be..9216bf4569 100644 --- a/src/crawlee/storages/_base.py +++ b/src/crawlee/storages/_base.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from crawlee.configuration import Configuration from crawlee.storage_clients._base import StorageClient - from crawlee.storage_clients.models import StorageMetadata + from crawlee.storage_clients.models import DatasetMetadata, KeyValueStoreMetadata, RequestQueueMetadata class Storage(ABC): @@ -24,13 +24,8 @@ def name(self) -> str | None: @property @abstractmethod - def storage_object(self) -> StorageMetadata: - """Get the full storage object.""" - - @storage_object.setter - @abstractmethod - def storage_object(self, storage_object: StorageMetadata) -> None: - """Set the full storage object.""" + def metadata(self) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata: + """Get the storage metadata.""" @classmethod @abstractmethod @@ -55,3 +50,17 @@ async def open( @abstractmethod async def drop(self) -> None: """Drop the storage, removing it from the underlying storage client and clearing the cache.""" + + @classmethod + def compute_cache_key( + cls, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + storage_client: StorageClient | None = None, + ) -> str: + """Compute the cache key for the storage. + + The cache key computed based on the storage ID, name, configuration fields, and storage client class. + """ + return f'{id}|{name}|{configuration}|{storage_client.__class__}' diff --git a/src/crawlee/storages/_creation_management.py b/src/crawlee/storages/_creation_management.py deleted file mode 100644 index 14d9b1719e..0000000000 --- a/src/crawlee/storages/_creation_management.py +++ /dev/null @@ -1,231 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING, TypeVar -from weakref import WeakKeyDictionary - -from crawlee.storage_clients import MemoryStorageClient - -from ._dataset import Dataset -from ._key_value_store import KeyValueStore -from ._request_queue import RequestQueue - -if TYPE_CHECKING: - from crawlee.configuration import Configuration - from crawlee.storage_clients._base import ResourceClient, ResourceCollectionClient, StorageClient - -TResource = TypeVar('TResource', Dataset, KeyValueStore, RequestQueue) - - -_creation_locks = WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Lock]() -"""Locks for storage creation (we need a separate lock for every event loop so that tests work as expected).""" - -_cache_dataset_by_id: dict[str, Dataset] = {} -_cache_dataset_by_name: dict[str, Dataset] = {} -_cache_kvs_by_id: dict[str, KeyValueStore] = {} -_cache_kvs_by_name: dict[str, KeyValueStore] = {} -_cache_rq_by_id: dict[str, RequestQueue] = {} -_cache_rq_by_name: dict[str, RequestQueue] = {} - - -def _get_from_cache_by_name( - storage_class: type[TResource], - name: str, -) -> TResource | None: - """Try to restore storage from cache by name.""" - if issubclass(storage_class, Dataset): - return _cache_dataset_by_name.get(name) - if issubclass(storage_class, KeyValueStore): - return _cache_kvs_by_name.get(name) - if issubclass(storage_class, RequestQueue): - return _cache_rq_by_name.get(name) - raise ValueError(f'Unknown storage class: {storage_class.__name__}') - - -def _get_from_cache_by_id( - storage_class: type[TResource], - id: str, -) -> TResource | None: - """Try to restore storage from cache by ID.""" - if issubclass(storage_class, Dataset): - return _cache_dataset_by_id.get(id) - if issubclass(storage_class, KeyValueStore): - return _cache_kvs_by_id.get(id) - if issubclass(storage_class, RequestQueue): - return _cache_rq_by_id.get(id) - raise ValueError(f'Unknown storage: {storage_class.__name__}') - - -def _add_to_cache_by_name(name: str, storage: TResource) -> None: - """Add storage to cache by name.""" - if isinstance(storage, Dataset): - _cache_dataset_by_name[name] = storage - elif isinstance(storage, KeyValueStore): - _cache_kvs_by_name[name] = storage - elif isinstance(storage, RequestQueue): - _cache_rq_by_name[name] = storage - else: - raise TypeError(f'Unknown storage: {storage}') - - -def _add_to_cache_by_id(id: str, storage: TResource) -> None: - """Add storage to cache by ID.""" - if isinstance(storage, Dataset): - _cache_dataset_by_id[id] = storage - elif isinstance(storage, KeyValueStore): - _cache_kvs_by_id[id] = storage - elif isinstance(storage, RequestQueue): - _cache_rq_by_id[id] = storage - else: - raise TypeError(f'Unknown storage: {storage}') - - -def _rm_from_cache_by_id(storage_class: type, id: str) -> None: - """Remove a storage from cache by ID.""" - try: - if issubclass(storage_class, Dataset): - del _cache_dataset_by_id[id] - elif issubclass(storage_class, KeyValueStore): - del _cache_kvs_by_id[id] - elif issubclass(storage_class, RequestQueue): - del _cache_rq_by_id[id] - else: - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - except KeyError as exc: - raise RuntimeError(f'Storage with provided ID was not found ({id}).') from exc - - -def _rm_from_cache_by_name(storage_class: type, name: str) -> None: - """Remove a storage from cache by name.""" - try: - if issubclass(storage_class, Dataset): - del _cache_dataset_by_name[name] - elif issubclass(storage_class, KeyValueStore): - del _cache_kvs_by_name[name] - elif issubclass(storage_class, RequestQueue): - del _cache_rq_by_name[name] - else: - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - except KeyError as exc: - raise RuntimeError(f'Storage with provided name was not found ({name}).') from exc - - -def _get_default_storage_id(configuration: Configuration, storage_class: type[TResource]) -> str: - if issubclass(storage_class, Dataset): - return configuration.default_dataset_id - if issubclass(storage_class, KeyValueStore): - return configuration.default_key_value_store_id - if issubclass(storage_class, RequestQueue): - return configuration.default_request_queue_id - - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - - -async def open_storage( - *, - storage_class: type[TResource], - id: str | None, - name: str | None, - configuration: Configuration, - storage_client: StorageClient, -) -> TResource: - """Open either a new storage or restore an existing one and return it.""" - # Try to restore the storage from cache by name - if name: - cached_storage = _get_from_cache_by_name(storage_class=storage_class, name=name) - if cached_storage: - return cached_storage - - default_id = _get_default_storage_id(configuration, storage_class) - - if not id and not name: - id = default_id - - # Find out if the storage is a default on memory storage - is_default_on_memory = id == default_id and isinstance(storage_client, MemoryStorageClient) - - # Try to restore storage from cache by ID - if id: - cached_storage = _get_from_cache_by_id(storage_class=storage_class, id=id) - if cached_storage: - return cached_storage - - # Purge on start if configured - if configuration.purge_on_start: - await storage_client.purge_on_start() - - # Lock and create new storage - loop = asyncio.get_running_loop() - if loop not in _creation_locks: - _creation_locks[loop] = asyncio.Lock() - - async with _creation_locks[loop]: - if id and not is_default_on_memory: - resource_client = _get_resource_client(storage_class, storage_client, id) - storage_object = await resource_client.get() - if not storage_object: - raise RuntimeError(f'{storage_class.__name__} with id "{id}" does not exist!') - - elif is_default_on_memory: - resource_collection_client = _get_resource_collection_client(storage_class, storage_client) - storage_object = await resource_collection_client.get_or_create(name=name, id=id) - - else: - resource_collection_client = _get_resource_collection_client(storage_class, storage_client) - storage_object = await resource_collection_client.get_or_create(name=name) - - storage = storage_class.from_storage_object(storage_client=storage_client, storage_object=storage_object) - - # Cache the storage by ID and name - _add_to_cache_by_id(storage.id, storage) - if storage.name is not None: - _add_to_cache_by_name(storage.name, storage) - - return storage - - -def remove_storage_from_cache( - *, - storage_class: type, - id: str | None = None, - name: str | None = None, -) -> None: - """Remove a storage from cache by ID or name.""" - if id: - _rm_from_cache_by_id(storage_class=storage_class, id=id) - - if name: - _rm_from_cache_by_name(storage_class=storage_class, name=name) - - -def _get_resource_client( - storage_class: type[TResource], - storage_client: StorageClient, - id: str, -) -> ResourceClient: - if issubclass(storage_class, Dataset): - return storage_client.dataset(id) - - if issubclass(storage_class, KeyValueStore): - return storage_client.key_value_store(id) - - if issubclass(storage_class, RequestQueue): - return storage_client.request_queue(id) - - raise ValueError(f'Unknown storage class label: {storage_class.__name__}') - - -def _get_resource_collection_client( - storage_class: type, - storage_client: StorageClient, -) -> ResourceCollectionClient: - if issubclass(storage_class, Dataset): - return storage_client.datasets() - - if issubclass(storage_class, KeyValueStore): - return storage_client.key_value_stores() - - if issubclass(storage_class, RequestQueue): - return storage_client.request_queues() - - raise ValueError(f'Unknown storage class: {storage_class.__name__}') diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 7cb58ae817..04416a0098 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -1,243 +1,100 @@ from __future__ import annotations -import csv -import io -import json import logging -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Literal, TextIO, TypedDict, cast +from io import StringIO +from typing import TYPE_CHECKING, overload -from typing_extensions import NotRequired, Required, Unpack, override +from typing_extensions import override from crawlee import service_locator -from crawlee._utils.byte_size import ByteSize from crawlee._utils.docs import docs_group -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import DatasetMetadata, StorageMetadata +from crawlee._utils.file import export_csv_to_stream, export_json_to_stream from ._base import Storage from ._key_value_store import KeyValueStore if TYPE_CHECKING: - from collections.abc import AsyncIterator, Callable + from collections.abc import AsyncIterator + from typing import Any, ClassVar, Literal + + from typing_extensions import Unpack - from crawlee._types import JsonSerializable, PushDataKwargs from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient - from crawlee.storage_clients.models import DatasetItemsListPage - -logger = logging.getLogger(__name__) - - -class GetDataKwargs(TypedDict): - """Keyword arguments for dataset's `get_data` method.""" - - offset: NotRequired[int] - """Skip the specified number of items at the start.""" - - limit: NotRequired[int] - """The maximum number of items to retrieve. Unlimited if None.""" - - clean: NotRequired[bool] - """Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty.""" - - desc: NotRequired[bool] - """Set to True to sort results in descending order.""" - - fields: NotRequired[list[str]] - """Fields to include in each item. Sorts fields as specified if provided.""" - - omit: NotRequired[list[str]] - """Fields to exclude from each item.""" - - unwind: NotRequired[str] - """Unwind items by a specified array field, turning each element into a separate item.""" - - skip_empty: NotRequired[bool] - """Exclude empty items from the results if True.""" - - skip_hidden: NotRequired[bool] - """Exclude fields starting with '#' if True.""" - - flatten: NotRequired[list[str]] - """Field to be flattened in returned items.""" - - view: NotRequired[str] - """Specify the dataset view to be used.""" - - -class ExportToKwargs(TypedDict): - """Keyword arguments for dataset's `export_to` method.""" - - key: Required[str] - """The key under which to save the data.""" - - content_type: NotRequired[Literal['json', 'csv']] - """The format in which to export the data. Either 'json' or 'csv'.""" - - to_key_value_store_id: NotRequired[str] - """ID of the key-value store to save the exported file.""" - - to_key_value_store_name: NotRequired[str] - """Name of the key-value store to save the exported file.""" - - -class ExportDataJsonKwargs(TypedDict): - """Keyword arguments for dataset's `export_data_json` method.""" - - skipkeys: NotRequired[bool] - """If True (default: False), dict keys that are not of a basic type (str, int, float, bool, None) will be skipped - instead of raising a `TypeError`.""" - - ensure_ascii: NotRequired[bool] - """Determines if non-ASCII characters should be escaped in the output JSON string.""" - - check_circular: NotRequired[bool] - """If False (default: True), skips the circular reference check for container types. A circular reference will - result in a `RecursionError` or worse if unchecked.""" - - allow_nan: NotRequired[bool] - """If False (default: True), raises a ValueError for out-of-range float values (nan, inf, -inf) to strictly comply - with the JSON specification. If True, uses their JavaScript equivalents (NaN, Infinity, -Infinity).""" - - cls: NotRequired[type[json.JSONEncoder]] - """Allows specifying a custom JSON encoder.""" - - indent: NotRequired[int] - """Specifies the number of spaces for indentation in the pretty-printed JSON output.""" - - separators: NotRequired[tuple[str, str]] - """A tuple of (item_separator, key_separator). The default is (', ', ': ') if indent is None and (',', ': ') - otherwise.""" - - default: NotRequired[Callable] - """A function called for objects that can't be serialized otherwise. It should return a JSON-encodable version - of the object or raise a `TypeError`.""" - - sort_keys: NotRequired[bool] - """Specifies whether the output JSON object should have keys sorted alphabetically.""" - - -class ExportDataCsvKwargs(TypedDict): - """Keyword arguments for dataset's `export_data_csv` method.""" - - dialect: NotRequired[str] - """Specifies a dialect to be used in CSV parsing and writing.""" - - delimiter: NotRequired[str] - """A one-character string used to separate fields. Defaults to ','.""" - - doublequote: NotRequired[bool] - """Controls how instances of `quotechar` inside a field should be quoted. When True, the character is doubled; - when False, the `escapechar` is used as a prefix. Defaults to True.""" - - escapechar: NotRequired[str] - """A one-character string used to escape the delimiter if `quoting` is set to `QUOTE_NONE` and the `quotechar` - if `doublequote` is False. Defaults to None, disabling escaping.""" - - lineterminator: NotRequired[str] - """The string used to terminate lines produced by the writer. Defaults to '\\r\\n'.""" - - quotechar: NotRequired[str] - """A one-character string used to quote fields containing special characters, like the delimiter or quotechar, - or fields containing new-line characters. Defaults to '\"'.""" + from crawlee.storage_clients._base import DatasetClient + from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata - quoting: NotRequired[int] - """Controls when quotes should be generated by the writer and recognized by the reader. Can take any of - the `QUOTE_*` constants, with a default of `QUOTE_MINIMAL`.""" + from ._types import ExportDataCsvKwargs, ExportDataJsonKwargs - skipinitialspace: NotRequired[bool] - """When True, spaces immediately following the delimiter are ignored. Defaults to False.""" - - strict: NotRequired[bool] - """When True, raises an exception on bad CSV input. Defaults to False.""" +logger = logging.getLogger(__name__) @docs_group('Classes') class Dataset(Storage): - """Represents an append-only structured storage, ideal for tabular data similar to database tables. - - The `Dataset` class is designed to store structured data, where each entry (row) maintains consistent attributes - (columns) across the dataset. It operates in an append-only mode, allowing new records to be added, but not - modified or deleted. This makes it particularly useful for storing results from web crawling operations. + """Dataset is a storage for managing structured tabular data. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + The dataset class provides a high-level interface for storing and retrieving structured data + with consistent schema, similar to database tables or spreadsheets. It abstracts the underlying + storage implementation details, offering a consistent API regardless of where the data is + physically stored. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/datasets/{DATASET_ID}/{INDEX}.json - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{DATASET_ID}`: Specifies the dataset, either "default" or a custom dataset ID. - - `{INDEX}`: Represents the zero-based index of the record within the dataset. + Dataset operates in an append-only mode, allowing new records to be added but not modified + or deleted after creation. This makes it particularly suitable for storing crawling results + and other data that should be immutable once collected. - To open a dataset, use the `open` class method by specifying an `id`, `name`, or `configuration`. If none are - provided, the default dataset for the current crawler run is used. Attempting to open a dataset by `id` that does - not exist will raise an error; however, if accessed by `name`, the dataset will be created if it doesn't already - exist. + The class provides methods for adding data, retrieving data with various filtering options, + and exporting data to different formats. You can create a dataset using the `open` class method, + specifying either a name or ID. The underlying storage implementation is determined by + the configured storage client. ### Usage ```python from crawlee.storages import Dataset + # Open a dataset dataset = await Dataset.open(name='my_dataset') - ``` - """ - _MAX_PAYLOAD_SIZE = ByteSize.from_mb(9) - """Maximum size for a single payload.""" + # Add data + await dataset.push_data({'title': 'Example Product', 'price': 99.99}) - _SAFETY_BUFFER_PERCENT = 0.01 / 100 # 0.01% - """Percentage buffer to reduce payload limit slightly for safety.""" + # Retrieve filtered data + results = await dataset.get_data(limit=10, desc=True) - _EFFECTIVE_LIMIT_SIZE = _MAX_PAYLOAD_SIZE - (_MAX_PAYLOAD_SIZE * _SAFETY_BUFFER_PERCENT) - """Calculated payload limit considering safety buffer.""" + # Export data + await dataset.export_to('results.json', content_type='json') + ``` + """ - def __init__(self, id: str, name: str | None, storage_client: StorageClient) -> None: - self._id = id - self._name = name - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) + _cache: ClassVar[dict[str, Dataset]] = {} + """A dictionary to cache datasets.""" - # Get resource clients from the storage client. - self._resource_client = storage_client.dataset(self._id) - self._resource_collection_client = storage_client.datasets() + def __init__(self, client: DatasetClient, cache_key: str) -> None: + """Initialize a new instance. - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> Dataset: - """Initialize a new instance of Dataset from a storage metadata object.""" - dataset = Dataset( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) + Preferably use the `Dataset.open` constructor to create a new instance. - dataset.storage_object = storage_object - return dataset + Args: + client: An instance of a dataset client. + cache_key: A unique key to identify the dataset in the cache. + """ + self._client = client + self._cache_key = cache_key - @property @override + @property def id(self) -> str: - return self._id + return self._client.metadata.id - @property @override - def name(self) -> str | None: - return self._name - @property - @override - def storage_object(self) -> StorageMetadata: - return self._storage_object + def name(self) -> str | None: + return self._client.metadata.name - @storage_object.setter @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object + @property + def metadata(self) -> DatasetMetadata: + return self._client.metadata @override @classmethod @@ -249,27 +106,41 @@ async def open( configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> Dataset: - from crawlee.storages._creation_management import open_storage + if id and name: + raise ValueError('Only one of "id" or "name" can be specified, not both.') - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - return await open_storage( - storage_class=cls, + cache_key = cls.compute_cache_key( id=id, name=name, configuration=configuration, storage_client=storage_client, ) + if cache_key in cls._cache: + return cls._cache[cache_key] + + client = await storage_client.open_dataset_client( + id=id, + name=name, + configuration=configuration, + ) + + dataset = cls(client, cache_key) + cls._cache[cache_key] = dataset + return dataset + @override async def drop(self) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + # Remove from cache before dropping + if self._cache_key in self._cache: + del self._cache[self._cache_key] - await self._resource_client.delete() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + await self._client.drop() - async def push_data(self, data: JsonSerializable, **kwargs: Unpack[PushDataKwargs]) -> None: + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: """Store an object or an array of objects to the dataset. The size of the data is limited by the receiving API and therefore `push_data()` will only @@ -279,127 +150,65 @@ async def push_data(self, data: JsonSerializable, **kwargs: Unpack[PushDataKwarg Args: data: A JSON serializable data structure to be stored in the dataset. The JSON representation of each item must be smaller than 9MB. - kwargs: Keyword arguments for the storage client method. """ - # Handle singular items - if not isinstance(data, list): - items = await self.check_and_serialize(data) - return await self._resource_client.push_items(items, **kwargs) - - # Handle lists - payloads_generator = (await self.check_and_serialize(item, index) for index, item in enumerate(data)) - - # Invoke client in series to preserve the order of data - async for items in self._chunk_by_size(payloads_generator): - await self._resource_client.push_items(items, **kwargs) - - return None + await self._client.push_data(data=data) - async def get_data(self, **kwargs: Unpack[GetDataKwargs]) -> DatasetItemsListPage: - """Retrieve dataset items based on filtering, sorting, and pagination parameters. + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + """Retrieve a paginated list of items from a dataset based on various filtering parameters. - This method allows customization of the data retrieval process from a dataset, supporting operations such as - field selection, ordering, and skipping specific records based on provided parameters. + This method provides the flexibility to filter, sort, and modify the appearance of dataset items + when listed. Each parameter modifies the result set according to its purpose. The method also + supports pagination through 'offset' and 'limit' parameters. Args: - kwargs: Keyword arguments for the storage client method. + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. + skip_hidden: Excludes fields starting with '#' if True. + flatten: Fields to be flattened in returned items. + view: Specifies the dataset view to be used. Returns: - List page containing filtered and paginated dataset items. + An object with filtered, sorted, and paginated dataset items plus pagination details. """ - return await self._resource_client.list_items(**kwargs) - - async def write_to_csv(self, destination: TextIO, **kwargs: Unpack[ExportDataCsvKwargs]) -> None: - """Export the entire dataset into an arbitrary stream. - - Args: - destination: The stream into which the dataset contents should be written. - kwargs: Additional keyword arguments for `csv.writer`. - """ - items: list[dict] = [] - limit = 1000 - offset = 0 - - while True: - list_items = await self._resource_client.list_items(limit=limit, offset=offset) - items.extend(list_items.items) - if list_items.total <= offset + list_items.count: - break - offset += list_items.count - - if items: - writer = csv.writer(destination, **kwargs) - writer.writerows([items[0].keys(), *[item.values() for item in items]]) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') - - async def write_to_json(self, destination: TextIO, **kwargs: Unpack[ExportDataJsonKwargs]) -> None: - """Export the entire dataset into an arbitrary stream. - - Args: - destination: The stream into which the dataset contents should be written. - kwargs: Additional keyword arguments for `json.dump`. - """ - items: list[dict] = [] - limit = 1000 - offset = 0 - - while True: - list_items = await self._resource_client.list_items(limit=limit, offset=offset) - items.extend(list_items.items) - if list_items.total <= offset + list_items.count: - break - offset += list_items.count - - if items: - json.dump(items, destination, **kwargs) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') - - async def export_to(self, **kwargs: Unpack[ExportToKwargs]) -> None: - """Export the entire dataset into a specified file stored under a key in a key-value store. - - This method consolidates all entries from a specified dataset into one file, which is then saved under a - given key in a key-value store. The format of the exported file is determined by the `content_type` parameter. - Either the dataset's ID or name should be specified, and similarly, either the target key-value store's ID or - name should be used. - - Args: - kwargs: Keyword arguments for the storage client method. - """ - key = cast('str', kwargs.get('key')) - content_type = kwargs.get('content_type', 'json') - to_key_value_store_id = kwargs.get('to_key_value_store_id') - to_key_value_store_name = kwargs.get('to_key_value_store_name') - - key_value_store = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) - - output = io.StringIO() - if content_type == 'csv': - await self.write_to_csv(output) - elif content_type == 'json': - await self.write_to_json(output) - else: - raise ValueError('Unsupported content type, expecting CSV or JSON') - - if content_type == 'csv': - await key_value_store.set_value(key, output.getvalue(), 'text/csv') - - if content_type == 'json': - await key_value_store.set_value(key, output.getvalue(), 'application/json') - - async def get_info(self) -> DatasetMetadata | None: - """Get an object containing general information about the dataset.""" - metadata = await self._resource_client.get() - if isinstance(metadata, DatasetMetadata): - return metadata - return None + return await self._client.get_data( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + flatten=flatten, + view=view, + ) async def iterate_items( self, *, offset: int = 0, - limit: int | None = None, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -408,27 +217,29 @@ async def iterate_items( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict]: - """Iterate over dataset items, applying filtering, sorting, and pagination. + """Iterate over items in the dataset according to specified filters and sorting. - Retrieve dataset items incrementally, allowing fine-grained control over the data fetched. The function - supports various parameters to filter, sort, and limit the data returned, facilitating tailored dataset - queries. + This method allows for asynchronously iterating through dataset items while applying various filters such as + skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` + parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and + `skip_hidden` parameters. Args: - offset: Initial number of items to skip. - limit: Max number of items to return. No limit if None. - clean: Filter out empty items and hidden fields if True. - desc: Return items in reverse order if True. - fields: Specific fields to include in each item. - omit: Fields to omit from each item. - unwind: Field name to unwind items by. - skip_empty: Omits empty items if True. + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. skip_hidden: Excludes fields starting with '#' if True. Yields: - Each item from the dataset as a dictionary. + An asynchronous iterator of dictionary objects, each representing a dataset item after applying + the specified filters and transformations. """ - async for item in self._resource_client.iterate_items( + async for item in self._client.iterate_items( offset=offset, limit=limit, clean=clean, @@ -441,59 +252,121 @@ async def iterate_items( ): yield item - @classmethod - async def check_and_serialize(cls, item: JsonSerializable, index: int | None = None) -> str: - """Serialize a given item to JSON, checks its serializability and size against a limit. + async def list_items( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> list[dict]: + """Retrieve a list of all items from the dataset according to specified filters and sorting. + + This method collects all dataset items into a list while applying various filters such as + skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` + parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and + `skip_hidden` parameters. Args: - item: The item to serialize. - index: Index of the item, used for error context. + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. + skip_hidden: Excludes fields starting with '#' if True. Returns: - Serialized JSON string. - - Raises: - ValueError: If item is not JSON serializable or exceeds size limit. + A list of dictionary objects, each representing a dataset item after applying + the specified filters and transformations. """ - s = ' ' if index is None else f' at index {index} ' - - try: - payload = await json_dumps(item) - except Exception as exc: - raise ValueError(f'Data item{s}is not serializable to JSON.') from exc - - payload_size = ByteSize(len(payload.encode('utf-8'))) - if payload_size > cls._EFFECTIVE_LIMIT_SIZE: - raise ValueError(f'Data item{s}is too large (size: {payload_size}, limit: {cls._EFFECTIVE_LIMIT_SIZE})') - - return payload - - async def _chunk_by_size(self, items: AsyncIterator[str]) -> AsyncIterator[str]: - """Yield chunks of JSON arrays composed of input strings, respecting a size limit. + return [ + item + async for item in self.iterate_items( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + ) + ] + + @overload + async def export_to( + self, + key: str, + content_type: Literal['json'], + to_kvs_id: str | None = None, + to_kvs_name: str | None = None, + to_kvs_storage_client: StorageClient | None = None, + to_kvs_configuration: Configuration | None = None, + **kwargs: Unpack[ExportDataJsonKwargs], + ) -> None: ... + + @overload + async def export_to( + self, + key: str, + content_type: Literal['csv'], + to_kvs_id: str | None = None, + to_kvs_name: str | None = None, + to_kvs_storage_client: StorageClient | None = None, + to_kvs_configuration: Configuration | None = None, + **kwargs: Unpack[ExportDataCsvKwargs], + ) -> None: ... + + async def export_to( + self, + key: str, + content_type: Literal['json', 'csv'] = 'json', + to_kvs_id: str | None = None, + to_kvs_name: str | None = None, + to_kvs_storage_client: StorageClient | None = None, + to_kvs_configuration: Configuration | None = None, + **kwargs: Any, + ) -> None: + """Export the entire dataset into a specified file stored under a key in a key-value store. - Groups an iterable of JSON string payloads into larger JSON arrays, ensuring the total size - of each array does not exceed `EFFECTIVE_LIMIT_SIZE`. Each output is a JSON array string that - contains as many payloads as possible without breaching the size threshold, maintaining the - order of the original payloads. Assumes individual items are below the size limit. + This method consolidates all entries from a specified dataset into one file, which is then saved under a + given key in a key-value store. The format of the exported file is determined by the `content_type` parameter. + Either the dataset's ID or name should be specified, and similarly, either the target key-value store's ID or + name should be used. Args: - items: Iterable of JSON string payloads. - - Yields: - Strings representing JSON arrays of payloads, each staying within the size limit. + key: The key under which to save the data in the key-value store. + content_type: The format in which to export the data. + to_kvs_id: ID of the key-value store to save the exported file. + Specify only one of ID or name. + to_kvs_name: Name of the key-value store to save the exported file. + Specify only one of ID or name. + to_kvs_storage_client: Storage client to use for the key-value store. + to_kvs_configuration: Configuration for the key-value store. + kwargs: Additional parameters for the export operation, specific to the chosen content type. """ - last_chunk_size = ByteSize(2) # Add 2 bytes for [] wrapper. - current_chunk = [] - - async for payload in items: - payload_size = ByteSize(len(payload.encode('utf-8'))) - - if last_chunk_size + payload_size <= self._EFFECTIVE_LIMIT_SIZE: - current_chunk.append(payload) - last_chunk_size += payload_size + ByteSize(1) # Add 1 byte for ',' separator. - else: - yield f'[{",".join(current_chunk)}]' - current_chunk = [payload] - last_chunk_size = payload_size + ByteSize(2) # Add 2 bytes for [] wrapper. + kvs = await KeyValueStore.open( + id=to_kvs_id, + name=to_kvs_name, + configuration=to_kvs_configuration, + storage_client=to_kvs_storage_client, + ) + dst = StringIO() - yield f'[{",".join(current_chunk)}]' + if content_type == 'csv': + await export_csv_to_stream(self.iterate_items(), dst, **kwargs) + await kvs.set_value(key, dst.getvalue(), 'text/csv') + elif content_type == 'json': + await export_json_to_stream(self.iterate_items(), dst, **kwargs) + await kvs.set_value(key, dst.getvalue(), 'application/json') + else: + raise ValueError('Unsupported content type, expecting CSV or JSON') diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index fc077726d1..094661af28 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -2,7 +2,6 @@ import asyncio from collections.abc import AsyncIterator -from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload @@ -12,7 +11,7 @@ from crawlee import service_locator from crawlee._types import JsonSerializable # noqa: TC001 from crawlee._utils.docs import docs_group -from crawlee.storage_clients.models import KeyValueStoreKeyInfo, KeyValueStoreMetadata, StorageMetadata +from crawlee.storage_clients.models import KeyValueStoreMetadata from ._base import Storage @@ -22,6 +21,8 @@ from crawlee._utils.recoverable_state import RecoverableState from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient + from crawlee.storage_clients._base import KeyValueStoreClient + from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata T = TypeVar('T') @@ -34,39 +35,38 @@ class AutosavedValue(RootModel): @docs_group('Classes') class KeyValueStore(Storage): - """Represents a key-value based storage for reading and writing data records or files. + """Key-value store is a storage for reading and writing data records with unique key identifiers. - Each data record is identified by a unique key and associated with a specific MIME content type. This class is - commonly used in crawler runs to store inputs and outputs, typically in JSON format, but it also supports other - content types. + The key-value store class acts as a high-level interface for storing, retrieving, and managing data records + identified by unique string keys. It abstracts away the underlying storage implementation details, + allowing you to work with the same API regardless of whether data is stored in memory, on disk, + or in the cloud. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + Each data record is associated with a specific MIME content type, allowing storage of various + data formats such as JSON, text, images, HTML snapshots or any binary data. This class is + commonly used to store inputs, outputs, and other artifacts of crawler operations. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/key_value_stores/{STORE_ID}/{KEY}.{EXT} - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{STORE_ID}`: The identifier for the key-value store, either "default" or as specified by - `CRAWLEE_DEFAULT_KEY_VALUE_STORE_ID`. - - `{KEY}`: The unique key for the record. - - `{EXT}`: The file extension corresponding to the MIME type of the content. - - To open a key-value store, use the `open` class method, providing an `id`, `name`, or optional `configuration`. - If none are specified, the default store for the current crawler run is used. Attempting to open a store by `id` - that does not exist will raise an error; however, if accessed by `name`, the store will be created if it does not - already exist. + You can instantiate a key-value store using the `open` class method, which will create a store + with the specified name or id. The underlying storage implementation is determined by the configured + storage client. ### Usage ```python from crawlee.storages import KeyValueStore - kvs = await KeyValueStore.open(name='my_kvs') + # Open a named key-value store + kvs = await KeyValueStore.open(name='my-store') + + # Store and retrieve data + await kvs.set_value('product-1234.json', [{'name': 'Smartphone', 'price': 799.99}]) + product = await kvs.get_value('product-1234') ``` """ + _cache: ClassVar[dict[str, KeyValueStore]] = {} + """A dictionary to cache key-value stores.""" + # Cache for recoverable (auto-saved) values _autosaved_values: ClassVar[ dict[ @@ -75,53 +75,34 @@ class KeyValueStore(Storage): ] ] = {} - def __init__(self, id: str, name: str | None, storage_client: StorageClient) -> None: - self._id = id - self._name = name - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) - - # Get resource clients from storage client - self._resource_client = storage_client.key_value_store(self._id) - self._autosave_lock = asyncio.Lock() + def __init__(self, client: KeyValueStoreClient, cache_key: str) -> None: + """Initialize a new instance. - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> KeyValueStore: - """Initialize a new instance of KeyValueStore from a storage metadata object.""" - key_value_store = KeyValueStore( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) + Preferably use the `KeyValueStore.open` constructor to create a new instance. - key_value_store.storage_object = storage_object - return key_value_store + Args: + client: An instance of a key-value store client. + cache_key: A unique key to identify the key-value store in the cache. + """ + self._client = client + self._cache_key = cache_key + self._autosave_lock = asyncio.Lock() + self._persist_state_event_started = False - @property @override + @property def id(self) -> str: - return self._id + return self._client.metadata.id - @property @override - def name(self) -> str | None: - return self._name - @property - @override - def storage_object(self) -> StorageMetadata: - return self._storage_object + def name(self) -> str | None: + return self._client.metadata.name - @storage_object.setter @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object - - async def get_info(self) -> KeyValueStoreMetadata | None: - """Get an object containing general information about the key value store.""" - return await self._resource_client.get() + @property + def metadata(self) -> KeyValueStoreMetadata: + return self._client.metadata @override @classmethod @@ -133,26 +114,43 @@ async def open( configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> KeyValueStore: - from crawlee.storages._creation_management import open_storage + if id and name: + raise ValueError('Only one of "id" or "name" can be specified, not both.') - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - return await open_storage( - storage_class=cls, + cache_key = cls.compute_cache_key( id=id, name=name, configuration=configuration, storage_client=storage_client, ) + if cache_key in cls._cache: + return cls._cache[cache_key] + + client = await storage_client.open_key_value_store_client( + id=id, + name=name, + configuration=configuration, + ) + + kvs = cls(client, cache_key) + cls._cache[cache_key] = kvs + return kvs + @override async def drop(self) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + # Remove from cache before dropping + if self._cache_key in self._cache: + del self._cache[self._cache_key] - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + # Clear cache with persistent values await self._clear_cache() - await self._resource_client.delete() + + # Drop the key-value store client + await self._client.drop() @overload async def get_value(self, key: str) -> Any: ... @@ -173,44 +171,75 @@ async def get_value(self, key: str, default_value: T | None = None) -> T | None: Returns: The value associated with the given key. `default_value` is used in case the record does not exist. """ - record = await self._resource_client.get_record(key) + record = await self._client.get_value(key=key) return record.value if record else default_value - async def iterate_keys(self, exclusive_start_key: str | None = None) -> AsyncIterator[KeyValueStoreKeyInfo]: + async def set_value( + self, + key: str, + value: Any, + content_type: str | None = None, + ) -> None: + """Set a value in the KVS. + + Args: + key: Key of the record to set. + value: Value to set. + content_type: The MIME content type string. + """ + await self._client.set_value(key=key, value=value, content_type=content_type) + + async def delete_value(self, key: str) -> None: + """Delete a value from the KVS. + + Args: + key: Key of the record to delete. + """ + await self._client.delete_value(key=key) + + async def iterate_keys( + self, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: """Iterate over the existing keys in the KVS. Args: exclusive_start_key: Key to start the iteration from. + limit: Maximum number of keys to return. None means no limit. Yields: Information about the key. """ - while True: - list_keys = await self._resource_client.list_keys(exclusive_start_key=exclusive_start_key) - for item in list_keys.items: - yield KeyValueStoreKeyInfo(key=item.key, size=item.size) - - if not list_keys.is_truncated: - break - exclusive_start_key = list_keys.next_exclusive_start_key + async for item in self._client.iterate_keys( + exclusive_start_key=exclusive_start_key, + limit=limit, + ): + yield item - async def set_value( + async def list_keys( self, - key: str, - value: Any, - content_type: str | None = None, - ) -> None: - """Set a value in the KVS. + exclusive_start_key: str | None = None, + limit: int = 1000, + ) -> list[KeyValueStoreRecordMetadata]: + """List all the existing keys in the KVS. + + It uses client's `iterate_keys` method to get the keys. Args: - key: Key of the record to set. - value: Value to set. If `None`, the record is deleted. - content_type: Content type of the record. - """ - if value is None: - return await self._resource_client.delete_record(key) + exclusive_start_key: Key to start the iteration from. + limit: Maximum number of keys to return. - return await self._resource_client.set_record(key, value, content_type) + Returns: + A list of keys in the KVS. + """ + return [ + key + async for key in self._client.iterate_keys( + exclusive_start_key=exclusive_start_key, + limit=limit, + ) + ] async def get_public_url(self, key: str) -> str: """Get the public URL for the given key. @@ -221,7 +250,7 @@ async def get_public_url(self, key: str) -> str: Returns: The public URL for the given key. """ - return await self._resource_client.get_public_url(key) + return await self._client.get_public_url(key=key) async def get_auto_saved_value( self, @@ -242,7 +271,7 @@ async def get_auto_saved_value( default_value = {} if default_value is None else default_value async with self._autosave_lock: - cache = self._autosaved_values.setdefault(self._id, {}) + cache = self._autosaved_values.setdefault(self.id, {}) if key in cache: return cache[key].current_value.root @@ -250,7 +279,7 @@ async def get_auto_saved_value( cache[key] = recoverable_state = RecoverableState( default_state=AutosavedValue(default_value), persistence_enabled=True, - persist_state_kvs_id=self._id, + persist_state_kvs_name=self.name, # TODO: use id instead of name, once it's implemented persist_state_key=key, logger=logger, ) @@ -259,17 +288,17 @@ async def get_auto_saved_value( return recoverable_state.current_value.root - async def _clear_cache(self) -> None: - """Clear cache with autosaved values.""" + async def persist_autosaved_values(self) -> None: + """Force autosaved values to be saved without waiting for an event in Event Manager.""" if self.id in self._autosaved_values: cache = self._autosaved_values[self.id] for value in cache.values(): - await value.teardown() - cache.clear() + await value.persist_state() - async def persist_autosaved_values(self) -> None: - """Force autosaved values to be saved without waiting for an event in Event Manager.""" + async def _clear_cache(self) -> None: + """Clear cache with autosaved values.""" if self.id in self._autosaved_values: cache = self._autosaved_values[self.id] for value in cache.values(): - await value.persist_state() + await value.teardown() + cache.clear() diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index b3274ccc81..e152c5943a 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -1,23 +1,16 @@ from __future__ import annotations import asyncio -from collections import deque -from contextlib import suppress -from datetime import datetime, timedelta, timezone +from datetime import timedelta from logging import getLogger -from typing import TYPE_CHECKING, Any, TypedDict, TypeVar +from typing import TYPE_CHECKING, ClassVar, TypeVar -from cachetools import LRUCache from typing_extensions import override -from crawlee import service_locator -from crawlee._utils.crypto import crypto_random_object_id +from crawlee import Request, service_locator from crawlee._utils.docs import docs_group -from crawlee._utils.requests import unique_key_to_request_id from crawlee._utils.wait import wait_for_all_tasks_for_finish -from crawlee.events import Event from crawlee.request_loaders import RequestManager -from crawlee.storage_clients.models import ProcessedRequest, RequestQueueMetadata, StorageMetadata from ._base import Storage @@ -27,131 +20,98 @@ from crawlee import Request from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient + from crawlee.storage_clients._base import RequestQueueClient + from crawlee.storage_clients.models import ProcessedRequest, RequestQueueMetadata logger = getLogger(__name__) T = TypeVar('T') -class CachedRequest(TypedDict): - id: str - was_already_handled: bool - hydrated: Request | None - lock_expires_at: datetime | None - forefront: bool - - @docs_group('Classes') class RequestQueue(Storage, RequestManager): - """Represents a queue storage for managing HTTP requests in web crawling operations. + """Request queue is a storage for managing HTTP requests. - The `RequestQueue` class handles a queue of HTTP requests, each identified by a unique URL, to facilitate structured - web crawling. It supports both breadth-first and depth-first crawling strategies, allowing for recursive crawling - starting from an initial set of URLs. Each URL in the queue is uniquely identified by a `unique_key`, which can be - customized to allow the same URL to be added multiple times under different keys. + The request queue class serves as a high-level interface for organizing and managing HTTP requests + during web crawling. It provides methods for adding, retrieving, and manipulating requests throughout + the crawling lifecycle, abstracting away the underlying storage implementation details. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + Request queue maintains the state of each URL to be crawled, tracking whether it has been processed, + is currently being handled, or is waiting in the queue. Each URL in the queue is uniquely identified + by a `unique_key` property, which prevents duplicate processing unless explicitly configured otherwise. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/request_queues/{QUEUE_ID}/{REQUEST_ID}.json - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{QUEUE_ID}`: The identifier for the request queue, either "default" or as specified. - - `{REQUEST_ID}`: The unique identifier for each request in the queue. + The class supports both breadth-first and depth-first crawling strategies through its `forefront` parameter + when adding requests. It also provides mechanisms for error handling and request reclamation when + processing fails. - The `RequestQueue` supports both creating new queues and opening existing ones by `id` or `name`. Named queues - persist indefinitely, while unnamed queues expire after 7 days unless specified otherwise. The queue supports - mutable operations, allowing URLs to be added and removed as needed. + You can open a request queue using the `open` class method, specifying either a name or ID to identify + the queue. The underlying storage implementation is determined by the configured storage client. ### Usage ```python from crawlee.storages import RequestQueue - rq = await RequestQueue.open(name='my_rq') - ``` - """ + # Open a request queue + rq = await RequestQueue.open(name='my_queue') - _MAX_CACHED_REQUESTS = 1_000_000 - """Maximum number of requests that can be cached.""" + # Add a request + await rq.add_request('https://example.com') - def __init__( - self, - id: str, - name: str | None, - storage_client: StorageClient, - ) -> None: - config = service_locator.get_configuration() - event_manager = service_locator.get_event_manager() + # Process requests + request = await rq.fetch_next_request() + if request: + try: + # Process the request + # ... + await rq.mark_request_as_handled(request) + except Exception: + await rq.reclaim_request(request) + ``` + """ - self._id = id - self._name = name + _cache: ClassVar[dict[str, RequestQueue]] = {} + """A dictionary to cache request queues.""" - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) + def __init__(self, client: RequestQueueClient, cache_key: str) -> None: + """Initialize a new instance. - # Get resource clients from storage client - self._resource_client = storage_client.request_queue(self._id) - self._resource_collection_client = storage_client.request_queues() - - self._request_lock_time = timedelta(minutes=3) - self._queue_paused_for_migration = False - self._queue_has_locked_requests: bool | None = None - self._should_check_for_forefront_requests = False - - self._is_finished_log_throttle_counter = 0 - self._dequeued_request_count = 0 - - event_manager.on(event=Event.MIGRATING, listener=lambda _: setattr(self, '_queue_paused_for_migration', True)) - event_manager.on(event=Event.MIGRATING, listener=self._clear_possible_locks) - event_manager.on(event=Event.ABORTING, listener=self._clear_possible_locks) - - # Other internal attributes - self._tasks = list[asyncio.Task]() - self._client_key = crypto_random_object_id() - self._internal_timeout = config.internal_timeout or timedelta(minutes=5) - self._assumed_total_count = 0 - self._assumed_handled_count = 0 - self._queue_head = deque[str]() - self._list_head_and_lock_task: asyncio.Task | None = None - self._last_activity = datetime.now(timezone.utc) - self._requests_cache: LRUCache[str, CachedRequest] = LRUCache(maxsize=self._MAX_CACHED_REQUESTS) + Preferably use the `RequestQueue.open` constructor to create a new instance. - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> RequestQueue: - """Initialize a new instance of RequestQueue from a storage metadata object.""" - request_queue = RequestQueue( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) + Args: + client: An instance of a request queue client. + cache_key: A unique key to identify the request queue in the cache. + """ + self._client = client + self._cache_key = cache_key - request_queue.storage_object = storage_object - return request_queue + self._add_requests_tasks = list[asyncio.Task]() + """A list of tasks for adding requests to the queue.""" - @property @override + @property def id(self) -> str: - return self._id + return self._client.metadata.id - @property @override + @property def name(self) -> str | None: - return self._name + return self._client.metadata.name + @override @property + def metadata(self) -> RequestQueueMetadata: + return self._client.metadata + @override - def storage_object(self) -> StorageMetadata: - return self._storage_object + @property + async def handled_count(self) -> int: + return self._client.metadata.handled_request_count - @storage_object.setter @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object + @property + async def total_count(self) -> int: + return self._client.metadata.total_request_count @override @classmethod @@ -163,29 +123,39 @@ async def open( configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> RequestQueue: - from crawlee.storages._creation_management import open_storage + if id and name: + raise ValueError('Only one of "id" or "name" can be specified, not both.') - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - return await open_storage( - storage_class=cls, + cache_key = cls.compute_cache_key( id=id, name=name, configuration=configuration, storage_client=storage_client, ) - @override - async def drop(self, *, timeout: timedelta | None = None) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + if cache_key in cls._cache: + return cls._cache[cache_key] - # Wait for all tasks to finish - await wait_for_all_tasks_for_finish(self._tasks, logger=logger, timeout=timeout) + client = await storage_client.open_request_queue_client( + id=id, + name=name, + configuration=configuration, + ) - # Delete the storage from the underlying client and remove it from the cache - await self._resource_client.delete() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + rq = cls(client, cache_key) + cls._cache[cache_key] = rq + return rq + + @override + async def drop(self) -> None: + # Remove from cache before dropping + if self._cache_key in self._cache: + del self._cache[self._cache_key] + + await self._client.drop() @override async def add_request( @@ -195,40 +165,15 @@ async def add_request( forefront: bool = False, ) -> ProcessedRequest: request = self._transform_request(request) - self._last_activity = datetime.now(timezone.utc) - - cache_key = unique_key_to_request_id(request.unique_key) - cached_info = self._requests_cache.get(cache_key) - - if cached_info: - request.id = cached_info['id'] - # We may assume that if request is in local cache then also the information if the request was already - # handled is there because just one client should be using one queue. - return ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=cached_info['was_already_handled'], - ) - - processed_request = await self._resource_client.add_request(request, forefront=forefront) - processed_request.unique_key = request.unique_key - - self._cache_request(cache_key, processed_request, forefront=forefront) - - if not processed_request.was_already_present and forefront: - self._should_check_for_forefront_requests = True - - if request.handled_at is None and not processed_request.was_already_present: - self._assumed_total_count += 1 - - return processed_request + response = await self._client.add_batch_of_requests([request], forefront=forefront) + return response.processed_requests[0] @override - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, wait_time_between_batches: timedelta = timedelta(seconds=1), wait_for_all_requests_to_be_added: bool = False, @@ -240,21 +185,31 @@ async def add_requests_batched( # Wait for the first batch to be added first_batch = transformed_requests[:batch_size] if first_batch: - await self._process_batch(first_batch, base_retry_wait=wait_time_between_batches) + await self._process_batch( + first_batch, + base_retry_wait=wait_time_between_batches, + forefront=forefront, + ) async def _process_remaining_batches() -> None: for i in range(batch_size, len(transformed_requests), batch_size): batch = transformed_requests[i : i + batch_size] - await self._process_batch(batch, base_retry_wait=wait_time_between_batches) + await self._process_batch( + batch, + base_retry_wait=wait_time_between_batches, + forefront=forefront, + ) if i + batch_size < len(transformed_requests): await asyncio.sleep(wait_time_secs) # Create and start the task to process remaining batches in the background remaining_batches_task = asyncio.create_task( - _process_remaining_batches(), name='request_queue_process_remaining_batches_task' + _process_remaining_batches(), + name='request_queue_process_remaining_batches_task', ) - self._tasks.append(remaining_batches_task) - remaining_batches_task.add_done_callback(lambda _: self._tasks.remove(remaining_batches_task)) + + self._add_requests_tasks.append(remaining_batches_task) + remaining_batches_task.add_done_callback(lambda _: self._add_requests_tasks.remove(remaining_batches_task)) # Wait for all tasks to finish if requested if wait_for_all_requests_to_be_added: @@ -264,42 +219,6 @@ async def _process_remaining_batches() -> None: timeout=wait_for_all_requests_to_be_added_timeout, ) - async def _process_batch(self, batch: Sequence[Request], base_retry_wait: timedelta, attempt: int = 1) -> None: - max_attempts = 5 - response = await self._resource_client.batch_add_requests(batch) - - if response.unprocessed_requests: - logger.debug(f'Following requests were not processed: {response.unprocessed_requests}.') - if attempt > max_attempts: - logger.warning( - f'Following requests were not processed even after {max_attempts} attempts:\n' - f'{response.unprocessed_requests}' - ) - else: - logger.debug('Retry to add requests.') - unprocessed_requests_unique_keys = {request.unique_key for request in response.unprocessed_requests} - retry_batch = [request for request in batch if request.unique_key in unprocessed_requests_unique_keys] - await asyncio.sleep((base_retry_wait * attempt).total_seconds()) - await self._process_batch(retry_batch, base_retry_wait=base_retry_wait, attempt=attempt + 1) - - request_count = len(batch) - len(response.unprocessed_requests) - self._assumed_total_count += request_count - if request_count: - logger.debug( - f'Added {request_count} requests to the queue. Processed requests: {response.processed_requests}' - ) - - async def get_request(self, request_id: str) -> Request | None: - """Retrieve a request from the queue. - - Args: - request_id: ID of the request to retrieve. - - Returns: - The retrieved request, or `None`, if it does not exist. - """ - return await self._resource_client.get_request(request_id) - async def fetch_next_request(self) -> Request | None: """Return the next request in the queue to be processed. @@ -313,75 +232,35 @@ async def fetch_next_request(self) -> Request | None: instead. Returns: - The request or `None` if there are no more pending requests. + The next request to process, or `None` if there are no more pending requests. """ - self._last_activity = datetime.now(timezone.utc) - - await self._ensure_head_is_non_empty() + return await self._client.fetch_next_request() - # We are likely done at this point. - if len(self._queue_head) == 0: - return None - - next_request_id = self._queue_head.popleft() - request = await self._get_or_hydrate_request(next_request_id) - - # NOTE: It can happen that the queue head index is inconsistent with the main queue table. - # This can occur in two situations: + async def get_request(self, request_id: str) -> Request | None: + """Retrieve a specific request from the queue by its ID. - # 1) - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). In this case, keep the request marked as in progress for a short while, - # so that is_finished() doesn't return true and _ensure_head_is_non_empty() doesn't not load the request into - # the queueHeadDict straight again. After the interval expires, fetch_next_request() will try to fetch this - # request again, until it eventually appears in the main table. - if request is None: - logger.debug( - 'Cannot find a request from the beginning of queue, will be retried later', - extra={'nextRequestId': next_request_id}, - ) - return None - - # 2) - # Queue head index is behind the main table and the underlying request was already handled (by some other - # client, since we keep the track of handled requests in recently_handled dictionary). We just add the request - # to the recently_handled dictionary so that next call to _ensure_head_is_non_empty() will not put the request - # again to queue_head_dict. - if request.handled_at is not None: - logger.debug( - 'Request fetched from the beginning of queue was already handled', - extra={'nextRequestId': next_request_id}, - ) - return None + Args: + request_id: The ID of the request to retrieve. - self._dequeued_request_count += 1 - return request + Returns: + The request with the specified ID, or `None` if no such request exists. + """ + return await self._client.get_request(request_id) async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: """Mark a request as handled after successful processing. - Handled requests will never again be returned by the `RequestQueue.fetch_next_request` method. + This method should be called after a request has been successfully processed. + Once marked as handled, the request will be removed from the queue and will + not be returned in subsequent calls to `fetch_next_request` method. Args: request: The request to mark as handled. Returns: - Information about the queue operation. `None` if the given request was not in progress. + Information about the queue operation. """ - self._last_activity = datetime.now(timezone.utc) - - if request.handled_at is None: - request.handled_at = datetime.now(timezone.utc) - - processed_request = await self._resource_client.update_request(request) - processed_request.unique_key = request.unique_key - self._dequeued_request_count -= 1 - - if not processed_request.was_already_handled: - self._assumed_handled_count += 1 - - self._cache_request(unique_key_to_request_id(request.unique_key), processed_request, forefront=False) - return processed_request + return await self._client.mark_request_as_handled(request) async def reclaim_request( self, @@ -389,325 +268,83 @@ async def reclaim_request( *, forefront: bool = False, ) -> ProcessedRequest | None: - """Reclaim a failed request back to the queue. + """Reclaim a failed request back to the queue for later processing. - The request will be returned for processing later again by another call to `RequestQueue.fetch_next_request`. + If a request fails during processing, this method can be used to return it to the queue. + The request will be returned for processing again in a subsequent call + to `RequestQueue.fetch_next_request`. Args: request: The request to return to the queue. - forefront: Whether to add the request to the head or the end of the queue. + forefront: If true, the request will be added to the beginning of the queue. + Otherwise, it will be added to the end. Returns: - Information about the queue operation. `None` if the given request was not in progress. + Information about the queue operation. """ - self._last_activity = datetime.now(timezone.utc) - - processed_request = await self._resource_client.update_request(request, forefront=forefront) - processed_request.unique_key = request.unique_key - self._cache_request(unique_key_to_request_id(request.unique_key), processed_request, forefront=forefront) - - if forefront: - self._should_check_for_forefront_requests = True - - if processed_request: - # Try to delete the request lock if possible - try: - await self._resource_client.delete_request_lock(request.id, forefront=forefront) - except Exception as err: - logger.debug(f'Failed to delete request lock for request {request.id}', exc_info=err) - - return processed_request + return await self._client.reclaim_request(request, forefront=forefront) async def is_empty(self) -> bool: - """Check whether the queue is empty. + """Check if the request queue is empty. + + An empty queue means that there are no requests currently in the queue, either pending or being processed. + However, this does not necessarily mean that the crawling operation is finished, as there still might be + tasks that could add additional requests to the queue. Returns: - bool: `True` if the next call to `RequestQueue.fetch_next_request` would return `None`, otherwise `False`. + True if the request queue is empty, False otherwise. """ - await self._ensure_head_is_non_empty() - return len(self._queue_head) == 0 + return await self._client.is_empty() async def is_finished(self) -> bool: - """Check whether the queue is finished. + """Check if the request queue is finished. - Due to the nature of distributed storage used by the queue, the function might occasionally return a false - negative, but it will never return a false positive. + A finished queue means that all requests in the queue have been processed (the queue is empty) and there + are no more tasks that could add additional requests to the queue. This is the definitive way to check + if a crawling operation is complete. Returns: - bool: `True` if all requests were already handled and there are no more left. `False` otherwise. + True if the request queue is finished (empty and no pending add operations), False otherwise. """ - if self._tasks: - logger.debug('Background tasks are still in progress') - return False - - if self._queue_head: - logger.debug( - 'There are still ids in the queue head that are pending processing', - extra={ - 'queue_head_ids_pending': len(self._queue_head), - }, - ) - - return False - - await self._ensure_head_is_non_empty() - - if self._queue_head: - logger.debug('Queue head still returned requests that need to be processed') - + if self._add_requests_tasks: + logger.debug('Background add requests tasks are still in progress.') return False - # Could not lock any new requests - decide based on whether the queue contains requests locked by another client - if self._queue_has_locked_requests is not None: - if self._queue_has_locked_requests and self._dequeued_request_count == 0: - # The `% 25` was absolutely arbitrarily picked. It's just to not spam the logs too much. - if self._is_finished_log_throttle_counter % 25 == 0: - logger.info('The queue still contains requests locked by another client') - - self._is_finished_log_throttle_counter += 1 - - logger.debug( - f'Deciding if we are finished based on `queue_has_locked_requests` = {self._queue_has_locked_requests}' - ) - return not self._queue_has_locked_requests - - metadata = await self._resource_client.get() - if metadata is not None and not metadata.had_multiple_clients and not self._queue_head: - logger.debug('Queue head is empty and there are no other clients - we are finished') - + if await self.is_empty(): + logger.debug('The request queue is empty.') return True - # The following is a legacy algorithm for checking if the queue is finished. - # It is used only for request queue clients that do not provide the `queue_has_locked_requests` flag. - current_head = await self._resource_client.list_head(limit=2) - - if current_head.items: - logger.debug('The queue still contains unfinished requests or requests locked by another client') - - return len(current_head.items) == 0 - - async def get_info(self) -> RequestQueueMetadata | None: - """Get an object containing general information about the request queue.""" - return await self._resource_client.get() - - @override - async def get_handled_count(self) -> int: - return self._assumed_handled_count + return False - @override - async def get_total_count(self) -> int: - return self._assumed_total_count - - async def _ensure_head_is_non_empty(self) -> None: - # Stop fetching if we are paused for migration - if self._queue_paused_for_migration: - return - - # We want to fetch ahead of time to minimize dead time - if len(self._queue_head) > 1 and not self._should_check_for_forefront_requests: - return - - if self._list_head_and_lock_task is None: - task = asyncio.create_task(self._list_head_and_lock(), name='request_queue_list_head_and_lock_task') - - def callback(_: Any) -> None: - self._list_head_and_lock_task = None - - task.add_done_callback(callback) - self._list_head_and_lock_task = task - - await self._list_head_and_lock_task - - async def _list_head_and_lock(self) -> None: - # Make a copy so that we can clear the flag only if the whole method executes after the flag was set - # (i.e, it was not set in the middle of the execution of the method) - should_check_for_forefront_requests = self._should_check_for_forefront_requests - - limit = 25 - - response = await self._resource_client.list_and_lock_head( - limit=limit, lock_secs=int(self._request_lock_time.total_seconds()) - ) - - self._queue_has_locked_requests = response.queue_has_locked_requests - - head_id_buffer = list[str]() - forefront_head_id_buffer = list[str]() + async def _process_batch( + self, + batch: Sequence[Request], + *, + base_retry_wait: timedelta, + attempt: int = 1, + forefront: bool = False, + ) -> None: + """Process a batch of requests with automatic retry mechanism.""" + max_attempts = 5 + response = await self._client.add_batch_of_requests(batch, forefront=forefront) - for request in response.items: - # Queue head index might be behind the main table, so ensure we don't recycle requests - if not request.id or not request.unique_key: - logger.debug( - 'Skipping request from queue head, already in progress or recently handled', - extra={ - 'id': request.id, - 'unique_key': request.unique_key, - }, + if response.unprocessed_requests: + logger.debug(f'Following requests were not processed: {response.unprocessed_requests}.') + if attempt > max_attempts: + logger.warning( + f'Following requests were not processed even after {max_attempts} attempts:\n' + f'{response.unprocessed_requests}' ) - - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(request.id) - - continue - - # If we remember that we added the request ourselves and we added it to the forefront, - # we will put it to the beginning of the local queue head to preserve the expected order. - # If we do not remember that, we will enqueue it normally. - cached_request = self._requests_cache.get(unique_key_to_request_id(request.unique_key)) - forefront = cached_request['forefront'] if cached_request else False - - if forefront: - forefront_head_id_buffer.insert(0, request.id) else: - head_id_buffer.append(request.id) - - self._cache_request( - unique_key_to_request_id(request.unique_key), - ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=False, - ), - forefront=forefront, - ) - - for request_id in head_id_buffer: - self._queue_head.append(request_id) - - for request_id in forefront_head_id_buffer: - self._queue_head.appendleft(request_id) - - # If the queue head became too big, unlock the excess requests - to_unlock = list[str]() - while len(self._queue_head) > limit: - to_unlock.append(self._queue_head.pop()) - - if to_unlock: - await asyncio.gather( - *[self._resource_client.delete_request_lock(request_id) for request_id in to_unlock], - return_exceptions=True, # Just ignore the exceptions - ) - - # Unset the should_check_for_forefront_requests flag - the check is finished - if should_check_for_forefront_requests: - self._should_check_for_forefront_requests = False - - def _reset(self) -> None: - self._queue_head.clear() - self._list_head_and_lock_task = None - self._assumed_total_count = 0 - self._assumed_handled_count = 0 - self._requests_cache.clear() - self._last_activity = datetime.now(timezone.utc) - - def _cache_request(self, cache_key: str, processed_request: ProcessedRequest, *, forefront: bool) -> None: - self._requests_cache[cache_key] = { - 'id': processed_request.id, - 'was_already_handled': processed_request.was_already_handled, - 'hydrated': None, - 'lock_expires_at': None, - 'forefront': forefront, - } - - async def _get_or_hydrate_request(self, request_id: str) -> Request | None: - cached_entry = self._requests_cache.get(request_id) - - if not cached_entry: - # 2.1. Attempt to prolong the request lock to see if we still own the request - prolong_result = await self._prolong_request_lock(request_id) - - if not prolong_result: - return None - - # 2.1.1. If successful, hydrate the request and return it - hydrated_request = await self.get_request(request_id) - - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). - if not hydrated_request: - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(request_id) - - return None - - self._requests_cache[request_id] = { - 'id': request_id, - 'hydrated': hydrated_request, - 'was_already_handled': hydrated_request.handled_at is not None, - 'lock_expires_at': prolong_result, - 'forefront': False, - } - - return hydrated_request - - # 1.1. If hydrated, prolong the lock more and return it - if cached_entry['hydrated']: - # 1.1.1. If the lock expired on the hydrated requests, try to prolong. If we fail, we lost the request - # (or it was handled already) - if cached_entry['lock_expires_at'] and cached_entry['lock_expires_at'] < datetime.now(timezone.utc): - prolonged = await self._prolong_request_lock(cached_entry['id']) - - if not prolonged: - return None - - cached_entry['lock_expires_at'] = prolonged - - return cached_entry['hydrated'] - - # 1.2. If not hydrated, try to prolong the lock first (to ensure we keep it in our queue), hydrate and return it - prolonged = await self._prolong_request_lock(cached_entry['id']) - - if not prolonged: - return None - - # This might still return null if the queue head is inconsistent with the main queue table. - hydrated_request = await self.get_request(cached_entry['id']) - - cached_entry['hydrated'] = hydrated_request - - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). - if not hydrated_request: - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(cached_entry['id']) - - return None + logger.debug('Retry to add requests.') + unprocessed_requests_unique_keys = {request.unique_key for request in response.unprocessed_requests} + retry_batch = [request for request in batch if request.unique_key in unprocessed_requests_unique_keys] + await asyncio.sleep((base_retry_wait * attempt).total_seconds()) + await self._process_batch(retry_batch, base_retry_wait=base_retry_wait, attempt=attempt + 1) - return hydrated_request + request_count = len(batch) - len(response.unprocessed_requests) - async def _prolong_request_lock(self, request_id: str) -> datetime | None: - try: - res = await self._resource_client.prolong_request_lock( - request_id, lock_secs=int(self._request_lock_time.total_seconds()) - ) - except Exception as err: - # Most likely we do not own the lock anymore - logger.warning( - f'Failed to prolong lock for cached request {request_id}, either lost the lock ' - 'or the request was already handled\n', - exc_info=err, + if request_count: + logger.debug( + f'Added {request_count} requests to the queue. Processed requests: {response.processed_requests}' ) - return None - else: - return res.lock_expires_at - - async def _clear_possible_locks(self) -> None: - self._queue_paused_for_migration = True - request_id: str | None = None - - while True: - try: - request_id = self._queue_head.pop() - except LookupError: - break - - with suppress(Exception): - await self._resource_client.delete_request_lock(request_id) - # If this fails, we don't have the lock, or the request was never locked. Either way it's fine diff --git a/src/crawlee/storages/_types.py b/src/crawlee/storages/_types.py new file mode 100644 index 0000000000..e8c1b135e0 --- /dev/null +++ b/src/crawlee/storages/_types.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, TypedDict + +if TYPE_CHECKING: + import json + from collections.abc import Callable + from datetime import datetime + + from typing_extensions import NotRequired, Required + + from crawlee import Request + from crawlee.configuration import Configuration + from crawlee.storage_clients import StorageClient + + +class CachedRequest(TypedDict): + """Represent a cached request in the `RequestQueue`.""" + + id: str + """The ID of the request.""" + + was_already_handled: bool + """Indicates whether the request was already handled.""" + + hydrated: Request | None + """The hydrated request object.""" + + lock_expires_at: datetime | None + """The time at which the lock on the request expires.""" + + forefront: bool + """Indicates whether the request is at the forefront of the queue.""" + + +class IterateKwargs(TypedDict): + """Keyword arguments for dataset's `iterate` method.""" + + offset: NotRequired[int] + """Skips the specified number of items at the start.""" + + limit: NotRequired[int | None] + """The maximum number of items to retrieve. Unlimited if None.""" + + clean: NotRequired[bool] + """Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty.""" + + desc: NotRequired[bool] + """Set to True to sort results in descending order.""" + + fields: NotRequired[list[str]] + """Fields to include in each item. Sorts fields as specified if provided.""" + + omit: NotRequired[list[str]] + """Fields to exclude from each item.""" + + unwind: NotRequired[str] + """Unwinds items by a specified array field, turning each element into a separate item.""" + + skip_empty: NotRequired[bool] + """Excludes empty items from the results if True.""" + + skip_hidden: NotRequired[bool] + """Excludes fields starting with '#' if True.""" + + +class GetDataKwargs(IterateKwargs): + """Keyword arguments for dataset's `get_data` method.""" + + flatten: NotRequired[list[str]] + """Fields to be flattened in returned items.""" + + view: NotRequired[str] + """Specifies the dataset view to be used.""" + + +class ExportToKwargs(TypedDict): + """Keyword arguments for dataset's `export_to` method.""" + + key: Required[str] + """The key under which to save the data.""" + + content_type: NotRequired[Literal['json', 'csv']] + """The format in which to export the data. Either 'json' or 'csv'.""" + + to_kvs_id: NotRequired[str] + """ID of the key-value store to save the exported file.""" + + to_kvs_name: NotRequired[str] + """Name of the key-value store to save the exported file.""" + + to_kvs_storage_client: NotRequired[StorageClient] + """The storage client to use for saving the exported file.""" + + to_kvs_configuration: NotRequired[Configuration] + """The configuration to use for saving the exported file.""" + + +class ExportDataJsonKwargs(TypedDict): + """Keyword arguments for dataset's `export_data_json` method.""" + + skipkeys: NotRequired[bool] + """If True (default: False), dict keys that are not of a basic type (str, int, float, bool, None) will be skipped + instead of raising a `TypeError`.""" + + ensure_ascii: NotRequired[bool] + """Determines if non-ASCII characters should be escaped in the output JSON string.""" + + check_circular: NotRequired[bool] + """If False (default: True), skips the circular reference check for container types. A circular reference will + result in a `RecursionError` or worse if unchecked.""" + + allow_nan: NotRequired[bool] + """If False (default: True), raises a ValueError for out-of-range float values (nan, inf, -inf) to strictly comply + with the JSON specification. If True, uses their JavaScript equivalents (NaN, Infinity, -Infinity).""" + + cls: NotRequired[type[json.JSONEncoder]] + """Allows specifying a custom JSON encoder.""" + + indent: NotRequired[int] + """Specifies the number of spaces for indentation in the pretty-printed JSON output.""" + + separators: NotRequired[tuple[str, str]] + """A tuple of (item_separator, key_separator). The default is (', ', ': ') if indent is None and (',', ': ') + otherwise.""" + + default: NotRequired[Callable] + """A function called for objects that can't be serialized otherwise. It should return a JSON-encodable version + of the object or raise a `TypeError`.""" + + sort_keys: NotRequired[bool] + """Specifies whether the output JSON object should have keys sorted alphabetically.""" + + +class ExportDataCsvKwargs(TypedDict): + """Keyword arguments for dataset's `export_data_csv` method.""" + + dialect: NotRequired[str] + """Specifies a dialect to be used in CSV parsing and writing.""" + + delimiter: NotRequired[str] + """A one-character string used to separate fields. Defaults to ','.""" + + doublequote: NotRequired[bool] + """Controls how instances of `quotechar` inside a field should be quoted. When True, the character is doubled; + when False, the `escapechar` is used as a prefix. Defaults to True.""" + + escapechar: NotRequired[str] + """A one-character string used to escape the delimiter if `quoting` is set to `QUOTE_NONE` and the `quotechar` + if `doublequote` is False. Defaults to None, disabling escaping.""" + + lineterminator: NotRequired[str] + """The string used to terminate lines produced by the writer. Defaults to '\\r\\n'.""" + + quotechar: NotRequired[str] + """A one-character string used to quote fields containing special characters, like the delimiter or quotechar, + or fields containing new-line characters. Defaults to '\"'.""" + + quoting: NotRequired[int] + """Controls when quotes should be generated by the writer and recognized by the reader. Can take any of + the `QUOTE_*` constants, with a default of `QUOTE_MINIMAL`.""" + + skipinitialspace: NotRequired[bool] + """When True, spaces immediately following the delimiter are ignored. Defaults to False.""" + + strict: NotRequired[bool] + """When True, raises an exception on bad CSV input. Defaults to False.""" diff --git a/tests/e2e/project_template/utils.py b/tests/e2e/project_template/utils.py index 3bc5be4ea6..685e8c45e8 100644 --- a/tests/e2e/project_template/utils.py +++ b/tests/e2e/project_template/utils.py @@ -20,23 +20,25 @@ def patch_crawlee_version_in_project( def _patch_crawlee_version_in_requirements_txt_based_project(project_path: Path, wheel_path: Path) -> None: # Get any extras - with open(project_path / 'requirements.txt') as f: + requirements_path = project_path / 'requirements.txt' + with requirements_path.open() as f: requirements = f.read() crawlee_extras = re.findall(r'crawlee(\[.*\])', requirements)[0] or '' # Modify requirements.txt to use crawlee from wheel file instead of from Pypi - with open(project_path / 'requirements.txt') as f: + with requirements_path.open() as f: modified_lines = [] for line in f: if 'crawlee' in line: modified_lines.append(f'./{wheel_path.name}{crawlee_extras}\n') else: modified_lines.append(line) - with open(project_path / 'requirements.txt', 'w') as f: + with requirements_path.open('w') as f: f.write(''.join(modified_lines)) # Patch the dockerfile to have wheel file available - with open(project_path / 'Dockerfile') as f: + dockerfile_path = project_path / 'Dockerfile' + with dockerfile_path.open() as f: modified_lines = [] for line in f: modified_lines.append(line) @@ -49,19 +51,21 @@ def _patch_crawlee_version_in_requirements_txt_based_project(project_path: Path, f'RUN pip install ./{wheel_path.name}{crawlee_extras} --force-reinstall\n', ] ) - with open(project_path / 'Dockerfile', 'w') as f: + with dockerfile_path.open('w') as f: f.write(''.join(modified_lines)) def _patch_crawlee_version_in_pyproject_toml_based_project(project_path: Path, wheel_path: Path) -> None: """Ensure that the test is using current version of the crawlee from the source and not from Pypi.""" # Get any extras - with open(project_path / 'pyproject.toml') as f: + pyproject_path = project_path / 'pyproject.toml' + with pyproject_path.open() as f: pyproject = f.read() crawlee_extras = re.findall(r'crawlee(\[.*\])', pyproject)[0] or '' # Inject crawlee wheel file to the docker image and update project to depend on it.""" - with open(project_path / 'Dockerfile') as f: + dockerfile_path = project_path / 'Dockerfile' + with dockerfile_path.open() as f: modified_lines = [] for line in f: modified_lines.append(line) @@ -94,5 +98,5 @@ def _patch_crawlee_version_in_pyproject_toml_based_project(project_path: Path, w f'RUN {package_manager} lock\n', ] ) - with open(project_path / 'Dockerfile', 'w') as f: + with dockerfile_path.open('w') as f: f.write(''.join(modified_lines)) diff --git a/tests/unit/_utils/test_file.py b/tests/unit/_utils/test_file.py index a86291b43f..0762e1d966 100644 --- a/tests/unit/_utils/test_file.py +++ b/tests/unit/_utils/test_file.py @@ -1,6 +1,5 @@ from __future__ import annotations -import io from datetime import datetime, timezone from pathlib import Path @@ -12,7 +11,6 @@ force_remove, force_rename, is_content_type, - is_file_or_bytes, json_dumps, ) @@ -25,15 +23,6 @@ async def test_json_dumps() -> None: assert await json_dumps(datetime(2022, 1, 1, tzinfo=timezone.utc)) == '"2022-01-01 00:00:00+00:00"' -def test_is_file_or_bytes() -> None: - assert is_file_or_bytes(b'bytes') is True - assert is_file_or_bytes(bytearray(b'bytearray')) is True - assert is_file_or_bytes(io.BytesIO(b'some bytes')) is True - assert is_file_or_bytes(io.StringIO('string')) is True - assert is_file_or_bytes('just a regular string') is False - assert is_file_or_bytes(12345) is False - - @pytest.mark.parametrize( ('content_type_enum', 'content_type', 'expected_result'), [ @@ -115,7 +104,7 @@ async def test_force_remove(tmp_path: Path) -> None: assert test_file_path.exists() is False # Remove the file if it exists - with open(test_file_path, 'a', encoding='utf-8'): # noqa: ASYNC230 + with test_file_path.open('a', encoding='utf-8'): pass assert test_file_path.exists() is True await force_remove(test_file_path) @@ -134,11 +123,11 @@ async def test_force_rename(tmp_path: Path) -> None: # Will remove dst_dir if it exists (also covers normal case) # Create the src_dir with a file in it src_dir.mkdir() - with open(src_file, 'a', encoding='utf-8'): # noqa: ASYNC230 + with src_file.open('a', encoding='utf-8'): pass # Create the dst_dir with a file in it dst_dir.mkdir() - with open(dst_file, 'a', encoding='utf-8'): # noqa: ASYNC230 + with dst_file.open('a', encoding='utf-8'): pass assert src_file.exists() is True assert dst_file.exists() is True diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b7ac06d124..1b73df5743 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -13,12 +13,10 @@ from uvicorn.config import Config from crawlee import service_locator -from crawlee.configuration import Configuration from crawlee.fingerprint_suite._browserforge_adapter import get_available_header_network from crawlee.http_clients import CurlImpersonateHttpClient, HttpxHttpClient from crawlee.proxy_configuration import ProxyInfo -from crawlee.storage_clients import MemoryStorageClient -from crawlee.storages import KeyValueStore, _creation_management +from crawlee.storages import KeyValueStore from tests.unit.server import TestServer, app, serve_in_thread if TYPE_CHECKING: @@ -64,14 +62,6 @@ def _prepare_test_env() -> None: service_locator._event_manager = None service_locator._storage_client = None - # Clear creation-related caches to ensure no state is carried over between tests. - monkeypatch.setattr(_creation_management, '_cache_dataset_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_dataset_by_name', {}) - monkeypatch.setattr(_creation_management, '_cache_kvs_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_kvs_by_name', {}) - monkeypatch.setattr(_creation_management, '_cache_rq_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_rq_by_name', {}) - # Verify that the test environment was set up correctly. assert os.environ.get('CRAWLEE_STORAGE_DIR') == str(tmp_path) assert service_locator._configuration_was_retrieved is False @@ -149,18 +139,6 @@ async def disabled_proxy(proxy_info: ProxyInfo) -> AsyncGenerator[ProxyInfo, Non yield proxy_info -@pytest.fixture -def memory_storage_client(tmp_path: Path) -> MemoryStorageClient: - """A fixture for testing the memory storage client and its resource clients.""" - config = Configuration( - persist_storage=True, - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - - return MemoryStorageClient.from_config(config) - - @pytest.fixture(scope='session') def header_network() -> dict: return get_available_header_network() diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 4f151ad621..e175031f78 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -10,7 +10,6 @@ from collections import Counter from dataclasses import dataclass from datetime import timedelta -from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, cast from unittest.mock import AsyncMock, Mock, call, patch @@ -32,16 +31,16 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence + from pathlib import Path from yarl import URL from crawlee._types import JsonSerializable - from crawlee.storage_clients._memory import DatasetClient async def test_processes_requests_from_explicit_queue() -> None: queue = await RequestQueue.open() - await queue.add_requests_batched(['http://a.com/', 'http://b.com/', 'http://c.com/']) + await queue.add_requests(['http://a.com/', 'http://b.com/', 'http://c.com/']) crawler = BasicCrawler(request_manager=queue) calls = list[str]() @@ -57,7 +56,7 @@ async def handler(context: BasicCrawlingContext) -> None: async def test_processes_requests_from_request_source_tandem() -> None: request_queue = await RequestQueue.open() - await request_queue.add_requests_batched(['http://a.com/', 'http://b.com/', 'http://c.com/']) + await request_queue.add_requests(['http://a.com/', 'http://b.com/', 'http://c.com/']) request_list = RequestList(['http://a.com/', 'http://d.com', 'http://e.com']) @@ -568,14 +567,14 @@ async def test_context_push_and_get_data() -> None: crawler = BasicCrawler() dataset = await Dataset.open() - await dataset.push_data('{"a": 1}') + await dataset.push_data({'a': 1}) assert (await crawler.get_data()).items == [{'a': 1}] @crawler.router.default_handler async def handler(context: BasicCrawlingContext) -> None: - await context.push_data('{"b": 2}') + await context.push_data({'b': 2}) - await dataset.push_data('{"c": 3}') + await dataset.push_data({'c': 3}) assert (await crawler.get_data()).items == [{'a': 1}, {'c': 3}] stats = await crawler.run(['http://test.io/1']) @@ -608,8 +607,8 @@ async def test_crawler_push_and_export_data(tmp_path: Path) -> None: await dataset.push_data([{'id': 0, 'test': 'test'}, {'id': 1, 'test': 'test'}]) await dataset.push_data({'id': 2, 'test': 'test'}) - await crawler.export_data_json(path=tmp_path / 'dataset.json') - await crawler.export_data_csv(path=tmp_path / 'dataset.csv') + await crawler.export_data(path=tmp_path / 'dataset.json') + await crawler.export_data(path=tmp_path / 'dataset.csv') assert json.load((tmp_path / 'dataset.json').open()) == [ {'id': 0, 'test': 'test'}, @@ -629,8 +628,8 @@ async def handler(context: BasicCrawlingContext) -> None: await crawler.run(['http://test.io/1']) - await crawler.export_data_json(path=tmp_path / 'dataset.json') - await crawler.export_data_csv(path=tmp_path / 'dataset.csv') + await crawler.export_data(path=tmp_path / 'dataset.json') + await crawler.export_data(path=tmp_path / 'dataset.csv') assert json.load((tmp_path / 'dataset.json').open()) == [ {'id': 0, 'test': 'test'}, @@ -641,45 +640,6 @@ async def handler(context: BasicCrawlingContext) -> None: assert (tmp_path / 'dataset.csv').read_bytes() == b'id,test\r\n0,test\r\n1,test\r\n2,test\r\n' -async def test_crawler_push_and_export_data_and_json_dump_parameter(tmp_path: Path) -> None: - crawler = BasicCrawler() - - @crawler.router.default_handler - async def handler(context: BasicCrawlingContext) -> None: - await context.push_data([{'id': 0, 'test': 'test'}, {'id': 1, 'test': 'test'}]) - await context.push_data({'id': 2, 'test': 'test'}) - - await crawler.run(['http://test.io/1']) - - await crawler.export_data_json(path=tmp_path / 'dataset.json', indent=3) - - with (tmp_path / 'dataset.json').open() as json_file: - exported_json_str = json_file.read() - - # Expected data in JSON format with 3 spaces indent - expected_data = [ - {'id': 0, 'test': 'test'}, - {'id': 1, 'test': 'test'}, - {'id': 2, 'test': 'test'}, - ] - expected_json_str = json.dumps(expected_data, indent=3) - - # Assert that the exported JSON string matches the expected JSON string - assert exported_json_str == expected_json_str - - -async def test_crawler_push_data_over_limit() -> None: - crawler = BasicCrawler() - - @crawler.router.default_handler - async def handler(context: BasicCrawlingContext) -> None: - # Push a roughly 15MB payload - this should be enough to break the 9MB limit - await context.push_data({'hello': 'world' * 3 * 1024 * 1024}) - - stats = await crawler.run(['http://example.tld/1']) - assert stats.requests_failed == 1 - - async def test_context_update_kv_store() -> None: crawler = BasicCrawler() @@ -694,7 +654,7 @@ async def handler(context: BasicCrawlingContext) -> None: assert (await store.get_value('foo')) == 'bar' -async def test_context_use_state(key_value_store: KeyValueStore) -> None: +async def test_context_use_state() -> None: crawler = BasicCrawler() @crawler.router.default_handler @@ -703,9 +663,10 @@ async def handler(context: BasicCrawlingContext) -> None: await crawler.run(['https://hello.world']) - store = await crawler.get_key_value_store() + kvs = await crawler.get_key_value_store() + value = await kvs.get_value(BasicCrawler._CRAWLEE_STATE_KEY) - assert (await store.get_value(BasicCrawler._CRAWLEE_STATE_KEY)) == {'hello': 'world'} + assert value == {'hello': 'world'} async def test_context_handlers_use_state(key_value_store: KeyValueStore) -> None: @@ -869,18 +830,6 @@ async def handler(context: BasicCrawlingContext) -> None: } -async def test_respects_no_persist_storage() -> None: - configuration = Configuration(persist_storage=False) - crawler = BasicCrawler(configuration=configuration) - - @crawler.router.default_handler - async def handler(context: BasicCrawlingContext) -> None: - await context.push_data({'something': 'something'}) - - datasets_path = Path(configuration.storage_dir) / 'datasets' / 'default' - assert not datasets_path.exists() or list(datasets_path.iterdir()) == [] - - @pytest.mark.skipif(os.name == 'nt' and 'CI' in os.environ, reason='Skipped in Windows CI') @pytest.mark.parametrize( ('statistics_log_format'), @@ -1020,9 +969,9 @@ async def handler(context: BasicCrawlingContext) -> None: async def test_sets_services() -> None: custom_configuration = Configuration() custom_event_manager = LocalEventManager.from_config(custom_configuration) - custom_storage_client = MemoryStorageClient.from_config(custom_configuration) + custom_storage_client = MemoryStorageClient() - crawler = BasicCrawler( + _ = BasicCrawler( configuration=custom_configuration, event_manager=custom_event_manager, storage_client=custom_storage_client, @@ -1032,12 +981,9 @@ async def test_sets_services() -> None: assert service_locator.get_event_manager() is custom_event_manager assert service_locator.get_storage_client() is custom_storage_client - dataset = await crawler.get_dataset(name='test') - assert cast('DatasetClient', dataset._resource_client)._memory_storage_client is custom_storage_client - async def test_allows_storage_client_overwrite_before_run(monkeypatch: pytest.MonkeyPatch) -> None: - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() crawler = BasicCrawler( storage_client=custom_storage_client, @@ -1047,7 +993,7 @@ async def test_allows_storage_client_overwrite_before_run(monkeypatch: pytest.Mo async def handler(context: BasicCrawlingContext) -> None: await context.push_data({'foo': 'bar'}) - other_storage_client = MemoryStorageClient.from_config() + other_storage_client = MemoryStorageClient() service_locator.set_storage_client(other_storage_client) with monkeypatch.context() as monkey: @@ -1057,8 +1003,6 @@ async def handler(context: BasicCrawlingContext) -> None: assert spy.call_count >= 1 dataset = await crawler.get_dataset() - assert cast('DatasetClient', dataset._resource_client)._memory_storage_client is other_storage_client - data = await dataset.get_data() assert data.items == [{'foo': 'bar'}] diff --git a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py new file mode 100644 index 0000000000..45f583c878 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest + +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients._file_system import FileSystemDatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def dataset_client(configuration: Configuration) -> AsyncGenerator[FileSystemDatasetClient, None]: + """A fixture for a file system dataset client.""" + client = await FileSystemStorageClient().open_dataset_client( + name='test_dataset', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_open_creates_new_dataset(configuration: Configuration) -> None: + """Test that open() creates a new dataset with proper metadata when it doesn't exist.""" + client = await FileSystemStorageClient().open_dataset_client( + name='new_dataset', + configuration=configuration, + ) + + # Verify correct client type and properties + assert isinstance(client, FileSystemDatasetClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_dataset' + assert client.metadata.item_count == 0 + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + # Verify files were created + assert client.path_to_dataset.exists() + assert client.path_to_metadata.exists() + + # Verify metadata content + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == client.metadata.id + assert metadata['name'] == 'new_dataset' + assert metadata['item_count'] == 0 + + +async def test_dataset_client_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=True clears existing data in the dataset.""" + configuration.purge_on_start = True + + # Create dataset and add data + dataset_client1 = await FileSystemStorageClient().open_dataset_client( + name='test-purge-dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'initial data'}) + + # Verify data was added + items = await dataset_client1.get_data() + assert len(items.items) == 1 + + # Reopen + dataset_client2 = await FileSystemStorageClient().open_dataset_client( + name='test-purge-dataset', + configuration=configuration, + ) + + # Verify data was purged + items = await dataset_client2.get_data() + assert len(items.items) == 0 + + +async def test_dataset_client_no_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=False keeps existing data in the dataset.""" + configuration.purge_on_start = False + + # Create dataset and add data + dataset_client1 = await FileSystemStorageClient().open_dataset_client( + name='test-no-purge-dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'preserved data'}) + + # Reopen + dataset_client2 = await FileSystemStorageClient().open_dataset_client( + name='test-no-purge-dataset', + configuration=configuration, + ) + + # Verify data was preserved + items = await dataset_client2.get_data() + assert len(items.items) == 1 + assert items.items[0]['item'] == 'preserved data' + + +async def test_open_with_id_raises_error(configuration: Configuration) -> None: + """Test that open() raises an error when an ID is provided.""" + with pytest.raises(ValueError, match='not supported for file system storage client'): + await FileSystemStorageClient().open_dataset_client(id='some-id', configuration=configuration) + + +async def test_push_data_single_item(dataset_client: FileSystemDatasetClient) -> None: + """Test pushing a single item to the dataset.""" + item = {'key': 'value', 'number': 42} + await dataset_client.push_data(item) + + # Verify item count was updated + assert dataset_client.metadata.item_count == 1 + + all_files = list(dataset_client.path_to_dataset.glob('*.json')) + assert len(all_files) == 2 # 1 data file + 1 metadata file + + # Verify item was persisted + data_files = [item for item in all_files if item.name != METADATA_FILENAME] + assert len(data_files) == 1 + + # Verify file content + with Path(data_files[0]).open() as f: + saved_item = json.load(f) + assert saved_item == item + + +async def test_push_data_multiple_items(dataset_client: FileSystemDatasetClient) -> None: + """Test pushing multiple items to the dataset.""" + items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] + await dataset_client.push_data(items) + + # Verify item count was updated + assert dataset_client.metadata.item_count == 3 + + all_files = list(dataset_client.path_to_dataset.glob('*.json')) + assert len(all_files) == 4 # 3 data files + 1 metadata file + + # Verify items were saved to files + data_files = [f for f in all_files if f.name != METADATA_FILENAME] + assert len(data_files) == 3 + + +async def test_get_data_empty_dataset(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data from an empty dataset returns empty list.""" + result = await dataset_client.get_data() + + assert isinstance(result, DatasetItemsListPage) + assert result.count == 0 + assert result.total == 0 + assert result.items == [] + + +async def test_get_data_with_items(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data from a dataset returns all items in order with correct properties.""" + # Add some items + items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + + assert result.count == 3 + assert result.total == 3 + assert len(result.items) == 3 + assert result.items[0]['id'] == 1 + assert result.items[1]['id'] == 2 + assert result.items[2]['id'] == 3 + + +async def test_get_data_with_pagination(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data with offset and limit parameters for pagination implementation.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test offset + result = await dataset_client.get_data(offset=3) + assert result.count == 7 + assert result.offset == 3 + assert result.items[0]['id'] == 4 + + # Test limit + result = await dataset_client.get_data(limit=5) + assert result.count == 5 + assert result.limit == 5 + assert result.items[-1]['id'] == 5 + + # Test both offset and limit + result = await dataset_client.get_data(offset=2, limit=3) + assert result.count == 3 + assert result.offset == 2 + assert result.limit == 3 + assert result.items[0]['id'] == 3 + assert result.items[-1]['id'] == 5 + + +async def test_get_data_descending_order(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data in descending order reverses the item order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Get items in descending order + result = await dataset_client.get_data(desc=True) + + assert result.desc is True + assert result.items[0]['id'] == 5 + assert result.items[-1]['id'] == 1 + + +async def test_get_data_skip_empty(dataset_client: FileSystemDatasetClient) -> None: + """Test getting data with skip_empty option filters out empty items when True.""" + # Add some items including an empty one + items = [ + {'id': 1, 'name': 'Item 1'}, + {}, # Empty item + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + assert result.count == 3 + + # Get non-empty items + result = await dataset_client.get_data(skip_empty=True) + assert result.count == 2 + assert all(item != {} for item in result.items) + + +async def test_iterate(dataset_client: FileSystemDatasetClient) -> None: + """Test iterating over dataset items yields each item in the original order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Iterate over all items + collected_items = [item async for item in dataset_client.iterate_items()] + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_iterate_with_options(dataset_client: FileSystemDatasetClient) -> None: + """Test iterating with offset, limit and desc parameters works the same as with get_data().""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test with offset and limit + collected_items = [item async for item in dataset_client.iterate_items(offset=3, limit=3)] + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 4 + assert collected_items[-1]['id'] == 6 + + # Test with descending order + collected_items = [] + async for item in dataset_client.iterate_items(desc=True, limit=3): + collected_items.append(item) + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 10 + assert collected_items[-1]['id'] == 8 + + +async def test_drop(dataset_client: FileSystemDatasetClient) -> None: + """Test dropping a dataset removes the entire dataset directory from disk.""" + await dataset_client.push_data({'test': 'data'}) + + assert dataset_client.path_to_dataset.exists() + + # Drop the dataset + await dataset_client.drop() + + assert not dataset_client.path_to_dataset.exists() + + +async def test_metadata_updates(dataset_client: FileSystemDatasetClient) -> None: + """Test that metadata timestamps are updated correctly after read and write operations.""" + # Record initial timestamps + initial_created = dataset_client.metadata.created_at + initial_accessed = dataset_client.metadata.accessed_at + initial_modified = dataset_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await dataset_client.get_data() + + # Verify timestamps + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.accessed_at > initial_accessed + assert dataset_client.metadata.modified_at == initial_modified + + accessed_after_get = dataset_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await dataset_client.push_data({'new': 'item'}) + + # Verify timestamps again + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.modified_at > initial_modified + assert dataset_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py new file mode 100644 index 0000000000..1c1d4d3de6 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py @@ -0,0 +1,368 @@ +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients._file_system import FileSystemKeyValueStoreClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def kvs_client(configuration: Configuration) -> AsyncGenerator[FileSystemKeyValueStoreClient, None]: + """A fixture for a file system key-value store client.""" + client = await FileSystemStorageClient().open_key_value_store_client( + name='test_kvs', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_open_creates_new_kvs(configuration: Configuration) -> None: + """Test that open() creates a new key-value store with proper metadata and files on disk.""" + client = await FileSystemStorageClient().open_key_value_store_client( + name='new_kvs', + configuration=configuration, + ) + + # Verify correct client type and properties + assert isinstance(client, FileSystemKeyValueStoreClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_kvs' + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + # Verify files were created + assert client.path_to_kvs.exists() + assert client.path_to_metadata.exists() + + # Verify metadata content + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == client.metadata.id + assert metadata['name'] == 'new_kvs' + + +async def test_kvs_client_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=True clears existing data in the key-value store.""" + configuration.purge_on_start = True + + # Create KVS and add data + kvs_client1 = await FileSystemStorageClient().open_key_value_store_client( + name='test-purge-kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='initial value') + + # Verify value was set + record = await kvs_client1.get_value(key='test-key') + assert record is not None + assert record.value == 'initial value' + + # Reopen + kvs_client2 = await FileSystemStorageClient().open_key_value_store_client( + name='test-purge-kvs', + configuration=configuration, + ) + + # Verify value was purged + record = await kvs_client2.get_value(key='test-key') + assert record is None + + +async def test_kvs_client_no_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=False keeps existing data in the key-value store.""" + configuration.purge_on_start = False + + # Create KVS and add data + kvs_client1 = await FileSystemStorageClient().open_key_value_store_client( + name='test-no-purge-kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='preserved value') + + # Reopen + kvs_client2 = await FileSystemStorageClient().open_key_value_store_client( + name='test-no-purge-kvs', + configuration=configuration, + ) + + # Verify value was preserved + record = await kvs_client2.get_value(key='test-key') + assert record is not None + assert record.value == 'preserved value' + + +async def test_open_with_id_raises_error(configuration: Configuration) -> None: + """Test that open() raises an error when an ID is provided (unsupported for file system client).""" + with pytest.raises(ValueError, match='not supported for file system storage client'): + await FileSystemStorageClient().open_key_value_store_client(id='some-id', configuration=configuration) + + +async def test_set_get_value_string(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test setting and getting a string value with correct file creation and metadata.""" + # Set a value + test_key = 'test-key' + test_value = 'Hello, world!' + await kvs_client.set_value(key=test_key, value=test_value) + + # Check if the file was created + key_path = kvs_client.path_to_kvs / test_key + key_metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' + assert key_path.exists() + assert key_metadata_path.exists() + + # Check file content + content = key_path.read_text(encoding='utf-8') + assert content == test_value + + # Check record metadata + with key_metadata_path.open() as f: + metadata = json.load(f) + assert metadata['key'] == test_key + assert metadata['content_type'] == 'text/plain; charset=utf-8' + assert metadata['size'] == len(test_value.encode('utf-8')) + + # Get the value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.key == test_key + assert record.value == test_value + assert record.content_type == 'text/plain; charset=utf-8' + assert record.size == len(test_value.encode('utf-8')) + + +async def test_set_get_value_json(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test setting and getting a JSON value with correct serialization and deserialization.""" + # Set a value + test_key = 'test-json' + test_value = {'name': 'John', 'age': 30, 'items': [1, 2, 3]} + await kvs_client.set_value(key=test_key, value=test_value) + + # Get the value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.key == test_key + assert record.value == test_value + assert 'application/json' in record.content_type + + +async def test_set_get_value_bytes(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test setting and getting binary data without corruption and with correct content type.""" + # Set a value + test_key = 'test-binary' + test_value = b'\x00\x01\x02\x03\x04' + await kvs_client.set_value(key=test_key, value=test_value) + + # Get the value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.key == test_key + assert record.value == test_value + assert record.content_type == 'application/octet-stream' + assert record.size == len(test_value) + + +async def test_set_value_explicit_content_type(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that an explicitly provided content type overrides the automatically inferred one.""" + test_key = 'test-explicit-content-type' + test_value = 'Hello, world!' + explicit_content_type = 'text/html; charset=utf-8' + + await kvs_client.set_value(key=test_key, value=test_value, content_type=explicit_content_type) + + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.content_type == explicit_content_type + + +async def test_get_nonexistent_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that attempting to get a non-existent key returns None.""" + record = await kvs_client.get_value(key='nonexistent-key') + assert record is None + + +async def test_overwrite_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that an existing value can be overwritten and the updated value is retrieved correctly.""" + test_key = 'test-overwrite' + + # Set initial value + initial_value = 'Initial value' + await kvs_client.set_value(key=test_key, value=initial_value) + + # Overwrite with new value + new_value = 'New value' + await kvs_client.set_value(key=test_key, value=new_value) + + # Verify the updated value + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.value == new_value + + +async def test_delete_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that deleting a value removes its files from disk and makes it irretrievable.""" + test_key = 'test-delete' + test_value = 'Delete me' + + # Set a value + await kvs_client.set_value(key=test_key, value=test_value) + + # Verify it exists + key_path = kvs_client.path_to_kvs / test_key + metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' + assert key_path.exists() + assert metadata_path.exists() + + # Delete the value + await kvs_client.delete_value(key=test_key) + + # Verify files were deleted + assert not key_path.exists() + assert not metadata_path.exists() + + # Verify value is no longer retrievable + record = await kvs_client.get_value(key=test_key) + assert record is None + + +async def test_delete_nonexistent_value(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that attempting to delete a non-existent key is a no-op and doesn't raise errors.""" + # Should not raise an error + await kvs_client.delete_value(key='nonexistent-key') + + +async def test_iterate_keys_empty_store(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that iterating over an empty store yields no keys.""" + keys = [key async for key in kvs_client.iterate_keys()] + assert len(keys) == 0 + + +async def test_iterate_keys(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that all keys can be iterated over and are returned in sorted order.""" + # Add some values + await kvs_client.set_value(key='key1', value='value1') + await kvs_client.set_value(key='key2', value='value2') + await kvs_client.set_value(key='key3', value='value3') + + # Iterate over keys + keys = [key.key async for key in kvs_client.iterate_keys()] + assert len(keys) == 3 + assert sorted(keys) == ['key1', 'key2', 'key3'] + + +async def test_iterate_keys_with_limit(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that the limit parameter returns only the specified number of keys.""" + # Add some values + await kvs_client.set_value(key='key1', value='value1') + await kvs_client.set_value(key='key2', value='value2') + await kvs_client.set_value(key='key3', value='value3') + + # Iterate with limit + keys = [key.key async for key in kvs_client.iterate_keys(limit=2)] + assert len(keys) == 2 + + +async def test_iterate_keys_with_exclusive_start_key(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that exclusive_start_key parameter returns only keys after it alphabetically.""" + # Add some values with alphabetical keys + await kvs_client.set_value(key='a-key', value='value-a') + await kvs_client.set_value(key='b-key', value='value-b') + await kvs_client.set_value(key='c-key', value='value-c') + await kvs_client.set_value(key='d-key', value='value-d') + + # Iterate with exclusive start key + keys = [key.key async for key in kvs_client.iterate_keys(exclusive_start_key='b-key')] + assert len(keys) == 2 + assert 'c-key' in keys + assert 'd-key' in keys + assert 'a-key' not in keys + assert 'b-key' not in keys + + +async def test_drop(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that drop removes the entire store directory from disk.""" + await kvs_client.set_value(key='test', value='test-value') + + assert kvs_client.path_to_kvs.exists() + + # Drop the store + await kvs_client.drop() + + assert not kvs_client.path_to_kvs.exists() + + +async def test_metadata_updates(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that read/write operations properly update accessed_at and modified_at timestamps.""" + # Record initial timestamps + initial_created = kvs_client.metadata.created_at + initial_accessed = kvs_client.metadata.accessed_at + initial_modified = kvs_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await kvs_client.get_value(key='nonexistent') + + # Verify timestamps + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.accessed_at > initial_accessed + assert kvs_client.metadata.modified_at == initial_modified + + accessed_after_get = kvs_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await kvs_client.set_value(key='new-key', value='new-value') + + # Verify timestamps again + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.modified_at > initial_modified + assert kvs_client.metadata.accessed_at > accessed_after_get + + +async def test_get_public_url_not_supported(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that get_public_url raises NotImplementedError for the file system implementation.""" + with pytest.raises(NotImplementedError, match='Public URLs are not supported'): + await kvs_client.get_public_url(key='any-key') + + +async def test_concurrent_operations(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that multiple concurrent set operations can be performed safely with correct results.""" + + # Create multiple tasks to set different values concurrently + async def set_value(key: str, value: str) -> None: + await kvs_client.set_value(key=key, value=value) + + tasks = [asyncio.create_task(set_value(f'concurrent-key-{i}', f'value-{i}')) for i in range(10)] + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + # Verify all values were set correctly + for i in range(10): + key = f'concurrent-key-{i}' + record = await kvs_client.get_value(key=key) + assert record is not None + assert record.value == f'value-{i}' diff --git a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py new file mode 100644 index 0000000000..92991bc9b3 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py @@ -0,0 +1,466 @@ +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee import Request +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients._file_system import FileSystemRequestQueueClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def rq_client(configuration: Configuration) -> AsyncGenerator[FileSystemRequestQueueClient, None]: + """A fixture for a file system request queue client.""" + client = await FileSystemStorageClient().open_request_queue_client( + name='test_request_queue', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_open_creates_new_rq(configuration: Configuration) -> None: + """Test that open() creates a new request queue with proper metadata and files on disk.""" + client = await FileSystemStorageClient().open_request_queue_client( + name='new_request_queue', + configuration=configuration, + ) + + # Verify correct client type and properties + assert isinstance(client, FileSystemRequestQueueClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_request_queue' + assert client.metadata.handled_request_count == 0 + assert client.metadata.pending_request_count == 0 + assert client.metadata.total_request_count == 0 + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + # Verify files were created + assert client.path_to_rq.exists() + assert client.path_to_metadata.exists() + + # Verify metadata content + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == client.metadata.id + assert metadata['name'] == 'new_request_queue' + + +async def test_rq_client_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=True clears existing data in the request queue.""" + configuration.purge_on_start = True + + # Create request queue and add data + rq_client1 = await FileSystemStorageClient().open_request_queue_client( + name='test-purge-rq', + configuration=configuration, + ) + await rq_client1.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Verify request was added + assert rq_client1.metadata.total_request_count == 1 + + # Reopen + rq_client2 = await FileSystemStorageClient().open_request_queue_client( + name='test-purge-rq', + configuration=configuration, + ) + + # Verify data was purged + assert rq_client2.metadata.total_request_count == 0 + + +async def test_rq_client_no_purge_on_start(configuration: Configuration) -> None: + """Test that purge_on_start=False keeps existing data in the request queue.""" + configuration.purge_on_start = False + + # Create request queue and add data + rq_client1 = await FileSystemStorageClient().open_request_queue_client( + name='test-no-purge-rq', + configuration=configuration, + ) + await rq_client1.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Reopen + rq_client2 = await FileSystemStorageClient().open_request_queue_client( + name='test-no-purge-rq', + configuration=configuration, + ) + + # Verify data was preserved + assert rq_client2.metadata.total_request_count == 1 + + +async def test_open_with_id_raises_error(configuration: Configuration) -> None: + """Test that open() raises an error when an ID is provided.""" + with pytest.raises(ValueError, match='not supported for file system storage client'): + await FileSystemStorageClient().open_request_queue_client(id='some-id', configuration=configuration) + + +@pytest.fixture +def rq_path(rq_client: FileSystemRequestQueueClient) -> Path: + """Return the path to the request queue directory.""" + return rq_client.path_to_rq + + +async def test_add_requests(rq_client: FileSystemRequestQueueClient) -> None: + """Test adding requests creates proper files in the filesystem.""" + # Add a batch of requests + requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + Request.from_url('https://example.com/3'), + ] + + response = await rq_client.add_batch_of_requests(requests) + + # Verify response + assert len(response.processed_requests) == 3 + for i, processed_request in enumerate(response.processed_requests): + assert processed_request.unique_key == f'https://example.com/{i + 1}' + assert processed_request.was_already_present is False + assert processed_request.was_already_handled is False + + # Verify request files were created + request_files = list(rq_client.path_to_rq.glob('*.json')) + assert len(request_files) == 4 # 3 requests + metadata file + assert rq_client.path_to_metadata in request_files + + # Verify metadata was updated + assert rq_client.metadata.total_request_count == 3 + assert rq_client.metadata.pending_request_count == 3 + + # Verify content of the request files + for req_file in [f for f in request_files if f != rq_client.path_to_metadata]: + with req_file.open() as f: + content = json.load(f) + assert 'url' in content + assert content['url'].startswith('https://example.com/') + assert 'id' in content + assert 'handled_at' not in content # Not yet handled + + +async def test_add_duplicate_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test adding a duplicate request.""" + request = Request.from_url('https://example.com') + + # Add the request the first time + await rq_client.add_batch_of_requests([request]) + + # Add the same request again + second_response = await rq_client.add_batch_of_requests([request]) + + # Verify response indicates it was already present + assert second_response.processed_requests[0].was_already_present is True + + # Verify only one request file exists + request_files = [f for f in rq_client.path_to_rq.glob('*.json') if f.name != METADATA_FILENAME] + assert len(request_files) == 1 + + # Verify metadata counts weren't incremented + assert rq_client.metadata.total_request_count == 1 + assert rq_client.metadata.pending_request_count == 1 + + +async def test_fetch_next_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test fetching the next request from the queue.""" + # Add requests + requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch the first request + first_request = await rq_client.fetch_next_request() + assert first_request is not None + assert first_request.url == 'https://example.com/1' + + # Check that it's marked as in-progress + assert first_request.id in rq_client._in_progress + + # Fetch the second request + second_request = await rq_client.fetch_next_request() + assert second_request is not None + assert second_request.url == 'https://example.com/2' + + # There should be no more requests + empty_request = await rq_client.fetch_next_request() + assert empty_request is None + + +async def test_fetch_forefront_requests(rq_client: FileSystemRequestQueueClient) -> None: + """Test that forefront requests are fetched first.""" + # Add regular requests + await rq_client.add_batch_of_requests( + [ + Request.from_url('https://example.com/regular1'), + Request.from_url('https://example.com/regular2'), + ] + ) + + # Add forefront requests + await rq_client.add_batch_of_requests( + [ + Request.from_url('https://example.com/priority1'), + Request.from_url('https://example.com/priority2'), + ], + forefront=True, + ) + + # Fetch requests - they should come in priority order first + next_request1 = await rq_client.fetch_next_request() + assert next_request1 is not None + assert next_request1.url.startswith('https://example.com/priority') + + next_request2 = await rq_client.fetch_next_request() + assert next_request2 is not None + assert next_request2.url.startswith('https://example.com/priority') + + next_request3 = await rq_client.fetch_next_request() + assert next_request3 is not None + assert next_request3.url.startswith('https://example.com/regular') + + next_request4 = await rq_client.fetch_next_request() + assert next_request4 is not None + assert next_request4.url.startswith('https://example.com/regular') + + +async def test_mark_request_as_handled(rq_client: FileSystemRequestQueueClient) -> None: + """Test marking a request as handled.""" + # Add and fetch a request + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + request = await rq_client.fetch_next_request() + assert request is not None + + # Mark it as handled + result = await rq_client.mark_request_as_handled(request) + assert result is not None + assert result.was_already_handled is True + + # Verify it's no longer in-progress + assert request.id not in rq_client._in_progress + + # Verify metadata was updated + assert rq_client.metadata.handled_request_count == 1 + assert rq_client.metadata.pending_request_count == 0 + + # Verify the file was updated with handled_at timestamp + request_files = [f for f in rq_client.path_to_rq.glob('*.json') if f.name != METADATA_FILENAME] + assert len(request_files) == 1 + + with request_files[0].open() as f: + content = json.load(f) + assert 'handled_at' in content + assert content['handled_at'] is not None + + +async def test_reclaim_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test reclaiming a request that failed processing.""" + # Add and fetch a request + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + request = await rq_client.fetch_next_request() + assert request is not None + + # Reclaim the request + result = await rq_client.reclaim_request(request) + assert result is not None + assert result.was_already_handled is False + + # Verify it's no longer in-progress + assert request.id not in rq_client._in_progress + + # Should be able to fetch it again + reclaimed_request = await rq_client.fetch_next_request() + assert reclaimed_request is not None + assert reclaimed_request.id == request.id + + +async def test_reclaim_request_with_forefront(rq_client: FileSystemRequestQueueClient) -> None: + """Test reclaiming a request with forefront priority.""" + # Add requests + await rq_client.add_batch_of_requests( + [ + Request.from_url('https://example.com/first'), + Request.from_url('https://example.com/second'), + ] + ) + + # Fetch the first request + first_request = await rq_client.fetch_next_request() + assert first_request is not None + assert first_request.url == 'https://example.com/first' + + # Reclaim it with forefront priority + await rq_client.reclaim_request(first_request, forefront=True) + + # Verify it's in the forefront set + assert first_request.id in rq_client._forefront_requests + + # It should be returned before the second request + reclaimed_request = await rq_client.fetch_next_request() + assert reclaimed_request is not None + assert reclaimed_request.url == 'https://example.com/first' + + +async def test_is_empty(rq_client: FileSystemRequestQueueClient) -> None: + """Test checking if a queue is empty.""" + # Queue should start empty + assert await rq_client.is_empty() is True + + # Add a request + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + assert await rq_client.is_empty() is False + + # Fetch and handle the request + request = await rq_client.fetch_next_request() + assert request is not None + await rq_client.mark_request_as_handled(request) + + # Queue should be empty again + assert await rq_client.is_empty() is True + + +async def test_get_request(rq_client: FileSystemRequestQueueClient) -> None: + """Test getting a request by ID.""" + # Add a request + response = await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + request_id = response.processed_requests[0].id + + # Get the request by ID + request = await rq_client.get_request(request_id) + assert request is not None + assert request.id == request_id + assert request.url == 'https://example.com' + + # Try to get a non-existent request + not_found = await rq_client.get_request('non-existent-id') + assert not_found is None + + +async def test_drop(configuration: Configuration) -> None: + """Test dropping the queue removes files from the filesystem.""" + client = await FileSystemStorageClient().open_request_queue_client( + name='drop_test', + configuration=configuration, + ) + + # Add requests to create files + await client.add_batch_of_requests( + [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + ) + + # Verify the directory exists + rq_path = client.path_to_rq + assert rq_path.exists() + + # Drop the client + await client.drop() + + # Verify the directory was removed + assert not rq_path.exists() + + +async def test_file_persistence(configuration: Configuration) -> None: + """Test that requests are persisted to files and can be recovered after a 'restart'.""" + # Explicitly set purge_on_start to False to ensure files aren't deleted + configuration.purge_on_start = False + + # Create a client and add requests + client1 = await FileSystemStorageClient().open_request_queue_client( + name='persistence_test', + configuration=configuration, + ) + + await client1.add_batch_of_requests( + [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + ) + + # Fetch and handle one request + request = await client1.fetch_next_request() + assert request is not None + await client1.mark_request_as_handled(request) + + # Get the storage directory path before clearing the cache + storage_path = client1.path_to_rq + assert storage_path.exists(), 'Request queue directory should exist' + + # Verify files exist + request_files = list(storage_path.glob('*.json')) + assert len(request_files) > 0, 'Request files should exist' + + # Create a new client with same name (which will load from files) + client2 = await FileSystemStorageClient().open_request_queue_client( + name='persistence_test', + configuration=configuration, + ) + + # Verify state was recovered + assert client2.metadata.total_request_count == 2 + assert client2.metadata.handled_request_count == 1 + assert client2.metadata.pending_request_count == 1 + + # Should be able to fetch the remaining request + remaining_request = await client2.fetch_next_request() + assert remaining_request is not None + assert remaining_request.url == 'https://example.com/2' + + # Clean up + await client2.drop() + + +async def test_metadata_updates(rq_client: FileSystemRequestQueueClient) -> None: + """Test that metadata timestamps are updated correctly after operations.""" + # Record initial timestamps + initial_created = rq_client.metadata.created_at + initial_accessed = rq_client.metadata.accessed_at + initial_modified = rq_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await rq_client.is_empty() + + # Verify timestamps + assert rq_client.metadata.created_at == initial_created + assert rq_client.metadata.accessed_at > initial_accessed + assert rq_client.metadata.modified_at == initial_modified + + accessed_after_get = rq_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Verify timestamps again + assert rq_client.metadata.created_at == initial_created + assert rq_client.metadata.modified_at > initial_modified + assert rq_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_memory/test_creation_management.py b/tests/unit/storage_clients/_memory/test_creation_management.py deleted file mode 100644 index 88a5e9e283..0000000000 --- a/tests/unit/storage_clients/_memory/test_creation_management.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import json -from pathlib import Path -from unittest.mock import AsyncMock, patch - -import pytest - -from crawlee._consts import METADATA_FILENAME -from crawlee.storage_clients._memory._creation_management import persist_metadata_if_enabled - - -async def test_persist_metadata_skips_when_disabled(tmp_path: Path) -> None: - await persist_metadata_if_enabled(data={'key': 'value'}, entity_directory=str(tmp_path), write_metadata=False) - assert not list(tmp_path.iterdir()) # The directory should be empty since write_metadata is False - - -async def test_persist_metadata_creates_files_and_directories_when_enabled(tmp_path: Path) -> None: - data = {'key': 'value'} - entity_directory = Path(tmp_path, 'new_dir') - await persist_metadata_if_enabled(data=data, entity_directory=str(entity_directory), write_metadata=True) - assert entity_directory.exists() is True # Check if directory was created - assert (entity_directory / METADATA_FILENAME).is_file() # Check if file was created - - -async def test_persist_metadata_correctly_writes_data(tmp_path: Path) -> None: - data = {'key': 'value'} - entity_directory = Path(tmp_path, 'data_dir') - await persist_metadata_if_enabled(data=data, entity_directory=str(entity_directory), write_metadata=True) - metadata_path = entity_directory / METADATA_FILENAME - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert json.loads(content) == data # Check if correct data was written - - -async def test_persist_metadata_rewrites_data_with_error(tmp_path: Path) -> None: - init_data = {'key': 'very_long_value'} - update_data = {'key': 'short_value'} - error_data = {'key': 'error'} - - entity_directory = Path(tmp_path, 'data_dir') - metadata_path = entity_directory / METADATA_FILENAME - - # write metadata with init_data - await persist_metadata_if_enabled(data=init_data, entity_directory=str(entity_directory), write_metadata=True) - - # rewrite metadata with new_data - await persist_metadata_if_enabled(data=update_data, entity_directory=str(entity_directory), write_metadata=True) - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert json.loads(content) == update_data # Check if correct data was rewritten - - # raise interrupt between opening a file and writing - module_for_patch = 'crawlee.storage_clients._memory._creation_management.json_dumps' - with patch(module_for_patch, AsyncMock(side_effect=KeyboardInterrupt())), pytest.raises(KeyboardInterrupt): - await persist_metadata_if_enabled(data=error_data, entity_directory=str(entity_directory), write_metadata=True) - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert content == '' # The file is empty after an error diff --git a/tests/unit/storage_clients/_memory/test_dataset_client.py b/tests/unit/storage_clients/_memory/test_dataset_client.py deleted file mode 100644 index 472d11a8b3..0000000000 --- a/tests/unit/storage_clients/_memory/test_dataset_client.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -import asyncio -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import DatasetClient - - -@pytest.fixture -async def dataset_client(memory_storage_client: MemoryStorageClient) -> DatasetClient: - datasets_client = memory_storage_client.datasets() - dataset_info = await datasets_client.get_or_create(name='test') - return memory_storage_client.dataset(dataset_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - dataset_client = memory_storage_client.dataset(id='nonexistent-id') - assert await dataset_client.get() is None - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.update(name='test-update') - - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.list_items() - - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.push_items([{'abc': 123}]) - await dataset_client.delete() - - -async def test_not_implemented(dataset_client: DatasetClient) -> None: - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await dataset_client.stream_items() - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await dataset_client.get_items_as_bytes() - - -async def test_get(dataset_client: DatasetClient) -> None: - await asyncio.sleep(0.1) - info = await dataset_client.get() - assert info is not None - assert info.id == dataset_client.id - assert info.accessed_at != info.created_at - - -async def test_update(dataset_client: DatasetClient) -> None: - new_dataset_name = 'test-update' - await dataset_client.push_items({'abc': 123}) - - old_dataset_info = await dataset_client.get() - assert old_dataset_info is not None - old_dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, old_dataset_info.name or '') - new_dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, new_dataset_name) - assert (old_dataset_directory / '000000001.json').exists() is True - assert (new_dataset_directory / '000000001.json').exists() is False - - await asyncio.sleep(0.1) - updated_dataset_info = await dataset_client.update(name=new_dataset_name) - assert (old_dataset_directory / '000000001.json').exists() is False - assert (new_dataset_directory / '000000001.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_dataset_info.created_at == updated_dataset_info.created_at - assert old_dataset_info.modified_at != updated_dataset_info.modified_at - assert old_dataset_info.accessed_at != updated_dataset_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Dataset with name "test-update" already exists.'): - await dataset_client.update(name=new_dataset_name) - - -async def test_delete(dataset_client: DatasetClient) -> None: - await dataset_client.push_items({'abc': 123}) - dataset_info = await dataset_client.get() - assert dataset_info is not None - dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, dataset_info.name or '') - assert (dataset_directory / '000000001.json').exists() is True - await dataset_client.delete() - assert (dataset_directory / '000000001.json').exists() is False - # Does not crash when called again - await dataset_client.delete() - - -async def test_push_items(dataset_client: DatasetClient) -> None: - await dataset_client.push_items('{"test": "JSON from a string"}') - await dataset_client.push_items({'abc': {'def': {'ghi': '123'}}}) - await dataset_client.push_items(['{"test-json-parse": "JSON from a string"}' for _ in range(10)]) - await dataset_client.push_items([{'test-dict': i} for i in range(10)]) - - list_page = await dataset_client.list_items() - assert list_page.items[0]['test'] == 'JSON from a string' - assert list_page.items[1]['abc']['def']['ghi'] == '123' - assert list_page.items[11]['test-json-parse'] == 'JSON from a string' - assert list_page.items[21]['test-dict'] == 9 - assert list_page.count == 22 - - -async def test_list_items(dataset_client: DatasetClient) -> None: - item_count = 100 - used_offset = 10 - used_limit = 50 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - # Test without any parameters - list_default = await dataset_client.list_items() - assert list_default.count == item_count - assert list_default.offset == 0 - assert list_default.items[0]['id'] == 0 - assert list_default.desc is False - # Test offset - list_offset_10 = await dataset_client.list_items(offset=used_offset) - assert list_offset_10.count == item_count - used_offset - assert list_offset_10.offset == used_offset - assert list_offset_10.total == item_count - assert list_offset_10.items[0]['id'] == used_offset - # Test limit - list_limit_50 = await dataset_client.list_items(limit=used_limit) - assert list_limit_50.count == used_limit - assert list_limit_50.limit == used_limit - assert list_limit_50.total == item_count - # Test desc - list_desc_true = await dataset_client.list_items(desc=True) - assert list_desc_true.items[0]['id'] == 99 - assert list_desc_true.desc is True - - -async def test_iterate_items(dataset_client: DatasetClient) -> None: - item_count = 100 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - actual_items = [] - async for item in dataset_client.iterate_items(): - assert 'id' in item - actual_items.append(item) - assert len(actual_items) == item_count - assert actual_items[0]['id'] == 0 - assert actual_items[99]['id'] == 99 - - -async def test_reuse_dataset(dataset_client: DatasetClient, memory_storage_client: MemoryStorageClient) -> None: - item_count = 10 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - - memory_storage_client.datasets_handled = [] # purge datasets loaded to test create_dataset_from_directory - datasets_client = memory_storage_client.datasets() - dataset_info = await datasets_client.get_or_create(name='test') - assert dataset_info.item_count == item_count diff --git a/tests/unit/storage_clients/_memory/test_dataset_collection_client.py b/tests/unit/storage_clients/_memory/test_dataset_collection_client.py deleted file mode 100644 index d71b7e8f68..0000000000 --- a/tests/unit/storage_clients/_memory/test_dataset_collection_client.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import DatasetCollectionClient - - -@pytest.fixture -def datasets_client(memory_storage_client: MemoryStorageClient) -> DatasetCollectionClient: - return memory_storage_client.datasets() - - -async def test_get_or_create(datasets_client: DatasetCollectionClient) -> None: - dataset_name = 'test' - # A new dataset gets created - dataset_info = await datasets_client.get_or_create(name=dataset_name) - assert dataset_info.name == dataset_name - - # Another get_or_create call returns the same dataset - dataset_info_existing = await datasets_client.get_or_create(name=dataset_name) - assert dataset_info.id == dataset_info_existing.id - assert dataset_info.name == dataset_info_existing.name - assert dataset_info.created_at == dataset_info_existing.created_at - - -async def test_list(datasets_client: DatasetCollectionClient) -> None: - dataset_list_1 = await datasets_client.list() - assert dataset_list_1.count == 0 - - dataset_info = await datasets_client.get_or_create(name='dataset') - dataset_list_2 = await datasets_client.list() - - assert dataset_list_2.count == 1 - assert dataset_list_2.items[0].name == dataset_info.name - - # Test sorting behavior - newer_dataset_info = await datasets_client.get_or_create(name='newer-dataset') - dataset_list_sorting = await datasets_client.list() - assert dataset_list_sorting.count == 2 - assert dataset_list_sorting.items[0].name == dataset_info.name - assert dataset_list_sorting.items[1].name == newer_dataset_info.name diff --git a/tests/unit/storage_clients/_memory/test_key_value_store_client.py b/tests/unit/storage_clients/_memory/test_key_value_store_client.py deleted file mode 100644 index 26d1f8f974..0000000000 --- a/tests/unit/storage_clients/_memory/test_key_value_store_client.py +++ /dev/null @@ -1,443 +0,0 @@ -from __future__ import annotations - -import asyncio -import base64 -import json -from datetime import datetime, timezone -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -from crawlee._consts import METADATA_FILENAME -from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import maybe_parse_body -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import KeyValueStoreClient - -TINY_PNG = base64.b64decode( - s='iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVQYV2NgYAAAAAMAAWgmWQ0AAAAASUVORK5CYII=', -) -TINY_BYTES = b'\x12\x34\x56\x78\x90\xab\xcd\xef' -TINY_DATA = {'a': 'b'} -TINY_TEXT = 'abcd' - - -@pytest.fixture -async def key_value_store_client(memory_storage_client: MemoryStorageClient) -> KeyValueStoreClient: - key_value_stores_client = memory_storage_client.key_value_stores() - kvs_info = await key_value_stores_client.get_or_create(name='test') - return memory_storage_client.key_value_store(kvs_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - kvs_client = memory_storage_client.key_value_store(id='nonexistent-id') - assert await kvs_client.get() is None - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.update(name='test-update') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.list_keys() - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.set_record('test', {'abc': 123}) - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.get_record('test') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.get_record_as_bytes('test') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.delete_record('test') - - await kvs_client.delete() - - -async def test_not_implemented(key_value_store_client: KeyValueStoreClient) -> None: - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await key_value_store_client.stream_record('test') - - -async def test_get(key_value_store_client: KeyValueStoreClient) -> None: - await asyncio.sleep(0.1) - info = await key_value_store_client.get() - assert info is not None - assert info.id == key_value_store_client.id - assert info.accessed_at != info.created_at - - -async def test_update(key_value_store_client: KeyValueStoreClient) -> None: - new_kvs_name = 'test-update' - await key_value_store_client.set_record('test', {'abc': 123}) - old_kvs_info = await key_value_store_client.get() - assert old_kvs_info is not None - old_kvs_directory = Path( - key_value_store_client._memory_storage_client.key_value_stores_directory, old_kvs_info.name or '' - ) - new_kvs_directory = Path(key_value_store_client._memory_storage_client.key_value_stores_directory, new_kvs_name) - assert (old_kvs_directory / 'test.json').exists() is True - assert (new_kvs_directory / 'test.json').exists() is False - - await asyncio.sleep(0.1) - updated_kvs_info = await key_value_store_client.update(name=new_kvs_name) - assert (old_kvs_directory / 'test.json').exists() is False - assert (new_kvs_directory / 'test.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_kvs_info.created_at == updated_kvs_info.created_at - assert old_kvs_info.modified_at != updated_kvs_info.modified_at - assert old_kvs_info.accessed_at != updated_kvs_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Key-value store with name "test-update" already exists.'): - await key_value_store_client.update(name=new_kvs_name) - - -async def test_delete(key_value_store_client: KeyValueStoreClient) -> None: - await key_value_store_client.set_record('test', {'abc': 123}) - kvs_info = await key_value_store_client.get() - assert kvs_info is not None - kvs_directory = Path(key_value_store_client._memory_storage_client.key_value_stores_directory, kvs_info.name or '') - assert (kvs_directory / 'test.json').exists() is True - await key_value_store_client.delete() - assert (kvs_directory / 'test.json').exists() is False - # Does not crash when called again - await key_value_store_client.delete() - - -async def test_list_keys_empty(key_value_store_client: KeyValueStoreClient) -> None: - keys = await key_value_store_client.list_keys() - assert len(keys.items) == 0 - assert keys.count == 0 - assert keys.is_truncated is False - - -async def test_list_keys(key_value_store_client: KeyValueStoreClient) -> None: - record_count = 4 - used_limit = 2 - used_exclusive_start_key = 'a' - await key_value_store_client.set_record('b', 'test') - await key_value_store_client.set_record('a', 'test') - await key_value_store_client.set_record('d', 'test') - await key_value_store_client.set_record('c', 'test') - - # Default settings - keys = await key_value_store_client.list_keys() - assert keys.items[0].key == 'a' - assert keys.items[3].key == 'd' - assert keys.count == record_count - assert keys.is_truncated is False - # Test limit - keys_limit_2 = await key_value_store_client.list_keys(limit=used_limit) - assert keys_limit_2.count == record_count - assert keys_limit_2.limit == used_limit - assert keys_limit_2.items[1].key == 'b' - # Test exclusive start key - keys_exclusive_start = await key_value_store_client.list_keys(exclusive_start_key=used_exclusive_start_key, limit=2) - assert keys_exclusive_start.exclusive_start_key == used_exclusive_start_key - assert keys_exclusive_start.is_truncated is True - assert keys_exclusive_start.next_exclusive_start_key == 'c' - assert keys_exclusive_start.items[0].key == 'b' - assert keys_exclusive_start.items[-1].key == keys_exclusive_start.next_exclusive_start_key - - -async def test_get_and_set_record(tmp_path: Path, key_value_store_client: KeyValueStoreClient) -> None: - # Test setting dict record - dict_record_key = 'test-dict' - await key_value_store_client.set_record(dict_record_key, {'test': 123}) - dict_record_info = await key_value_store_client.get_record(dict_record_key) - assert dict_record_info is not None - assert 'application/json' in str(dict_record_info.content_type) - assert dict_record_info.value['test'] == 123 - - # Test setting str record - str_record_key = 'test-str' - await key_value_store_client.set_record(str_record_key, 'test') - str_record_info = await key_value_store_client.get_record(str_record_key) - assert str_record_info is not None - assert 'text/plain' in str(str_record_info.content_type) - assert str_record_info.value == 'test' - - # Test setting explicit json record but use str as value, i.e. json dumps is skipped - explicit_json_key = 'test-json' - await key_value_store_client.set_record(explicit_json_key, '{"test": "explicit string"}', 'application/json') - bytes_record_info = await key_value_store_client.get_record(explicit_json_key) - assert bytes_record_info is not None - assert 'application/json' in str(bytes_record_info.content_type) - assert bytes_record_info.value['test'] == 'explicit string' - - # Test using bytes - bytes_key = 'test-json' - bytes_value = b'testing bytes set_record' - await key_value_store_client.set_record(bytes_key, bytes_value, 'unknown') - bytes_record_info = await key_value_store_client.get_record(bytes_key) - assert bytes_record_info is not None - assert 'unknown' in str(bytes_record_info.content_type) - assert bytes_record_info.value == bytes_value - assert bytes_record_info.value.decode('utf-8') == bytes_value.decode('utf-8') - - # Test using file descriptor - with open(tmp_path / 'test.json', 'w+', encoding='utf-8') as f: # noqa: ASYNC230 - f.write('Test') - with pytest.raises(NotImplementedError, match='File-like values are not supported in local memory storage'): - await key_value_store_client.set_record('file', f) - - -async def test_get_record_as_bytes(key_value_store_client: KeyValueStoreClient) -> None: - record_key = 'test' - record_value = 'testing' - await key_value_store_client.set_record(record_key, record_value) - record_info = await key_value_store_client.get_record_as_bytes(record_key) - assert record_info is not None - assert record_info.value == record_value.encode('utf-8') - - -async def test_delete_record(key_value_store_client: KeyValueStoreClient) -> None: - record_key = 'test' - await key_value_store_client.set_record(record_key, 'test') - await key_value_store_client.delete_record(record_key) - # Does not crash when called again - await key_value_store_client.delete_record(record_key) - - -@pytest.mark.parametrize( - ('input_data', 'expected_output'), - [ - ( - {'key': 'image', 'value': TINY_PNG, 'contentType': None}, - {'filename': 'image', 'key': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'key': 'image', 'value': TINY_PNG, 'contentType': 'image/png'}, - {'filename': 'image.png', 'key': 'image', 'contentType': 'image/png'}, - ), - ( - {'key': 'image.png', 'value': TINY_PNG, 'contentType': None}, - {'filename': 'image.png', 'key': 'image.png', 'contentType': 'application/octet-stream'}, - ), - ( - {'key': 'image.png', 'value': TINY_PNG, 'contentType': 'image/png'}, - {'filename': 'image.png', 'key': 'image.png', 'contentType': 'image/png'}, - ), - ( - {'key': 'data', 'value': TINY_DATA, 'contentType': None}, - {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}, - ), - ( - {'key': 'data', 'value': TINY_DATA, 'contentType': 'application/json'}, - {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}, - ), - ( - {'key': 'data.json', 'value': TINY_DATA, 'contentType': None}, - {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}, - ), - ( - {'key': 'data.json', 'value': TINY_DATA, 'contentType': 'application/json'}, - {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}, - ), - ( - {'key': 'text', 'value': TINY_TEXT, 'contentType': None}, - {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text', 'value': TINY_TEXT, 'contentType': 'text/plain'}, - {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': None}, - {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': 'text/plain'}, - {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}, - ), - ], -) -async def test_writes_correct_metadata( - memory_storage_client: MemoryStorageClient, - input_data: dict, - expected_output: dict, -) -> None: - key_value_store_name = crypto_random_object_id() - - # Get KVS client - kvs_info = await memory_storage_client.key_value_stores().get_or_create(name=key_value_store_name) - kvs_client = memory_storage_client.key_value_store(kvs_info.id) - - # Write the test input item to the store - await kvs_client.set_record( - key=input_data['key'], - value=input_data['value'], - content_type=input_data['contentType'], - ) - - # Check that everything was written correctly, both the data and metadata - storage_path = Path(memory_storage_client.key_value_stores_directory, key_value_store_name) - item_path = Path(storage_path, expected_output['filename']) - item_metadata_path = storage_path / f'{expected_output["filename"]}.__metadata__.json' - - assert item_path.exists() - assert item_metadata_path.exists() - - # Test the actual value of the item - with open(item_path, 'rb') as item_file: # noqa: ASYNC230 - actual_value = maybe_parse_body(item_file.read(), expected_output['contentType']) - assert actual_value == input_data['value'] - - # Test the actual metadata of the item - with open(item_metadata_path, encoding='utf-8') as metadata_file: # noqa: ASYNC230 - json_content = json.load(metadata_file) - metadata = KeyValueStoreRecordMetadata(**json_content) - assert metadata.key == expected_output['key'] - assert expected_output['contentType'] in metadata.content_type - - -@pytest.mark.parametrize( - ('input_data', 'expected_output'), - [ - ( - {'filename': 'image', 'value': TINY_PNG, 'metadata': None}, - {'key': 'image', 'filename': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': None}, - {'key': 'image', 'filename': 'image.png', 'contentType': 'image/png'}, - ), - ( - { - 'filename': 'image', - 'value': TINY_PNG, - 'metadata': {'key': 'image', 'contentType': 'application/octet-stream'}, - }, - {'key': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, - {'key': 'image', 'filename': 'image', 'contentType': 'image/png'}, - ), - ( - { - 'filename': 'image.png', - 'value': TINY_PNG, - 'metadata': {'key': 'image.png', 'contentType': 'application/octet-stream'}, - }, - {'key': 'image.png', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image.png', 'contentType': 'image/png'}}, - {'key': 'image.png', 'contentType': 'image/png'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, - {'key': 'image', 'contentType': 'image/png'}, - ), - ( - {'filename': 'input', 'value': TINY_BYTES, 'metadata': None}, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'input.json', 'value': TINY_DATA, 'metadata': None}, - {'key': 'input', 'contentType': 'application/json'}, - ), - ( - {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': None}, - {'key': 'input', 'contentType': 'text/plain'}, - ), - ( - {'filename': 'input.bin', 'value': TINY_BYTES, 'metadata': None}, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - { - 'filename': 'input', - 'value': TINY_BYTES, - 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}, - }, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - { - 'filename': 'input.json', - 'value': TINY_DATA, - 'metadata': {'key': 'input', 'contentType': 'application/json'}, - }, - {'key': 'input', 'contentType': 'application/json'}, - ), - ( - {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': {'key': 'input', 'contentType': 'text/plain'}}, - {'key': 'input', 'contentType': 'text/plain'}, - ), - ( - { - 'filename': 'input.bin', - 'value': TINY_BYTES, - 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}, - }, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ], -) -async def test_reads_correct_metadata( - memory_storage_client: MemoryStorageClient, - input_data: dict, - expected_output: dict, -) -> None: - key_value_store_name = crypto_random_object_id() - - # Ensure the directory for the store exists - storage_path = Path(memory_storage_client.key_value_stores_directory, key_value_store_name) - storage_path.mkdir(exist_ok=True, parents=True) - - store_metadata = KeyValueStoreMetadata( - id=crypto_random_object_id(), - name='', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - user_id='1', - ) - - # Write the store metadata to disk - storage_metadata_path = storage_path / METADATA_FILENAME - with open(storage_metadata_path, mode='wb') as f: # noqa: ASYNC230 - f.write(store_metadata.model_dump_json().encode('utf-8')) - - # Write the test input item to the disk - item_path = storage_path / input_data['filename'] - with open(item_path, 'wb') as item_file: # noqa: ASYNC230 - if isinstance(input_data['value'], bytes): - item_file.write(input_data['value']) - elif isinstance(input_data['value'], str): - item_file.write(input_data['value'].encode('utf-8')) - else: - s = await json_dumps(input_data['value']) - item_file.write(s.encode('utf-8')) - - # Optionally write the metadata to disk if there is some - if input_data['metadata'] is not None: - storage_metadata_path = storage_path / f'{input_data["filename"]}.__metadata__.json' - with open(storage_metadata_path, 'w', encoding='utf-8') as metadata_file: # noqa: ASYNC230 - s = await json_dumps( - { - 'key': input_data['metadata']['key'], - 'contentType': input_data['metadata']['contentType'], - } - ) - metadata_file.write(s) - - # Create the key-value store client to load the items from disk - store_details = await memory_storage_client.key_value_stores().get_or_create(name=key_value_store_name) - key_value_store_client = memory_storage_client.key_value_store(store_details.id) - - # Read the item from the store and check if it is as expected - actual_record = await key_value_store_client.get_record(expected_output['key']) - assert actual_record is not None - - assert actual_record.key == expected_output['key'] - assert actual_record.content_type == expected_output['contentType'] - assert actual_record.value == input_data['value'] diff --git a/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py b/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py deleted file mode 100644 index 41b289eb06..0000000000 --- a/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import KeyValueStoreCollectionClient - - -@pytest.fixture -def key_value_stores_client(memory_storage_client: MemoryStorageClient) -> KeyValueStoreCollectionClient: - return memory_storage_client.key_value_stores() - - -async def test_get_or_create(key_value_stores_client: KeyValueStoreCollectionClient) -> None: - kvs_name = 'test' - # A new kvs gets created - kvs_info = await key_value_stores_client.get_or_create(name=kvs_name) - assert kvs_info.name == kvs_name - - # Another get_or_create call returns the same kvs - kvs_info_existing = await key_value_stores_client.get_or_create(name=kvs_name) - assert kvs_info.id == kvs_info_existing.id - assert kvs_info.name == kvs_info_existing.name - assert kvs_info.created_at == kvs_info_existing.created_at - - -async def test_list(key_value_stores_client: KeyValueStoreCollectionClient) -> None: - assert (await key_value_stores_client.list()).count == 0 - kvs_info = await key_value_stores_client.get_or_create(name='kvs') - kvs_list = await key_value_stores_client.list() - assert kvs_list.count == 1 - assert kvs_list.items[0].name == kvs_info.name - - # Test sorting behavior - newer_kvs_info = await key_value_stores_client.get_or_create(name='newer-kvs') - kvs_list_sorting = await key_value_stores_client.list() - assert kvs_list_sorting.count == 2 - assert kvs_list_sorting.items[0].name == kvs_info.name - assert kvs_list_sorting.items[1].name == newer_kvs_info.name diff --git a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py new file mode 100644 index 0000000000..c25074e5c0 --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._memory import MemoryDatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + +@pytest.fixture +async def dataset_client() -> AsyncGenerator[MemoryDatasetClient, None]: + """Fixture that provides a fresh memory dataset client for each test.""" + client = await MemoryStorageClient().open_dataset_client(name='test_dataset') + yield client + await client.drop() + + +async def test_open_creates_new_dataset() -> None: + """Test that open() creates a new dataset with proper metadata and adds it to the cache.""" + client = await MemoryStorageClient().open_dataset_client(name='new_dataset') + + # Verify correct client type and properties + assert isinstance(client, MemoryDatasetClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_dataset' + assert client.metadata.item_count == 0 + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + +async def test_dataset_client_purge_on_start() -> None: + """Test that purge_on_start=True clears existing data in the dataset.""" + configuration = Configuration(purge_on_start=True) + + # Create dataset and add data + dataset_client1 = await MemoryStorageClient().open_dataset_client( + name='test_purge_dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'initial data'}) + + # Verify data was added + items = await dataset_client1.get_data() + assert len(items.items) == 1 + + # Reopen + dataset_client2 = await MemoryStorageClient().open_dataset_client( + name='test_purge_dataset', + configuration=configuration, + ) + + # Verify data was purged + items = await dataset_client2.get_data() + assert len(items.items) == 0 + + +async def test_open_with_id_and_name() -> None: + """Test that open() can be used with both id and name parameters.""" + client = await MemoryStorageClient().open_dataset_client( + id='some-id', + name='some-name', + ) + assert client.metadata.id == 'some-id' + assert client.metadata.name == 'some-name' + + +async def test_push_data_single_item(dataset_client: MemoryDatasetClient) -> None: + """Test pushing a single item to the dataset and verifying it was stored correctly.""" + item = {'key': 'value', 'number': 42} + await dataset_client.push_data(item) + + # Verify item count was updated + assert dataset_client.metadata.item_count == 1 + + # Verify item was stored + result = await dataset_client.get_data() + assert result.count == 1 + assert result.items[0] == item + + +async def test_push_data_multiple_items(dataset_client: MemoryDatasetClient) -> None: + """Test pushing multiple items to the dataset and verifying they were stored correctly.""" + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Verify item count was updated + assert dataset_client.metadata.item_count == 3 + + # Verify items were stored + result = await dataset_client.get_data() + assert result.count == 3 + assert result.items == items + + +async def test_get_data_empty_dataset(dataset_client: MemoryDatasetClient) -> None: + """Test that getting data from an empty dataset returns empty results with correct metadata.""" + result = await dataset_client.get_data() + + assert isinstance(result, DatasetItemsListPage) + assert result.count == 0 + assert result.total == 0 + assert result.items == [] + + +async def test_get_data_with_items(dataset_client: MemoryDatasetClient) -> None: + """Test that all items pushed to the dataset can be retrieved with correct metadata.""" + # Add some items + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + + assert result.count == 3 + assert result.total == 3 + assert len(result.items) == 3 + assert result.items[0]['id'] == 1 + assert result.items[1]['id'] == 2 + assert result.items[2]['id'] == 3 + + +async def test_get_data_with_pagination(dataset_client: MemoryDatasetClient) -> None: + """Test that offset and limit parameters work correctly for dataset pagination.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test offset + result = await dataset_client.get_data(offset=3) + assert result.count == 7 + assert result.offset == 3 + assert result.items[0]['id'] == 4 + + # Test limit + result = await dataset_client.get_data(limit=5) + assert result.count == 5 + assert result.limit == 5 + assert result.items[-1]['id'] == 5 + + # Test both offset and limit + result = await dataset_client.get_data(offset=2, limit=3) + assert result.count == 3 + assert result.offset == 2 + assert result.limit == 3 + assert result.items[0]['id'] == 3 + assert result.items[-1]['id'] == 5 + + +async def test_get_data_descending_order(dataset_client: MemoryDatasetClient) -> None: + """Test that the desc parameter correctly reverses the order of returned items.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Get items in descending order + result = await dataset_client.get_data(desc=True) + + assert result.desc is True + assert result.items[0]['id'] == 5 + assert result.items[-1]['id'] == 1 + + +async def test_get_data_skip_empty(dataset_client: MemoryDatasetClient) -> None: + """Test that the skip_empty parameter correctly filters out empty items.""" + # Add some items including an empty one + items = [ + {'id': 1, 'name': 'Item 1'}, + {}, # Empty item + {'id': 3, 'name': 'Item 3'}, + ] + await dataset_client.push_data(items) + + # Get all items + result = await dataset_client.get_data() + assert result.count == 3 + + # Get non-empty items + result = await dataset_client.get_data(skip_empty=True) + assert result.count == 2 + assert all(item != {} for item in result.items) + + +async def test_iterate(dataset_client: MemoryDatasetClient) -> None: + """Test that iterate_items yields each item in the dataset in the correct order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset_client.push_data(items) + + # Iterate over all items + collected_items = [item async for item in dataset_client.iterate_items()] + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_iterate_with_options(dataset_client: MemoryDatasetClient) -> None: + """Test that iterate_items respects offset, limit, and desc parameters.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset_client.push_data(items) + + # Test with offset and limit + collected_items = [item async for item in dataset_client.iterate_items(offset=3, limit=3)] + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 4 + assert collected_items[-1]['id'] == 6 + + # Test with descending order + collected_items = [] + async for item in dataset_client.iterate_items(desc=True, limit=3): + collected_items.append(item) + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 10 + assert collected_items[-1]['id'] == 8 + + +async def test_drop(dataset_client: MemoryDatasetClient) -> None: + """Test that drop removes the dataset from cache and resets its state.""" + await dataset_client.push_data({'test': 'data'}) + + # Drop the dataset + await dataset_client.drop() + + # Verify the dataset is empty + assert dataset_client.metadata.item_count == 0 + result = await dataset_client.get_data() + assert result.count == 0 + + +async def test_metadata_updates(dataset_client: MemoryDatasetClient) -> None: + """Test that read/write operations properly update accessed_at and modified_at timestamps.""" + # Record initial timestamps + initial_created = dataset_client.metadata.created_at + initial_accessed = dataset_client.metadata.accessed_at + initial_modified = dataset_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await dataset_client.get_data() + + # Verify timestamps + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.accessed_at > initial_accessed + assert dataset_client.metadata.modified_at == initial_modified + + accessed_after_get = dataset_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await dataset_client.push_data({'new': 'item'}) + + # Verify timestamps again + assert dataset_client.metadata.created_at == initial_created + assert dataset_client.metadata.modified_at > initial_modified + assert dataset_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py new file mode 100644 index 0000000000..5d8789f6c3 --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from typing import TYPE_CHECKING, Any + +import pytest + +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._memory import MemoryKeyValueStoreClient +from crawlee.storage_clients.models import KeyValueStoreRecordMetadata + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + +@pytest.fixture +async def kvs_client() -> AsyncGenerator[MemoryKeyValueStoreClient, None]: + """Fixture that provides a fresh memory key-value store client for each test.""" + client = await MemoryStorageClient().open_key_value_store_client(name='test_kvs') + yield client + await client.drop() + + +async def test_open_creates_new_kvs() -> None: + """Test that open() creates a new key-value store with proper metadata and adds it to the cache.""" + client = await MemoryStorageClient().open_key_value_store_client(name='new_kvs') + + # Verify correct client type and properties + assert isinstance(client, MemoryKeyValueStoreClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_kvs' + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + + +async def test_kvs_client_purge_on_start() -> None: + """Test that purge_on_start=True clears existing data in the KVS.""" + configuration = Configuration(purge_on_start=True) + + # Create KVS and add data + kvs_client1 = await MemoryStorageClient().open_key_value_store_client( + name='test_purge_kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='initial value') + + # Verify value was set + record = await kvs_client1.get_value(key='test-key') + assert record is not None + assert record.value == 'initial value' + + # Reopen + kvs_client2 = await MemoryStorageClient().open_key_value_store_client( + name='test_purge_kvs', + configuration=configuration, + ) + + # Verify value was purged + record = await kvs_client2.get_value(key='test-key') + assert record is None + + +async def test_open_with_id_and_name() -> None: + """Test that open() can be used with both id and name parameters.""" + client = await MemoryStorageClient().open_key_value_store_client( + id='some-id', + name='some-name', + ) + assert client.metadata.id == 'some-id' + assert client.metadata.name == 'some-name' + + +@pytest.mark.parametrize( + ('key', 'value', 'expected_content_type'), + [ + pytest.param('string_key', 'string value', 'text/plain; charset=utf-8', id='string'), + pytest.param('dict_key', {'name': 'test', 'value': 42}, 'application/json; charset=utf-8', id='dictionary'), + pytest.param('list_key', [1, 2, 3], 'application/json; charset=utf-8', id='list'), + pytest.param('bytes_key', b'binary data', 'application/octet-stream', id='bytes'), + ], +) +async def test_set_get_value( + kvs_client: MemoryKeyValueStoreClient, + key: str, + value: Any, + expected_content_type: str, +) -> None: + """Test storing and retrieving different types of values with correct content types.""" + # Set value + await kvs_client.set_value(key=key, value=value) + + # Get and verify value + record = await kvs_client.get_value(key=key) + assert record is not None + assert record.key == key + assert record.value == value + assert record.content_type == expected_content_type + + +async def test_get_nonexistent_value(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that attempting to get a non-existent key returns None.""" + record = await kvs_client.get_value(key='nonexistent') + assert record is None + + +async def test_set_value_with_explicit_content_type(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that an explicitly provided content type overrides the automatically inferred one.""" + value = 'This could be XML' + content_type = 'application/xml' + + await kvs_client.set_value(key='xml_key', value=value, content_type=content_type) + + record = await kvs_client.get_value(key='xml_key') + assert record is not None + assert record.value == value + assert record.content_type == content_type + + +async def test_delete_value(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that a stored value can be deleted and is no longer retrievable after deletion.""" + # Set a value + await kvs_client.set_value(key='delete_me', value='to be deleted') + + # Verify it exists + record = await kvs_client.get_value(key='delete_me') + assert record is not None + + # Delete it + await kvs_client.delete_value(key='delete_me') + + # Verify it's gone + record = await kvs_client.get_value(key='delete_me') + assert record is None + + +async def test_delete_nonexistent_value(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that attempting to delete a non-existent key is a no-op and doesn't raise errors.""" + # Should not raise an error + await kvs_client.delete_value(key='nonexistent') + + +async def test_iterate_keys(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that all keys can be iterated over and are returned in sorted order with correct metadata.""" + # Set some values + items = { + 'a_key': 'value A', + 'b_key': 'value B', + 'c_key': 'value C', + 'd_key': 'value D', + } + + for key, value in items.items(): + await kvs_client.set_value(key=key, value=value) + + # Get all keys + metadata_list = [metadata async for metadata in kvs_client.iterate_keys()] + + # Verify keys are returned in sorted order + assert len(metadata_list) == 4 + assert [m.key for m in metadata_list] == sorted(items.keys()) + assert all(isinstance(m, KeyValueStoreRecordMetadata) for m in metadata_list) + + +async def test_iterate_keys_with_exclusive_start_key(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that exclusive_start_key parameter returns only keys after it alphabetically.""" + # Set some values + for key in ['a_key', 'b_key', 'c_key', 'd_key', 'e_key']: + await kvs_client.set_value(key=key, value=f'value for {key}') + + # Get keys starting after 'b_key' + metadata_list = [metadata async for metadata in kvs_client.iterate_keys(exclusive_start_key='b_key')] + + # Verify only keys after 'b_key' are returned + assert len(metadata_list) == 3 + assert [m.key for m in metadata_list] == ['c_key', 'd_key', 'e_key'] + + +async def test_iterate_keys_with_limit(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that the limit parameter returns only the specified number of keys.""" + # Set some values + for key in ['a_key', 'b_key', 'c_key', 'd_key', 'e_key']: + await kvs_client.set_value(key=key, value=f'value for {key}') + + # Get first 3 keys + metadata_list = [metadata async for metadata in kvs_client.iterate_keys(limit=3)] + + # Verify only the first 3 keys are returned + assert len(metadata_list) == 3 + assert [m.key for m in metadata_list] == ['a_key', 'b_key', 'c_key'] + + +async def test_drop(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that drop removes the store from cache and clears all data.""" + # Add some values to the store + await kvs_client.set_value(key='test', value='data') + + # Drop the store + await kvs_client.drop() + + # Verify the store is empty + record = await kvs_client.get_value(key='test') + assert record is None + + +async def test_get_public_url(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that get_public_url raises NotImplementedError for the memory implementation.""" + with pytest.raises(NotImplementedError): + await kvs_client.get_public_url(key='any-key') + + +async def test_metadata_updates(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that read/write operations properly update accessed_at and modified_at timestamps.""" + # Record initial timestamps + initial_created = kvs_client.metadata.created_at + initial_accessed = kvs_client.metadata.accessed_at + initial_modified = kvs_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await kvs_client.get_value(key='nonexistent') + + # Verify timestamps + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.accessed_at > initial_accessed + assert kvs_client.metadata.modified_at == initial_modified + + accessed_after_get = kvs_client.metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at and accessed_at + await kvs_client.set_value(key='new_key', value='new value') + + # Verify timestamps again + assert kvs_client.metadata.created_at == initial_created + assert kvs_client.metadata.modified_at > initial_modified + assert kvs_client.metadata.accessed_at > accessed_after_get diff --git a/tests/unit/storage_clients/_memory/test_memory_rq_client.py b/tests/unit/storage_clients/_memory/test_memory_rq_client.py new file mode 100644 index 0000000000..028c53ccd2 --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_rq_client.py @@ -0,0 +1,442 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +from crawlee import Request +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._memory import MemoryRequestQueueClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + +@pytest.fixture +async def rq_client() -> AsyncGenerator[MemoryRequestQueueClient, None]: + """Fixture that provides a fresh memory request queue client for each test.""" + client = await MemoryStorageClient().open_request_queue_client(name='test_rq') + yield client + await client.drop() + + +async def test_open_creates_new_rq() -> None: + """Test that open() creates a new request queue with proper metadata and adds it to the cache.""" + client = await MemoryStorageClient().open_request_queue_client(name='new_rq') + + # Verify correct client type and properties + assert isinstance(client, MemoryRequestQueueClient) + assert client.metadata.id is not None + assert client.metadata.name == 'new_rq' + assert isinstance(client.metadata.created_at, datetime) + assert isinstance(client.metadata.accessed_at, datetime) + assert isinstance(client.metadata.modified_at, datetime) + assert client.metadata.handled_request_count == 0 + assert client.metadata.pending_request_count == 0 + assert client.metadata.total_request_count == 0 + assert client.metadata.had_multiple_clients is False + + +async def test_rq_client_purge_on_start() -> None: + """Test that purge_on_start=True clears existing data in the RQ.""" + configuration = Configuration(purge_on_start=True) + + # Create RQ and add data + rq_client1 = await MemoryStorageClient().open_request_queue_client( + name='test_purge_rq', + configuration=configuration, + ) + request = Request.from_url(url='https://example.com/initial') + await rq_client1.add_batch_of_requests([request]) + + # Verify request was added + assert await rq_client1.is_empty() is False + + # Reopen + rq_client2 = await MemoryStorageClient().open_request_queue_client( + name='test_purge_rq', + configuration=configuration, + ) + + # Verify queue was purged + assert await rq_client2.is_empty() is True + + +async def test_open_with_id_and_name() -> None: + """Test that open() can be used with both id and name parameters.""" + client = await MemoryStorageClient().open_request_queue_client( + id='some-id', + name='some-name', + ) + assert client.metadata.id is not None # ID is always auto-generated + assert client.metadata.name == 'some-name' + + +async def test_add_batch_of_requests(rq_client: MemoryRequestQueueClient) -> None: + """Test adding a batch of requests to the queue.""" + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + Request.from_url(url='https://example.com/3'), + ] + + response = await rq_client.add_batch_of_requests(requests) + + # Verify correct response + assert len(response.processed_requests) == 3 + assert len(response.unprocessed_requests) == 0 + + # Verify each request was processed correctly + for i, req in enumerate(requests): + assert response.processed_requests[i].id == req.id + assert response.processed_requests[i].unique_key == req.unique_key + assert response.processed_requests[i].was_already_present is False + assert response.processed_requests[i].was_already_handled is False + + # Verify metadata was updated + assert rq_client.metadata.total_request_count == 3 + assert rq_client.metadata.pending_request_count == 3 + + +async def test_add_batch_of_requests_with_duplicates(rq_client: MemoryRequestQueueClient) -> None: + """Test adding requests with duplicate unique keys.""" + # Add initial requests + initial_requests = [ + Request.from_url(url='https://example.com/1', unique_key='key1'), + Request.from_url(url='https://example.com/2', unique_key='key2'), + ] + await rq_client.add_batch_of_requests(initial_requests) + + # Mark first request as handled + req1 = await rq_client.fetch_next_request() + assert req1 is not None + await rq_client.mark_request_as_handled(req1) + + # Add duplicate requests + duplicate_requests = [ + Request.from_url(url='https://example.com/1-dup', unique_key='key1'), # Same as first (handled) + Request.from_url(url='https://example.com/2-dup', unique_key='key2'), # Same as second (not handled) + Request.from_url(url='https://example.com/3', unique_key='key3'), # New request + ] + response = await rq_client.add_batch_of_requests(duplicate_requests) + + # Verify response + assert len(response.processed_requests) == 3 + + # First request should be marked as already handled + assert response.processed_requests[0].was_already_present is True + assert response.processed_requests[0].was_already_handled is True + + # Second request should be marked as already present but not handled + assert response.processed_requests[1].was_already_present is True + assert response.processed_requests[1].was_already_handled is False + + # Third request should be new + assert response.processed_requests[2].was_already_present is False + assert response.processed_requests[2].was_already_handled is False + + +async def test_add_batch_of_requests_to_forefront(rq_client: MemoryRequestQueueClient) -> None: + """Test adding requests to the forefront of the queue.""" + # Add initial requests + initial_requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(initial_requests) + + # Add new requests to forefront + forefront_requests = [ + Request.from_url(url='https://example.com/priority'), + ] + await rq_client.add_batch_of_requests(forefront_requests, forefront=True) + + # The priority request should be fetched first + next_request = await rq_client.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/priority' + + +async def test_fetch_next_request(rq_client: MemoryRequestQueueClient) -> None: + """Test fetching the next request from the queue.""" + # Add some requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch first request + request1 = await rq_client.fetch_next_request() + assert request1 is not None + assert request1.url == 'https://example.com/1' + + # Fetch second request + request2 = await rq_client.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/2' + + # No more requests + request3 = await rq_client.fetch_next_request() + assert request3 is None + + +async def test_fetch_skips_handled_requests(rq_client: MemoryRequestQueueClient) -> None: + """Test that fetch_next_request skips handled requests.""" + # Add requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch and handle first request + request1 = await rq_client.fetch_next_request() + assert request1 is not None + await rq_client.mark_request_as_handled(request1) + + # Next fetch should return second request, not the handled one + request = await rq_client.fetch_next_request() + assert request is not None + assert request.url == 'https://example.com/2' + + +async def test_fetch_skips_in_progress_requests(rq_client: MemoryRequestQueueClient) -> None: + """Test that fetch_next_request skips requests that are already in progress.""" + # Add requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch first request (it should be in progress now) + request1 = await rq_client.fetch_next_request() + assert request1 is not None + + # Next fetch should return second request, not the in-progress one + request2 = await rq_client.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/2' + + # Third fetch should return None as all requests are in progress + request3 = await rq_client.fetch_next_request() + assert request3 is None + + +async def test_get_request(rq_client: MemoryRequestQueueClient) -> None: + """Test getting a request by ID.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Get the request by ID + retrieved_request = await rq_client.get_request(request.id) + assert retrieved_request is not None + assert retrieved_request.id == request.id + assert retrieved_request.url == request.url + + # Try to get a non-existent request + nonexistent = await rq_client.get_request('nonexistent-id') + assert nonexistent is None + + +async def test_get_in_progress_request(rq_client: MemoryRequestQueueClient) -> None: + """Test getting an in-progress request by ID.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch the request to make it in-progress + fetched = await rq_client.fetch_next_request() + assert fetched is not None + + # Get the request by ID + retrieved = await rq_client.get_request(request.id) + assert retrieved is not None + assert retrieved.id == request.id + assert retrieved.url == request.url + + +async def test_mark_request_as_handled(rq_client: MemoryRequestQueueClient) -> None: + """Test marking a request as handled.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch the request to make it in-progress + fetched = await rq_client.fetch_next_request() + assert fetched is not None + + # Mark as handled + result = await rq_client.mark_request_as_handled(fetched) + assert result is not None + assert result.id == fetched.id + assert result.was_already_handled is True + + # Check that metadata was updated + assert rq_client.metadata.handled_request_count == 1 + assert rq_client.metadata.pending_request_count == 0 + + # Try to mark again (should fail as it's no longer in-progress) + result = await rq_client.mark_request_as_handled(fetched) + assert result is None + + +async def test_reclaim_request(rq_client: MemoryRequestQueueClient) -> None: + """Test reclaiming a request back to the queue.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch the request to make it in-progress + fetched = await rq_client.fetch_next_request() + assert fetched is not None + + # Reclaim the request + result = await rq_client.reclaim_request(fetched) + assert result is not None + assert result.id == fetched.id + assert result.was_already_handled is False + + # It should be available to fetch again + reclaimed = await rq_client.fetch_next_request() + assert reclaimed is not None + assert reclaimed.id == fetched.id + + +async def test_reclaim_request_to_forefront(rq_client: MemoryRequestQueueClient) -> None: + """Test reclaiming a request to the forefront of the queue.""" + # Add requests + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/2'), + ] + await rq_client.add_batch_of_requests(requests) + + # Fetch the second request to make it in-progress + await rq_client.fetch_next_request() # Skip the first one + request2 = await rq_client.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/2' + + # Reclaim the request to forefront + await rq_client.reclaim_request(request2, forefront=True) + + # It should now be the first in the queue + next_request = await rq_client.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/2' + + +async def test_is_empty(rq_client: MemoryRequestQueueClient) -> None: + """Test checking if the queue is empty.""" + # Initially empty + assert await rq_client.is_empty() is True + + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Not empty now + assert await rq_client.is_empty() is False + + # Fetch and handle + fetched = await rq_client.fetch_next_request() + assert fetched is not None + await rq_client.mark_request_as_handled(fetched) + + # Empty again (all requests handled) + assert await rq_client.is_empty() is True + + +async def test_is_empty_with_in_progress(rq_client: MemoryRequestQueueClient) -> None: + """Test that in-progress requests don't affect is_empty.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Fetch but don't handle + await rq_client.fetch_next_request() + + # Queue should still be considered non-empty + # This is because the request hasn't been handled yet + assert await rq_client.is_empty() is False + + +async def test_drop(rq_client: MemoryRequestQueueClient) -> None: + """Test that drop removes the queue from cache and clears all data.""" + # Add a request + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Drop the queue + await rq_client.drop() + + # Verify the queue is empty + assert await rq_client.is_empty() is True + + +async def test_metadata_updates(rq_client: MemoryRequestQueueClient) -> None: + """Test that operations properly update metadata timestamps.""" + # Record initial timestamps + initial_created = rq_client.metadata.created_at + initial_accessed = rq_client.metadata.accessed_at + initial_modified = rq_client.metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at and accessed_at + request = Request.from_url(url='https://example.com/test') + await rq_client.add_batch_of_requests([request]) + + # Verify timestamps + assert rq_client.metadata.created_at == initial_created + assert rq_client.metadata.modified_at > initial_modified + assert rq_client.metadata.accessed_at > initial_accessed + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Record timestamps after add + accessed_after_add = rq_client.metadata.accessed_at + modified_after_add = rq_client.metadata.modified_at + + # Check is_empty (should only update accessed_at) + await rq_client.is_empty() + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Verify only accessed_at changed + assert rq_client.metadata.modified_at == modified_after_add + assert rq_client.metadata.accessed_at > accessed_after_add + + +async def test_unique_key_generation(rq_client: MemoryRequestQueueClient) -> None: + """Test that unique keys are auto-generated if not provided.""" + # Add requests without explicit unique keys + requests = [ + Request.from_url(url='https://example.com/1'), + Request.from_url(url='https://example.com/1', always_enqueue=True), + ] + response = await rq_client.add_batch_of_requests(requests) + + # Both should be added as their auto-generated unique keys will differ + assert len(response.processed_requests) == 2 + assert all(not pr.was_already_present for pr in response.processed_requests) + + # Add a request with explicit unique key + request = Request.from_url(url='https://example.com/2', unique_key='explicit-key') + await rq_client.add_batch_of_requests([request]) + + # Add duplicate with same unique key + duplicate = Request.from_url(url='https://example.com/different', unique_key='explicit-key') + duplicate_response = await rq_client.add_batch_of_requests([duplicate]) + + # Should be marked as already present + assert duplicate_response.processed_requests[0].was_already_present is True diff --git a/tests/unit/storage_clients/_memory/test_memory_storage_client.py b/tests/unit/storage_clients/_memory/test_memory_storage_client.py deleted file mode 100644 index 0d043322ae..0000000000 --- a/tests/unit/storage_clients/_memory/test_memory_storage_client.py +++ /dev/null @@ -1,288 +0,0 @@ -# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed -# https://github.com/apify/crawlee-python/issues/146 - -from __future__ import annotations - -from pathlib import Path - -import pytest - -from crawlee import Request, service_locator -from crawlee._consts import METADATA_FILENAME -from crawlee.configuration import Configuration -from crawlee.storage_clients import MemoryStorageClient -from crawlee.storage_clients.models import BatchRequestsOperationResponse - - -async def test_write_metadata(tmp_path: Path) -> None: - dataset_name = 'test' - dataset_no_metadata_name = 'test-no-metadata' - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - write_metadata=True, - ), - ) - ms_no_metadata = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - write_metadata=False, - ) - ) - datasets_client = ms.datasets() - datasets_no_metadata_client = ms_no_metadata.datasets() - await datasets_client.get_or_create(name=dataset_name) - await datasets_no_metadata_client.get_or_create(name=dataset_no_metadata_name) - assert Path(ms.datasets_directory, dataset_name, METADATA_FILENAME).exists() is True - assert Path(ms_no_metadata.datasets_directory, dataset_no_metadata_name, METADATA_FILENAME).exists() is False - - -@pytest.mark.parametrize( - 'persist_storage', - [ - True, - False, - ], -) -async def test_persist_storage(persist_storage: bool, tmp_path: Path) -> None: # noqa: FBT001 - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - persist_storage=persist_storage, - ) - ) - - # Key value stores - kvs_client = ms.key_value_stores() - kvs_info = await kvs_client.get_or_create(name='kvs') - await ms.key_value_store(kvs_info.id).set_record('test', {'x': 1}, 'application/json') - - path = Path(ms.key_value_stores_directory) / (kvs_info.name or '') / 'test.json' - assert path.exists() is persist_storage - - # Request queues - rq_client = ms.request_queues() - rq_info = await rq_client.get_or_create(name='rq') - - request = Request.from_url('http://lorem.com') - await ms.request_queue(rq_info.id).add_request(request) - - path = Path(ms.request_queues_directory) / (rq_info.name or '') / f'{request.id}.json' - assert path.exists() is persist_storage - - # Datasets - ds_client = ms.datasets() - ds_info = await ds_client.get_or_create(name='ds') - - await ms.dataset(ds_info.id).push_items([{'foo': 'bar'}]) - - -def test_persist_storage_set_to_false_via_string_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - monkeypatch.setenv('CRAWLEE_PERSIST_STORAGE', 'false') - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.persist_storage is False - - -def test_persist_storage_set_to_false_via_numeric_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - monkeypatch.setenv('CRAWLEE_PERSIST_STORAGE', '0') - ms = MemoryStorageClient.from_config(Configuration(crawlee_storage_dir=str(tmp_path))) # type: ignore[call-arg] - assert ms.persist_storage is False - - -def test_persist_storage_true_via_constructor_arg(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - persist_storage=True, - ) - ) - assert ms.persist_storage is True - - -def test_default_write_metadata_behavior(tmp_path: Path) -> None: - # Default behavior - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.write_metadata is True - - -def test_write_metadata_set_to_false_via_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - # Test if env var changes write_metadata to False - monkeypatch.setenv('CRAWLEE_WRITE_METADATA', 'false') - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.write_metadata is False - - -def test_write_metadata_false_via_constructor_arg_overrides_env_var(tmp_path: Path) -> None: - # Test if constructor arg takes precedence over env var value - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=False, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - assert ms.write_metadata is False - - -async def test_purge_datasets(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - # Create default and non-default datasets - datasets_client = ms.datasets() - default_dataset_info = await datasets_client.get_or_create(name='default') - non_default_dataset_info = await datasets_client.get_or_create(name='non-default') - - # Check all folders inside datasets directory before and after purge - assert default_dataset_info.name is not None - assert non_default_dataset_info.name is not None - - default_path = Path(ms.datasets_directory, default_dataset_info.name) - non_default_path = Path(ms.datasets_directory, non_default_dataset_info.name) - - assert default_path.exists() is True - assert non_default_path.exists() is True - - await ms._purge_default_storages() - - assert default_path.exists() is False - assert non_default_path.exists() is True - - -async def test_purge_key_value_stores(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - - # Create default and non-default key-value stores - kvs_client = ms.key_value_stores() - default_kvs_info = await kvs_client.get_or_create(name='default') - non_default_kvs_info = await kvs_client.get_or_create(name='non-default') - default_kvs_client = ms.key_value_store(default_kvs_info.id) - # INPUT.json should be kept - await default_kvs_client.set_record('INPUT', {'abc': 123}, 'application/json') - # test.json should not be kept - await default_kvs_client.set_record('test', {'abc': 123}, 'application/json') - - # Check all folders and files inside kvs directory before and after purge - assert default_kvs_info.name is not None - assert non_default_kvs_info.name is not None - - default_kvs_path = Path(ms.key_value_stores_directory, default_kvs_info.name) - non_default_kvs_path = Path(ms.key_value_stores_directory, non_default_kvs_info.name) - kvs_directory = Path(ms.key_value_stores_directory, 'default') - - assert default_kvs_path.exists() is True - assert non_default_kvs_path.exists() is True - - assert (kvs_directory / 'INPUT.json').exists() is True - assert (kvs_directory / 'test.json').exists() is True - - await ms._purge_default_storages() - - assert default_kvs_path.exists() is True - assert non_default_kvs_path.exists() is True - - assert (kvs_directory / 'INPUT.json').exists() is True - assert (kvs_directory / 'test.json').exists() is False - - -async def test_purge_request_queues(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - # Create default and non-default request queues - rq_client = ms.request_queues() - default_rq_info = await rq_client.get_or_create(name='default') - non_default_rq_info = await rq_client.get_or_create(name='non-default') - - # Check all folders inside rq directory before and after purge - assert default_rq_info.name - assert non_default_rq_info.name - - default_rq_path = Path(ms.request_queues_directory, default_rq_info.name) - non_default_rq_path = Path(ms.request_queues_directory, non_default_rq_info.name) - - assert default_rq_path.exists() is True - assert non_default_rq_path.exists() is True - - await ms._purge_default_storages() - - assert default_rq_path.exists() is False - assert non_default_rq_path.exists() is True - - -async def test_not_implemented_method(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - ddt = ms.dataset('test') - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await ddt.stream_items(item_format='json') - - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await ddt.stream_items(item_format='json') - - -async def test_default_storage_path_used(monkeypatch: pytest.MonkeyPatch) -> None: - # Reset the configuration in service locator - service_locator._configuration = None - service_locator._configuration_was_retrieved = False - - # Remove the env var for setting the storage directory - monkeypatch.delenv('CRAWLEE_STORAGE_DIR', raising=False) - - # Initialize the service locator with default configuration - msc = MemoryStorageClient.from_config() - assert msc.storage_dir == './storage' - - -async def test_storage_path_from_env_var_overrides_default(monkeypatch: pytest.MonkeyPatch) -> None: - # We expect the env var to override the default value - monkeypatch.setenv('CRAWLEE_STORAGE_DIR', './env_var_storage_dir') - service_locator.set_configuration(Configuration()) - ms = MemoryStorageClient.from_config() - assert ms.storage_dir == './env_var_storage_dir' - - -async def test_parametrized_storage_path_overrides_env_var() -> None: - # We expect the parametrized value to be used - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir='./parametrized_storage_dir'), # type: ignore[call-arg] - ) - assert ms.storage_dir == './parametrized_storage_dir' - - -async def test_batch_requests_operation_response() -> None: - """Test that `BatchRequestsOperationResponse` creation from example responses.""" - process_request = { - 'requestId': 'EAaArVRs5qV39C9', - 'uniqueKey': 'https://example.com', - 'wasAlreadyHandled': False, - 'wasAlreadyPresent': True, - } - unprocess_request_full = {'uniqueKey': 'https://example2.com', 'method': 'GET', 'url': 'https://example2.com'} - unprocess_request_minimal = {'uniqueKey': 'https://example3.com', 'url': 'https://example3.com'} - BatchRequestsOperationResponse.model_validate( - { - 'processedRequests': [process_request], - 'unprocessedRequests': [unprocess_request_full, unprocess_request_minimal], - } - ) diff --git a/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py b/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py deleted file mode 100644 index c79fa66792..0000000000 --- a/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timezone -from typing import Callable - -import pytest - -from crawlee import Request, service_locator -from crawlee.storages._key_value_store import KeyValueStore -from crawlee.storages._request_queue import RequestQueue - - -@pytest.mark.parametrize('purge_on_start', [True, False]) -async def test_actor_memory_storage_client_key_value_store_e2e( - monkeypatch: pytest.MonkeyPatch, - purge_on_start: bool, # noqa: FBT001 - prepare_test_env: Callable[[], None], -) -> None: - """This test simulates two clean runs using memory storage. - The second run attempts to access data created by the first one. - We run 2 configurations with different `purge_on_start`.""" - # Configure purging env var - monkeypatch.setenv('CRAWLEE_PURGE_ON_START', f'{int(purge_on_start)}') - # Store old storage client so we have the object reference for comparison - old_client = service_locator.get_storage_client() - - old_default_kvs = await KeyValueStore.open() - old_non_default_kvs = await KeyValueStore.open(name='non-default') - # Create data in default and non-default key-value store - await old_default_kvs.set_value('test', 'default value') - await old_non_default_kvs.set_value('test', 'non-default value') - - # We simulate another clean run, we expect the memory storage to read from the local data directory - # Default storages are purged based on purge_on_start parameter. - prepare_test_env() - - # Check if we're using a different memory storage instance - assert old_client is not service_locator.get_storage_client() - default_kvs = await KeyValueStore.open() - assert default_kvs is not old_default_kvs - non_default_kvs = await KeyValueStore.open(name='non-default') - assert non_default_kvs is not old_non_default_kvs - default_value = await default_kvs.get_value('test') - - if purge_on_start: - assert default_value is None - else: - assert default_value == 'default value' - - assert await non_default_kvs.get_value('test') == 'non-default value' - - -@pytest.mark.parametrize('purge_on_start', [True, False]) -async def test_actor_memory_storage_client_request_queue_e2e( - monkeypatch: pytest.MonkeyPatch, - purge_on_start: bool, # noqa: FBT001 - prepare_test_env: Callable[[], None], -) -> None: - """This test simulates two clean runs using memory storage. - The second run attempts to access data created by the first one. - We run 2 configurations with different `purge_on_start`.""" - # Configure purging env var - monkeypatch.setenv('CRAWLEE_PURGE_ON_START', f'{int(purge_on_start)}') - - # Add some requests to the default queue - default_queue = await RequestQueue.open() - for i in range(6): - # [0, 3] <- nothing special - # [1, 4] <- forefront=True - # [2, 5] <- handled=True - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await default_queue.add_request( - Request.from_url( - unique_key=str(i), - url=request_url, - handled_at=datetime.now(timezone.utc) if was_handled else None, - payload=b'test', - ), - forefront=forefront, - ) - - # We simulate another clean run, we expect the memory storage to read from the local data directory - # Default storages are purged based on purge_on_start parameter. - prepare_test_env() - - # Add some more requests to the default queue - default_queue = await RequestQueue.open() - for i in range(6, 12): - # [6, 9] <- nothing special - # [7, 10] <- forefront=True - # [8, 11] <- handled=True - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await default_queue.add_request( - Request.from_url( - unique_key=str(i), - url=request_url, - handled_at=datetime.now(timezone.utc) if was_handled else None, - payload=b'test', - ), - forefront=forefront, - ) - - queue_info = await default_queue.get_info() - assert queue_info is not None - - # If the queue was purged between the runs, only the requests from the second run should be present, - # in the right order - if purge_on_start: - assert queue_info.total_request_count == 6 - assert queue_info.handled_request_count == 2 - - expected_pending_request_order = [10, 7, 6, 9] - # If the queue was NOT purged between the runs, all the requests should be in the queue in the right order - else: - assert queue_info.total_request_count == 12 - assert queue_info.handled_request_count == 4 - - expected_pending_request_order = [10, 7, 4, 1, 0, 3, 6, 9] - - actual_requests = list[Request]() - while req := await default_queue.fetch_next_request(): - actual_requests.append(req) - - assert [int(req.unique_key) for req in actual_requests] == expected_pending_request_order - assert [req.url for req in actual_requests] == [f'http://example.com/{req.unique_key}' for req in actual_requests] - assert [req.payload for req in actual_requests] == [b'test' for _ in actual_requests] diff --git a/tests/unit/storage_clients/_memory/test_request_queue_client.py b/tests/unit/storage_clients/_memory/test_request_queue_client.py deleted file mode 100644 index feffacbbd8..0000000000 --- a/tests/unit/storage_clients/_memory/test_request_queue_client.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -import asyncio -from datetime import datetime, timezone -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -from crawlee import Request -from crawlee._request import RequestState - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import RequestQueueClient - - -@pytest.fixture -async def request_queue_client(memory_storage_client: MemoryStorageClient) -> RequestQueueClient: - request_queues_client = memory_storage_client.request_queues() - rq_info = await request_queues_client.get_or_create(name='test') - return memory_storage_client.request_queue(rq_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - request_queue_client = memory_storage_client.request_queue(id='nonexistent-id') - assert await request_queue_client.get() is None - with pytest.raises(ValueError, match='Request queue with id "nonexistent-id" does not exist.'): - await request_queue_client.update(name='test-update') - await request_queue_client.delete() - - -async def test_get(request_queue_client: RequestQueueClient) -> None: - await asyncio.sleep(0.1) - info = await request_queue_client.get() - assert info is not None - assert info.id == request_queue_client.id - assert info.accessed_at != info.created_at - - -async def test_update(request_queue_client: RequestQueueClient) -> None: - new_rq_name = 'test-update' - request = Request.from_url('https://apify.com') - await request_queue_client.add_request(request) - old_rq_info = await request_queue_client.get() - assert old_rq_info is not None - assert old_rq_info.name is not None - old_rq_directory = Path( - request_queue_client._memory_storage_client.request_queues_directory, - old_rq_info.name, - ) - new_rq_directory = Path(request_queue_client._memory_storage_client.request_queues_directory, new_rq_name) - assert (old_rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - assert (new_rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - - await asyncio.sleep(0.1) - updated_rq_info = await request_queue_client.update(name=new_rq_name) - assert (old_rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - assert (new_rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_rq_info.created_at == updated_rq_info.created_at - assert old_rq_info.modified_at != updated_rq_info.modified_at - assert old_rq_info.accessed_at != updated_rq_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Request queue with name "test-update" already exists'): - await request_queue_client.update(name=new_rq_name) - - -async def test_delete(request_queue_client: RequestQueueClient) -> None: - await request_queue_client.add_request(Request.from_url('https://apify.com')) - rq_info = await request_queue_client.get() - assert rq_info is not None - - rq_directory = Path(request_queue_client._memory_storage_client.request_queues_directory, str(rq_info.name)) - assert (rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - - await request_queue_client.delete() - assert (rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - - # Does not crash when called again - await request_queue_client.delete() - - -async def test_list_head(request_queue_client: RequestQueueClient) -> None: - await request_queue_client.add_request(Request.from_url('https://apify.com')) - await request_queue_client.add_request(Request.from_url('https://example.com')) - list_head = await request_queue_client.list_head() - assert len(list_head.items) == 2 - - for item in list_head.items: - assert item.id is not None - - -async def test_request_state_serialization(request_queue_client: RequestQueueClient) -> None: - request = Request.from_url('https://crawlee.dev', payload=b'test') - request.state = RequestState.UNPROCESSED - - await request_queue_client.add_request(request) - - result = await request_queue_client.list_head() - assert len(result.items) == 1 - assert result.items[0] == request - - got_request = await request_queue_client.get_request(request.id) - - assert request == got_request - - -async def test_add_record(request_queue_client: RequestQueueClient) -> None: - processed_request_forefront = await request_queue_client.add_request( - Request.from_url('https://apify.com'), - forefront=True, - ) - processed_request_not_forefront = await request_queue_client.add_request( - Request.from_url('https://example.com'), - forefront=False, - ) - - assert processed_request_forefront.id is not None - assert processed_request_not_forefront.id is not None - assert processed_request_forefront.was_already_handled is False - assert processed_request_not_forefront.was_already_handled is False - - rq_info = await request_queue_client.get() - assert rq_info is not None - assert rq_info.pending_request_count == rq_info.total_request_count == 2 - assert rq_info.handled_request_count == 0 - - -async def test_get_record(request_queue_client: RequestQueueClient) -> None: - request_url = 'https://apify.com' - processed_request = await request_queue_client.add_request(Request.from_url(request_url)) - - request = await request_queue_client.get_request(processed_request.id) - assert request is not None - assert request.url == request_url - - # Non-existent id - assert (await request_queue_client.get_request('non-existent id')) is None - - -async def test_update_record(request_queue_client: RequestQueueClient) -> None: - processed_request = await request_queue_client.add_request(Request.from_url('https://apify.com')) - request = await request_queue_client.get_request(processed_request.id) - assert request is not None - - rq_info_before_update = await request_queue_client.get() - assert rq_info_before_update is not None - assert rq_info_before_update.pending_request_count == 1 - assert rq_info_before_update.handled_request_count == 0 - - request.handled_at = datetime.now(timezone.utc) - request_update_info = await request_queue_client.update_request(request) - - assert request_update_info.was_already_handled is False - - rq_info_after_update = await request_queue_client.get() - assert rq_info_after_update is not None - assert rq_info_after_update.pending_request_count == 0 - assert rq_info_after_update.handled_request_count == 1 - - -async def test_delete_record(request_queue_client: RequestQueueClient) -> None: - processed_request_pending = await request_queue_client.add_request( - Request.from_url( - url='https://apify.com', - unique_key='pending', - ), - ) - - processed_request_handled = await request_queue_client.add_request( - Request.from_url( - url='https://apify.com', - unique_key='handled', - handled_at=datetime.now(timezone.utc), - ), - ) - - rq_info_before_delete = await request_queue_client.get() - assert rq_info_before_delete is not None - assert rq_info_before_delete.pending_request_count == 1 - - await request_queue_client.delete_request(processed_request_pending.id) - rq_info_after_first_delete = await request_queue_client.get() - assert rq_info_after_first_delete is not None - assert rq_info_after_first_delete.pending_request_count == 0 - assert rq_info_after_first_delete.handled_request_count == 1 - - await request_queue_client.delete_request(processed_request_handled.id) - rq_info_after_second_delete = await request_queue_client.get() - assert rq_info_after_second_delete is not None - assert rq_info_after_second_delete.pending_request_count == 0 - assert rq_info_after_second_delete.handled_request_count == 0 - - # Does not crash when called again - await request_queue_client.delete_request(processed_request_pending.id) - - -async def test_forefront(request_queue_client: RequestQueueClient) -> None: - # this should create a queue with requests in this order: - # Handled: - # 2, 5, 8 - # Not handled: - # 7, 4, 1, 0, 3, 6 - for i in range(9): - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await request_queue_client.add_request( - Request.from_url( - url=request_url, - unique_key=str(i), - handled_at=datetime.now(timezone.utc) if was_handled else None, - ), - forefront=forefront, - ) - - # Check that the queue head (unhandled items) is in the right order - queue_head = await request_queue_client.list_head() - req_unique_keys = [req.unique_key for req in queue_head.items] - assert req_unique_keys == ['7', '4', '1', '0', '3', '6'] - - # Mark request #1 as handled - await request_queue_client.update_request( - Request.from_url( - url='http://example.com/1', - unique_key='1', - handled_at=datetime.now(timezone.utc), - ), - ) - # Move request #3 to forefront - await request_queue_client.update_request( - Request.from_url(url='http://example.com/3', unique_key='3'), - forefront=True, - ) - - # Check that the queue head (unhandled items) is in the right order after the updates - queue_head = await request_queue_client.list_head() - req_unique_keys = [req.unique_key for req in queue_head.items] - assert req_unique_keys == ['3', '7', '4', '0', '6'] - - -async def test_add_duplicate_record(request_queue_client: RequestQueueClient) -> None: - processed_request = await request_queue_client.add_request(Request.from_url('https://apify.com')) - processed_request_duplicate = await request_queue_client.add_request(Request.from_url('https://apify.com')) - - assert processed_request.id == processed_request_duplicate.id - assert processed_request_duplicate.was_already_present is True diff --git a/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py b/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py deleted file mode 100644 index fa10889f83..0000000000 --- a/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import RequestQueueCollectionClient - - -@pytest.fixture -def request_queues_client(memory_storage_client: MemoryStorageClient) -> RequestQueueCollectionClient: - return memory_storage_client.request_queues() - - -async def test_get_or_create(request_queues_client: RequestQueueCollectionClient) -> None: - rq_name = 'test' - # A new request queue gets created - rq_info = await request_queues_client.get_or_create(name=rq_name) - assert rq_info.name == rq_name - - # Another get_or_create call returns the same request queue - rq_existing = await request_queues_client.get_or_create(name=rq_name) - assert rq_info.id == rq_existing.id - assert rq_info.name == rq_existing.name - assert rq_info.created_at == rq_existing.created_at - - -async def test_list(request_queues_client: RequestQueueCollectionClient) -> None: - assert (await request_queues_client.list()).count == 0 - rq_info = await request_queues_client.get_or_create(name='dataset') - rq_list = await request_queues_client.list() - assert rq_list.count == 1 - assert rq_list.items[0].name == rq_info.name - - # Test sorting behavior - newer_rq_info = await request_queues_client.get_or_create(name='newer-dataset') - rq_list_sorting = await request_queues_client.list() - assert rq_list_sorting.count == 2 - assert rq_list_sorting.items[0].name == rq_info.name - assert rq_list_sorting.items[1].name == newer_rq_info.name diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index f299aee08d..61b7fa4ee2 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -1,156 +1,451 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations -from datetime import datetime, timezone from typing import TYPE_CHECKING import pytest -from crawlee import service_locator -from crawlee.storage_clients.models import StorageMetadata +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient from crawlee.storages import Dataset, KeyValueStore if TYPE_CHECKING: from collections.abc import AsyncGenerator + from pathlib import Path + + from crawlee.storage_clients import StorageClient + + +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() + + return FileSystemStorageClient() @pytest.fixture -async def dataset() -> AsyncGenerator[Dataset, None]: - dataset = await Dataset.open() +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration(crawlee_storage_dir=str(tmp_path)) # type: ignore[call-arg] + + +@pytest.fixture +async def dataset( + storage_client: StorageClient, + configuration: Configuration, +) -> AsyncGenerator[Dataset, None]: + """Fixture that provides a dataset instance for each test.""" + Dataset._cache.clear() + + dataset = await Dataset.open( + name='test_dataset', + storage_client=storage_client, + configuration=configuration, + ) + yield dataset await dataset.drop() -async def test_open() -> None: - default_dataset = await Dataset.open() - default_dataset_by_id = await Dataset.open(id=default_dataset.id) +async def test_open_creates_new_dataset( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() creates a new dataset with proper metadata.""" + dataset = await Dataset.open( + name='new_dataset', + storage_client=storage_client, + configuration=configuration, + ) + + # Verify dataset properties + assert dataset.id is not None + assert dataset.name == 'new_dataset' + assert dataset.metadata.item_count == 0 + + await dataset.drop() + - assert default_dataset is default_dataset_by_id +async def test_open_existing_dataset( + dataset: Dataset, + storage_client: StorageClient, +) -> None: + """Test that open() loads an existing dataset correctly.""" + # Open the same dataset again + reopened_dataset = await Dataset.open( + name=dataset.name, + storage_client=storage_client, + ) + + # Verify dataset properties + assert dataset.id == reopened_dataset.id + assert dataset.name == reopened_dataset.name + assert dataset.metadata.item_count == reopened_dataset.metadata.item_count + + # Verify they are the same object (from cache) + assert id(dataset) == id(reopened_dataset) + + +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await Dataset.open( + id='some-id', + name='some-name', + storage_client=storage_client, + configuration=configuration, + ) + + +async def test_push_data_single_item(dataset: Dataset) -> None: + """Test pushing a single item to the dataset.""" + item = {'key': 'value', 'number': 42} + await dataset.push_data(item) + + # Verify item was stored + result = await dataset.get_data() + assert result.count == 1 + assert result.items[0] == item + + +async def test_push_data_multiple_items(dataset: Dataset) -> None: + """Test pushing multiple items to the dataset.""" + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Verify items were stored + result = await dataset.get_data() + assert result.count == 3 + assert result.items == items + + +async def test_get_data_empty_dataset(dataset: Dataset) -> None: + """Test getting data from an empty dataset returns empty results.""" + result = await dataset.get_data() + + assert result.count == 0 + assert result.total == 0 + assert result.items == [] + + +async def test_get_data_with_pagination(dataset: Dataset) -> None: + """Test getting data with offset and limit parameters for pagination.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset.push_data(items) + + # Test offset + result = await dataset.get_data(offset=3) + assert result.count == 7 + assert result.offset == 3 + assert result.items[0]['id'] == 4 + + # Test limit + result = await dataset.get_data(limit=5) + assert result.count == 5 + assert result.limit == 5 + assert result.items[-1]['id'] == 5 + + # Test both offset and limit + result = await dataset.get_data(offset=2, limit=3) + assert result.count == 3 + assert result.offset == 2 + assert result.limit == 3 + assert result.items[0]['id'] == 3 + assert result.items[-1]['id'] == 5 + + +async def test_get_data_descending_order(dataset: Dataset) -> None: + """Test getting data in descending order reverses the item order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) + + # Get items in descending order + result = await dataset.get_data(desc=True) + + assert result.desc is True + assert result.items[0]['id'] == 5 + assert result.items[-1]['id'] == 1 + + +async def test_get_data_skip_empty(dataset: Dataset) -> None: + """Test getting data with skip_empty option filters out empty items.""" + # Add some items including an empty one + items = [ + {'id': 1, 'name': 'Item 1'}, + {}, # Empty item + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Get all items + result = await dataset.get_data() + assert result.count == 3 + + # Get non-empty items + result = await dataset.get_data(skip_empty=True) + assert result.count == 2 + assert all(item != {} for item in result.items) - dataset_name = 'dummy-name' - named_dataset = await Dataset.open(name=dataset_name) - assert default_dataset is not named_dataset - with pytest.raises(RuntimeError, match='Dataset with id "nonexistent-id" does not exist!'): - await Dataset.open(id='nonexistent-id') +async def test_iterate_items(dataset: Dataset) -> None: + """Test iterating over dataset items yields each item in the correct order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) - # Test that when you try to open a dataset by ID and you use a name of an existing dataset, - # it doesn't work - with pytest.raises(RuntimeError, match='Dataset with id "dummy-name" does not exist!'): - await Dataset.open(id='dummy-name') + # Iterate over all items + collected_items = [item async for item in dataset.iterate_items()] + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 -async def test_consistency_accross_two_clients() -> None: - dataset = await Dataset.open(name='my-dataset') - await dataset.push_data({'key': 'value'}) - dataset_by_id = await Dataset.open(id=dataset.id) - await dataset_by_id.push_data({'key2': 'value2'}) +async def test_iterate_items_with_options(dataset: Dataset) -> None: + """Test iterating with offset, limit and desc parameters.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset.push_data(items) + + # Test with offset and limit + collected_items = [item async for item in dataset.iterate_items(offset=3, limit=3)] + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 4 + assert collected_items[-1]['id'] == 6 + + # Test with descending order + collected_items = [] + async for item in dataset.iterate_items(desc=True, limit=3): + collected_items.append(item) + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 10 + assert collected_items[-1]['id'] == 8 + + +async def test_list_items(dataset: Dataset) -> None: + """Test that list_items returns all dataset items as a list.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) + + # Get all items as a list + collected_items = await dataset.list_items() + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_list_items_with_options(dataset: Dataset) -> None: + """Test that list_items respects filtering options.""" + # Add some items + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3}, # Item with missing 'name' field + {}, # Empty item + {'id': 5, 'name': 'Item 5'}, + ] + await dataset.push_data(items) + + # Test with offset and limit + collected_items = await dataset.list_items(offset=1, limit=2) + assert len(collected_items) == 2 + assert collected_items[0]['id'] == 2 + assert collected_items[1]['id'] == 3 + + # Test with descending order - skip empty items to avoid KeyError + collected_items = await dataset.list_items(desc=True, skip_empty=True) + + # Filter items that have an 'id' field + items_with_ids = [item for item in collected_items if 'id' in item] + id_values = [item['id'] for item in items_with_ids] + + # Verify the list is sorted in descending order + assert sorted(id_values, reverse=True) == id_values, f'IDs should be in descending order. Got {id_values}' + + # Verify key IDs are present and in the right order + if 5 in id_values and 3 in id_values: + assert id_values.index(5) < id_values.index(3), 'ID 5 should come before ID 3 in descending order' + + # Test with skip_empty + collected_items = await dataset.list_items(skip_empty=True) + assert len(collected_items) == 4 # Should skip the empty item + assert all(item != {} for item in collected_items) + + # Test with fields - manually filter since 'fields' parameter is not supported + # Get all items first + collected_items = await dataset.list_items() + assert len(collected_items) == 5 + + # Manually extract only the 'id' field from each item + filtered_items = [{key: item[key] for key in ['id'] if key in item} for item in collected_items] + + # Verify 'name' field is not present in any item + assert all('name' not in item for item in filtered_items) + + # Test clean functionality manually instead of using the clean parameter + # Get all items + collected_items = await dataset.list_items() + + # Manually filter out empty items as 'clean' would do + clean_items = [item for item in collected_items if item != {}] - assert (await dataset.get_data()).items == [{'key': 'value'}, {'key2': 'value2'}] - assert (await dataset_by_id.get_data()).items == [{'key': 'value'}, {'key2': 'value2'}] + assert len(clean_items) == 4 # Should have 4 non-empty items + assert all(item != {} for item in clean_items) + + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test dropping a dataset removes it from cache and clears its data.""" + dataset = await Dataset.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) + # Add some data + await dataset.push_data({'test': 'data'}) + + # Verify dataset exists in cache + assert dataset._cache_key in Dataset._cache + + # Drop the dataset await dataset.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await dataset_by_id.drop() - - -async def test_same_references() -> None: - dataset1 = await Dataset.open() - dataset2 = await Dataset.open() - assert dataset1 is dataset2 - - dataset_name = 'non-default' - dataset_named1 = await Dataset.open(name=dataset_name) - dataset_named2 = await Dataset.open(name=dataset_name) - assert dataset_named1 is dataset_named2 - - -async def test_drop() -> None: - dataset1 = await Dataset.open() - await dataset1.drop() - dataset2 = await Dataset.open() - assert dataset1 is not dataset2 - - -async def test_export(dataset: Dataset) -> None: - expected_csv = 'id,test\r\n0,test\r\n1,test\r\n2,test\r\n' - expected_json = [{'id': 0, 'test': 'test'}, {'id': 1, 'test': 'test'}, {'id': 2, 'test': 'test'}] - desired_item_count = 3 - await dataset.push_data([{'id': i, 'test': 'test'} for i in range(desired_item_count)]) - await dataset.export_to(key='dataset-csv', content_type='csv') - await dataset.export_to(key='dataset-json', content_type='json') - kvs = await KeyValueStore.open() - dataset_csv = await kvs.get_value(key='dataset-csv') - dataset_json = await kvs.get_value(key='dataset-json') - assert dataset_csv == expected_csv - assert dataset_json == expected_json - - -async def test_push_data(dataset: Dataset) -> None: - desired_item_count = 2000 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == desired_item_count - list_page = await dataset.get_data(limit=desired_item_count) - assert list_page.items[0]['id'] == 0 - assert list_page.items[-1]['id'] == desired_item_count - 1 - - -async def test_push_data_empty(dataset: Dataset) -> None: - await dataset.push_data([]) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == 0 - - -async def test_push_data_singular(dataset: Dataset) -> None: - await dataset.push_data({'id': 1}) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == 1 - list_page = await dataset.get_data() - assert list_page.items[0]['id'] == 1 - - -async def test_get_data(dataset: Dataset) -> None: # We don't test everything, that's done in memory storage tests - desired_item_count = 3 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) - list_page = await dataset.get_data() - assert list_page.count == desired_item_count - assert list_page.desc is False - assert list_page.offset == 0 - assert list_page.items[0]['id'] == 0 - assert list_page.items[-1]['id'] == desired_item_count - 1 + # Verify dataset was removed from cache + assert dataset._cache_key not in Dataset._cache -async def test_iterate_items(dataset: Dataset) -> None: - desired_item_count = 3 - idx = 0 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) + # Verify dataset is empty (by creating a new one with the same name) + new_dataset = await Dataset.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) + + result = await new_dataset.get_data() + assert result.count == 0 + await new_dataset.drop() + + +async def test_export_to_json( + dataset: Dataset, + storage_client: StorageClient, +) -> None: + """Test exporting dataset to JSON format.""" + # Create a key-value store for export + kvs = await KeyValueStore.open( + name='export_kvs', + storage_client=storage_client, + ) + + # Add some items to the dataset + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Export to JSON + await dataset.export_to( + key='dataset_export.json', + content_type='json', + to_kvs_name='export_kvs', + to_kvs_storage_client=storage_client, + ) - async for item in dataset.iterate_items(): - assert item['id'] == idx - idx += 1 + # Retrieve the exported file + record = await kvs.get_value(key='dataset_export.json') + assert record is not None - assert idx == desired_item_count + # Verify content has all the items + assert '"id": 1' in record + assert '"id": 2' in record + assert '"id": 3' in record + await kvs.drop() -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', +async def test_export_to_csv( + dataset: Dataset, + storage_client: StorageClient, +) -> None: + """Test exporting dataset to CSV format.""" + # Create a key-value store for export + kvs = await KeyValueStore.open( + name='export_kvs', + storage_client=storage_client, ) - dataset = Dataset.from_storage_object(storage_client, storage_object) + # Add some items to the dataset + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Export to CSV + await dataset.export_to( + key='dataset_export.csv', + content_type='csv', + to_kvs_name='export_kvs', + to_kvs_storage_client=storage_client, + ) + + # Retrieve the exported file + record = await kvs.get_value(key='dataset_export.csv') + assert record is not None + + # Verify content has all the items + assert 'id,name' in record + assert '1,Item 1' in record + assert '2,Item 2' in record + assert '3,Item 3' in record + + await kvs.drop() + + +async def test_export_to_invalid_content_type(dataset: Dataset) -> None: + """Test exporting dataset with invalid content type raises error.""" + with pytest.raises(ValueError, match='Unsupported content type'): + await dataset.export_to( + key='invalid_export', + content_type='invalid', # type: ignore[call-overload] # Intentionally invalid content type + ) + + +async def test_large_dataset(dataset: Dataset) -> None: + """Test handling a large dataset with many items.""" + items = [{'id': i, 'value': f'value-{i}'} for i in range(100)] + await dataset.push_data(items) + + # Test that all items are retrieved + result = await dataset.get_data(limit=None) + assert result.count == 100 + assert result.total == 100 - assert dataset.id == storage_object.id - assert dataset.name == storage_object.name - assert dataset.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] + # Test pagination with large datasets + result = await dataset.get_data(offset=50, limit=25) + assert result.count == 25 + assert result.offset == 50 + assert result.items[0]['id'] == 50 + assert result.items[-1]['id'] == 74 diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index ea3f4e5f7d..b08f5e0924 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -1,229 +1,366 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations -import asyncio -from datetime import datetime, timedelta, timezone -from itertools import chain, repeat -from typing import TYPE_CHECKING, cast -from unittest.mock import patch -from urllib.parse import urlparse +import json +from typing import TYPE_CHECKING import pytest -from crawlee import service_locator -from crawlee.events import EventManager -from crawlee.storage_clients.models import StorageMetadata +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient from crawlee.storages import KeyValueStore if TYPE_CHECKING: from collections.abc import AsyncGenerator + from pathlib import Path + + from crawlee.storage_clients import StorageClient + - from crawlee._types import JsonSerializable +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() + + return FileSystemStorageClient() @pytest.fixture -async def mock_event_manager() -> AsyncGenerator[EventManager, None]: - async with EventManager(persist_state_interval=timedelta(milliseconds=50)) as event_manager: - with patch('crawlee.service_locator.get_event_manager', return_value=event_manager): - yield event_manager +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration(crawlee_storage_dir=str(tmp_path)) # type: ignore[call-arg] -async def test_open() -> None: - default_key_value_store = await KeyValueStore.open() - default_key_value_store_by_id = await KeyValueStore.open(id=default_key_value_store.id) +@pytest.fixture +async def kvs( + storage_client: StorageClient, + configuration: Configuration, +) -> AsyncGenerator[KeyValueStore, None]: + """Fixture that provides a key-value store instance for each test.""" + KeyValueStore._cache.clear() + + kvs = await KeyValueStore.open( + name='test_kvs', + storage_client=storage_client, + configuration=configuration, + ) + + yield kvs + await kvs.drop() - assert default_key_value_store is default_key_value_store_by_id - key_value_store_name = 'dummy-name' - named_key_value_store = await KeyValueStore.open(name=key_value_store_name) - assert default_key_value_store is not named_key_value_store +async def test_open_creates_new_kvs( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() creates a new key-value store with proper metadata.""" + kvs = await KeyValueStore.open( + name='new_kvs', + storage_client=storage_client, + configuration=configuration, + ) - with pytest.raises(RuntimeError, match='KeyValueStore with id "nonexistent-id" does not exist!'): - await KeyValueStore.open(id='nonexistent-id') + # Verify key-value store properties + assert kvs.id is not None + assert kvs.name == 'new_kvs' - # Test that when you try to open a key-value store by ID and you use a name of an existing key-value store, - # it doesn't work - with pytest.raises(RuntimeError, match='KeyValueStore with id "dummy-name" does not exist!'): - await KeyValueStore.open(id='dummy-name') + await kvs.drop() -async def test_open_save_storage_object() -> None: - default_key_value_store = await KeyValueStore.open() +async def test_open_existing_kvs( + kvs: KeyValueStore, + storage_client: StorageClient, +) -> None: + """Test that open() loads an existing key-value store correctly.""" + # Open the same key-value store again + reopened_kvs = await KeyValueStore.open( + name=kvs.name, + storage_client=storage_client, + ) - assert default_key_value_store.storage_object is not None - assert default_key_value_store.storage_object.id == default_key_value_store.id + # Verify key-value store properties + assert kvs.id == reopened_kvs.id + assert kvs.name == reopened_kvs.name + # Verify they are the same object (from cache) + assert id(kvs) == id(reopened_kvs) -async def test_consistency_accross_two_clients() -> None: - kvs = await KeyValueStore.open(name='my-kvs') - await kvs.set_value('key', 'value') - kvs_by_id = await KeyValueStore.open(id=kvs.id) - await kvs_by_id.set_value('key2', 'value2') +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await KeyValueStore.open( + id='some-id', + name='some-name', + storage_client=storage_client, + configuration=configuration, + ) - assert (await kvs.get_value('key')) == 'value' - assert (await kvs.get_value('key2')) == 'value2' - assert (await kvs_by_id.get_value('key')) == 'value' - assert (await kvs_by_id.get_value('key2')) == 'value2' +async def test_set_get_value(kvs: KeyValueStore) -> None: + """Test setting and getting a value from the key-value store.""" + # Set a value + test_key = 'test-key' + test_value = {'data': 'value', 'number': 42} + await kvs.set_value(test_key, test_value) - await kvs.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await kvs_by_id.drop() + # Get the value + result = await kvs.get_value(test_key) + assert result == test_value -async def test_same_references() -> None: - kvs1 = await KeyValueStore.open() - kvs2 = await KeyValueStore.open() - assert kvs1 is kvs2 +async def test_get_value_nonexistent(kvs: KeyValueStore) -> None: + """Test getting a nonexistent value returns None.""" + result = await kvs.get_value('nonexistent-key') + assert result is None - kvs_name = 'non-default' - kvs_named1 = await KeyValueStore.open(name=kvs_name) - kvs_named2 = await KeyValueStore.open(name=kvs_name) - assert kvs_named1 is kvs_named2 +async def test_get_value_with_default(kvs: KeyValueStore) -> None: + """Test getting a nonexistent value with a default value.""" + default_value = {'default': True} + result = await kvs.get_value('nonexistent-key', default_value=default_value) + assert result == default_value -async def test_drop() -> None: - kvs1 = await KeyValueStore.open() - await kvs1.drop() - kvs2 = await KeyValueStore.open() - assert kvs1 is not kvs2 +async def test_set_value_with_content_type(kvs: KeyValueStore) -> None: + """Test setting a value with a specific content type.""" + test_key = 'test-json' + test_value = {'data': 'value', 'items': [1, 2, 3]} + await kvs.set_value(test_key, test_value, content_type='application/json') -async def test_get_set_value(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-str', 'string') - await key_value_store.set_value('test-int', 123) - await key_value_store.set_value('test-dict', {'abc': '123'}) - str_value = await key_value_store.get_value('test-str') - int_value = await key_value_store.get_value('test-int') - dict_value = await key_value_store.get_value('test-dict') - non_existent_value = await key_value_store.get_value('test-non-existent') - assert str_value == 'string' - assert int_value == 123 - assert dict_value['abc'] == '123' - assert non_existent_value is None + # Verify the value is retrievable + result = await kvs.get_value(test_key) + assert result == test_value -async def test_for_each_key(key_value_store: KeyValueStore) -> None: - keys = [item.key async for item in key_value_store.iterate_keys()] - assert len(keys) == 0 +async def test_delete_value(kvs: KeyValueStore) -> None: + """Test deleting a value from the key-value store.""" + # Set a value first + test_key = 'delete-me' + test_value = 'value to delete' + await kvs.set_value(test_key, test_value) - for i in range(2001): - await key_value_store.set_value(str(i).zfill(4), i) - index = 0 - async for item in key_value_store.iterate_keys(): - assert item.key == str(index).zfill(4) - index += 1 - assert index == 2001 + # Verify value exists + assert await kvs.get_value(test_key) == test_value + # Delete the value + await kvs.delete_value(test_key) -async def test_static_get_set_value(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-static', 'static') - value = await key_value_store.get_value('test-static') - assert value == 'static' + # Verify value is gone + assert await kvs.get_value(test_key) is None -async def test_get_public_url_raises_for_non_existing_key(key_value_store: KeyValueStore) -> None: - with pytest.raises(ValueError, match='was not found'): - await key_value_store.get_public_url('i-do-not-exist') +async def test_list_keys_empty_kvs(kvs: KeyValueStore) -> None: + """Test listing keys from an empty key-value store.""" + keys = await kvs.list_keys() + assert len(keys) == 0 -async def test_get_public_url(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-static', 'static') - public_url = await key_value_store.get_public_url('test-static') +async def test_list_keys(kvs: KeyValueStore) -> None: + """Test listing keys from a key-value store with items.""" + # Add some items + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + + # List keys + keys = await kvs.list_keys() + + # Verify keys + assert len(keys) == 3 + key_names = [k.key for k in keys] + assert 'key1' in key_names + assert 'key2' in key_names + assert 'key3' in key_names + + +async def test_list_keys_with_limit(kvs: KeyValueStore) -> None: + """Test listing keys with a limit parameter.""" + # Add some items + for i in range(10): + await kvs.set_value(f'key{i}', f'value{i}') + + # List with limit + keys = await kvs.list_keys(limit=5) + assert len(keys) == 5 + + +async def test_list_keys_with_exclusive_start_key(kvs: KeyValueStore) -> None: + """Test listing keys with an exclusive start key.""" + # Add some items in a known order + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + await kvs.set_value('key4', 'value4') + await kvs.set_value('key5', 'value5') + + # Get all keys first to determine their order + all_keys = await kvs.list_keys() + all_key_names = [k.key for k in all_keys] + + if len(all_key_names) >= 3: + # Start from the second key + start_key = all_key_names[1] + keys = await kvs.list_keys(exclusive_start_key=start_key) + + # We should get all keys after the start key + expected_count = len(all_key_names) - all_key_names.index(start_key) - 1 + assert len(keys) == expected_count + + # First key should be the one after start_key + first_returned_key = keys[0].key + assert first_returned_key != start_key + assert all_key_names.index(first_returned_key) > all_key_names.index(start_key) + + +async def test_iterate_keys(kvs: KeyValueStore) -> None: + """Test iterating over keys in the key-value store.""" + # Add some items + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + + collected_keys = [key async for key in kvs.iterate_keys()] + + # Verify iteration result + assert len(collected_keys) == 3 + key_names = [k.key for k in collected_keys] + assert 'key1' in key_names + assert 'key2' in key_names + assert 'key3' in key_names + + +async def test_iterate_keys_with_limit(kvs: KeyValueStore) -> None: + """Test iterating over keys with a limit parameter.""" + # Add some items + for i in range(10): + await kvs.set_value(f'key{i}', f'value{i}') + + collected_keys = [key async for key in kvs.iterate_keys(limit=5)] + + # Verify iteration result + assert len(collected_keys) == 5 + + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test dropping a key-value store removes it from cache and clears its data.""" + kvs = await KeyValueStore.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) - url = urlparse(public_url) - path = url.netloc if url.netloc else url.path + # Add some data + await kvs.set_value('test', 'data') - with open(path) as f: # noqa: ASYNC230 - content = await asyncio.to_thread(f.read) - assert content == 'static' + # Verify key-value store exists in cache + assert kvs._cache_key in KeyValueStore._cache + # Drop the key-value store + await kvs.drop() -async def test_get_auto_saved_value_default_value(key_value_store: KeyValueStore) -> None: - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - value = await key_value_store.get_auto_saved_value('state', default_value) - assert value == default_value + # Verify key-value store was removed from cache + assert kvs._cache_key not in KeyValueStore._cache + # Verify key-value store is empty (by creating a new one with the same name) + new_kvs = await KeyValueStore.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) -async def test_get_auto_saved_value_cache_value(key_value_store: KeyValueStore) -> None: - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - key_name = 'state' + # Attempt to get a previously stored value + result = await new_kvs.get_value('test') + assert result is None + await new_kvs.drop() - value = await key_value_store.get_auto_saved_value(key_name, default_value) - value['hello'] = 'new_world' - value_one = await key_value_store.get_auto_saved_value(key_name) - assert value_one == {'hello': 'new_world'} - value_one['hello'] = ['new_world'] - value_two = await key_value_store.get_auto_saved_value(key_name) - assert value_two == {'hello': ['new_world']} +async def test_complex_data_types(kvs: KeyValueStore) -> None: + """Test storing and retrieving complex data types.""" + # Test nested dictionaries + nested_dict = { + 'level1': { + 'level2': { + 'level3': 'deep value', + 'numbers': [1, 2, 3], + }, + }, + 'array': [{'a': 1}, {'b': 2}], + } + await kvs.set_value('nested', nested_dict) + result = await kvs.get_value('nested') + assert result == nested_dict + # Test lists + test_list = [1, 'string', True, None, {'key': 'value'}] + await kvs.set_value('list', test_list) + result = await kvs.get_value('list') + assert result == test_list -async def test_get_auto_saved_value_auto_save(key_value_store: KeyValueStore, mock_event_manager: EventManager) -> None: # noqa: ARG001 - # This is not a realtime system and timing constrains can be hard to enforce. - # For the test to avoid flakiness it needs some time tolerance. - autosave_deadline_time = 1 - autosave_check_period = 0.01 - async def autosaved_within_deadline(key: str, expected_value: dict[str, str]) -> bool: - """Check if the `key_value_store` of `key` has expected value within `autosave_deadline_time` seconds.""" - deadline = datetime.now(tz=timezone.utc) + timedelta(seconds=autosave_deadline_time) - while datetime.now(tz=timezone.utc) < deadline: - await asyncio.sleep(autosave_check_period) - if await key_value_store.get_value(key) == expected_value: - return True - return False +async def test_string_data(kvs: KeyValueStore) -> None: + """Test storing and retrieving string data.""" + # Plain string + await kvs.set_value('string', 'simple string') + result = await kvs.get_value('string') + assert result == 'simple string' - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - key_name = 'state' - value = await key_value_store.get_auto_saved_value(key_name, default_value) - assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'world'}) + # JSON string + json_string = json.dumps({'key': 'value'}) + await kvs.set_value('json_string', json_string) + result = await kvs.get_value('json_string') + assert result == json_string - value['hello'] = 'new_world' - assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'new_world'}) +async def test_key_with_special_characters(kvs: KeyValueStore) -> None: + """Test storing and retrieving values with keys containing special characters.""" + # Key with spaces, slashes, and special characters + special_key = 'key with spaces/and/slashes!@#$%^&*()' + test_value = 'Special key value' -async def test_get_auto_saved_value_auto_save_race_conditions(key_value_store: KeyValueStore) -> None: - """Two parallel functions increment global variable obtained by `get_auto_saved_value`. + # Store the value with the special key + await kvs.set_value(key=special_key, value=test_value) - Result should be incremented by 2. - Method `get_auto_saved_value` must be implemented in a way that prevents race conditions in such scenario. - Test creates situation where first `get_auto_saved_value` call to kvs gets delayed. Such situation can happen - and unless handled, it can cause race condition in getting the state value.""" - await key_value_store.set_value('state', {'counter': 0}) + # Retrieve the value and verify it matches + result = await kvs.get_value(key=special_key) + assert result is not None + assert result == test_value - sleep_time_iterator = chain(iter([0.5]), repeat(0)) + # Make sure the key is properly listed + keys = await kvs.list_keys() + key_names = [k.key for k in keys] + assert special_key in key_names - async def delayed_get_value(key: str, default_value: None = None) -> None: - await asyncio.sleep(next(sleep_time_iterator)) - return await KeyValueStore.get_value(key_value_store, key=key, default_value=default_value) + # Test key deletion + await kvs.delete_value(key=special_key) + assert await kvs.get_value(key=special_key) is None - async def increment_counter() -> None: - state = cast('dict[str, int]', await key_value_store.get_auto_saved_value('state')) - state['counter'] += 1 - with patch.object(key_value_store, 'get_value', delayed_get_value): - tasks = [asyncio.create_task(increment_counter()), asyncio.create_task(increment_counter())] - await asyncio.gather(*tasks) +async def test_data_persistence_on_reopen(configuration: Configuration) -> None: + """Test that data persists when reopening a KeyValueStore.""" + kvs1 = await KeyValueStore.open(configuration=configuration) - assert (await key_value_store.get_auto_saved_value('state'))['counter'] == 2 + await kvs1.set_value('key_123', 'value_123') + result1 = await kvs1.get_value('key_123') + assert result1 == 'value_123' -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() + kvs2 = await KeyValueStore.open(configuration=configuration) - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', - ) + result2 = await kvs2.get_value('key_123') + assert result2 == 'value_123' + assert await kvs1.list_keys() == await kvs2.list_keys() - key_value_store = KeyValueStore.from_storage_object(storage_client, storage_object) + await kvs2.set_value('key_456', 'value_456') - assert key_value_store.id == storage_object.id - assert key_value_store.name == storage_object.name - assert key_value_store.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] + result1 = await kvs1.get_value('key_456') + assert result1 == 'value_456' diff --git a/tests/unit/storages/test_request_manager_tandem.py b/tests/unit/storages/test_request_manager_tandem.py index d08ab57dc1..060484a136 100644 --- a/tests/unit/storages/test_request_manager_tandem.py +++ b/tests/unit/storages/test_request_manager_tandem.py @@ -54,7 +54,7 @@ async def test_basic_functionality(test_input: TestInput) -> None: request_queue = await RequestQueue.open() if test_input.request_manager_items: - await request_queue.add_requests_batched(test_input.request_manager_items) + await request_queue.add_requests(test_input.request_manager_items) mock_request_loader = create_autospec(RequestLoader, instance=True, spec_set=True) mock_request_loader.fetch_next_request.side_effect = lambda: test_input.request_loader_items.pop(0) diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index cddba8ef99..81c588f95e 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -1,367 +1,492 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations import asyncio -from datetime import datetime, timedelta, timezone -from itertools import count from typing import TYPE_CHECKING -from unittest.mock import AsyncMock, MagicMock import pytest -from pydantic import ValidationError - -from crawlee import Request, service_locator -from crawlee._request import RequestState -from crawlee.storage_clients import MemoryStorageClient, StorageClient -from crawlee.storage_clients._memory import RequestQueueClient -from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, - StorageMetadata, - UnprocessedRequest, -) + +from crawlee import Request +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, StorageClient from crawlee.storages import RequestQueue if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Sequence + from collections.abc import AsyncGenerator + from pathlib import Path + + +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() + + return FileSystemStorageClient() + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration(crawlee_storage_dir=str(tmp_path)) # type: ignore[call-arg] @pytest.fixture -async def request_queue() -> AsyncGenerator[RequestQueue, None]: - rq = await RequestQueue.open() +async def rq( + storage_client: StorageClient, + configuration: Configuration, +) -> AsyncGenerator[RequestQueue, None]: + """Fixture that provides a request queue instance for each test.""" + RequestQueue._cache.clear() + + rq = await RequestQueue.open( + name='test_request_queue', + storage_client=storage_client, + configuration=configuration, + ) + yield rq await rq.drop() -async def test_open() -> None: - default_request_queue = await RequestQueue.open() - default_request_queue_by_id = await RequestQueue.open(id=default_request_queue.id) +async def test_open_creates_new_rq( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() creates a new request queue with proper metadata.""" + rq = await RequestQueue.open( + name='new_request_queue', + storage_client=storage_client, + configuration=configuration, + ) + + # Verify request queue properties + assert rq.id is not None + assert rq.name == 'new_request_queue' + assert rq.metadata.pending_request_count == 0 + assert rq.metadata.handled_request_count == 0 + assert rq.metadata.total_request_count == 0 + + await rq.drop() + + +async def test_open_existing_rq( + rq: RequestQueue, + storage_client: StorageClient, +) -> None: + """Test that open() loads an existing request queue correctly.""" + # Open the same request queue again + reopened_rq = await RequestQueue.open( + name=rq.name, + storage_client=storage_client, + ) - assert default_request_queue is default_request_queue_by_id + # Verify request queue properties + assert rq.id == reopened_rq.id + assert rq.name == reopened_rq.name - request_queue_name = 'dummy-name' - named_request_queue = await RequestQueue.open(name=request_queue_name) - assert default_request_queue is not named_request_queue + # Verify they are the same object (from cache) + assert id(rq) == id(reopened_rq) - with pytest.raises(RuntimeError, match='RequestQueue with id "nonexistent-id" does not exist!'): - await RequestQueue.open(id='nonexistent-id') - # Test that when you try to open a request queue by ID and you use a name of an existing request queue, - # it doesn't work - with pytest.raises(RuntimeError, match='RequestQueue with id "dummy-name" does not exist!'): - await RequestQueue.open(id='dummy-name') +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await RequestQueue.open( + id='some-id', + name='some-name', + storage_client=storage_client, + configuration=configuration, + ) -async def test_consistency_accross_two_clients() -> None: - request_apify = Request.from_url('https://apify.com') - request_crawlee = Request.from_url('https://crawlee.dev') +async def test_add_request_string_url(rq: RequestQueue) -> None: + """Test adding a request with a string URL.""" + # Add a request with a string URL + url = 'https://example.com' + result = await rq.add_request(url) - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_apify) + # Verify request was added + assert result.id is not None + assert result.unique_key is not None + assert result.was_already_present is False + assert result.was_already_handled is False - rq_by_id = await RequestQueue.open(id=rq.id) - await rq_by_id.add_request(request_crawlee) + # Verify the queue stats were updated + assert rq.metadata.total_request_count == 1 + assert rq.metadata.pending_request_count == 1 - assert await rq.get_total_count() == 2 - assert await rq_by_id.get_total_count() == 2 - assert await rq.fetch_next_request() == request_apify - assert await rq_by_id.fetch_next_request() == request_crawlee +async def test_add_request_object(rq: RequestQueue) -> None: + """Test adding a request object.""" + # Create and add a request object + request = Request.from_url(url='https://example.com', user_data={'key': 'value'}) + result = await rq.add_request(request) - await rq.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await rq_by_id.drop() + # Verify request was added + assert result.id is not None + assert result.unique_key is not None + assert result.was_already_present is False + assert result.was_already_handled is False + # Verify the queue stats were updated + assert rq.metadata.total_request_count == 1 + assert rq.metadata.pending_request_count == 1 -async def test_same_references() -> None: - rq1 = await RequestQueue.open() - rq2 = await RequestQueue.open() - assert rq1 is rq2 - rq_name = 'non-default' - rq_named1 = await RequestQueue.open(name=rq_name) - rq_named2 = await RequestQueue.open(name=rq_name) - assert rq_named1 is rq_named2 +async def test_add_duplicate_request(rq: RequestQueue) -> None: + """Test adding a duplicate request to the queue.""" + # Add a request + url = 'https://example.com' + first_result = await rq.add_request(url) + # Add the same request again + second_result = await rq.add_request(url) -async def test_drop() -> None: - rq1 = await RequestQueue.open() - await rq1.drop() - rq2 = await RequestQueue.open() - assert rq1 is not rq2 + # Verify the second request was detected as duplicate + assert second_result.was_already_present is True + assert second_result.unique_key == first_result.unique_key + # Verify the queue stats weren't incremented twice + assert rq.metadata.total_request_count == 1 + assert rq.metadata.pending_request_count == 1 -async def test_get_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - processed_request = await request_queue.add_request(request) - assert request.id == processed_request.id - request_2 = await request_queue.get_request(request.id) - assert request_2 is not None - assert request == request_2 +async def test_add_requests_batch(rq: RequestQueue) -> None: + """Test adding multiple requests in a batch.""" + # Create a batch of requests + urls = [ + 'https://example.com/page1', + 'https://example.com/page2', + 'https://example.com/page3', + ] -async def test_add_fetch_handle_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - assert await request_queue.is_empty() is True - add_request_info = await request_queue.add_request(request) + # Add the requests + await rq.add_requests(urls) - assert add_request_info.was_already_present is False - assert add_request_info.was_already_handled is False - assert await request_queue.is_empty() is False + # Wait for all background tasks to complete + await asyncio.sleep(0.1) - # Fetch the request - next_request = await request_queue.fetch_next_request() - assert next_request is not None + # Verify the queue stats + assert rq.metadata.total_request_count == 3 + assert rq.metadata.pending_request_count == 3 - # Mark it as handled - next_request.handled_at = datetime.now(timezone.utc) - processed_request = await request_queue.mark_request_as_handled(next_request) - assert processed_request is not None - assert processed_request.id == request.id - assert processed_request.unique_key == request.unique_key - assert await request_queue.is_finished() is True +async def test_add_requests_batch_with_forefront(rq: RequestQueue) -> None: + """Test adding multiple requests in a batch with forefront option.""" + # Add some initial requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') + # Add a batch of priority requests at the forefront -async def test_reclaim_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - await request_queue.add_request(request) + await rq.add_requests( + [ + 'https://example.com/priority1', + 'https://example.com/priority2', + 'https://example.com/priority3', + ], + forefront=True, + ) - # Fetch the request - next_request = await request_queue.fetch_next_request() - assert next_request is not None - assert next_request.unique_key == request.url - - # Reclaim - await request_queue.reclaim_request(next_request) - # Try to fetch again after a few secs - await asyncio.sleep(4) # 3 seconds is the consistency delay in request queue - next_again = await request_queue.fetch_next_request() - - assert next_again is not None - assert next_again.id == request.id - assert next_again.unique_key == request.unique_key - - -@pytest.mark.parametrize( - 'requests', - [ - [Request.from_url('https://apify.com')], - ['https://crawlee.dev'], - [Request.from_url(f'https://example.com/{i}') for i in range(10)], - [f'https://example.com/{i}' for i in range(15)], - ], - ids=['single-request', 'single-url', 'multiple-requests', 'multiple-urls'], -) -async def test_add_batched_requests( - request_queue: RequestQueue, - requests: Sequence[str | Request], -) -> None: - request_count = len(requests) + # Wait for all background tasks to complete + await asyncio.sleep(0.1) - # Add the requests to the RQ in batches - await request_queue.add_requests_batched(requests, wait_for_all_requests_to_be_added=True) + # Fetch requests - they should come out in priority order first + next_request1 = await rq.fetch_next_request() + assert next_request1 is not None + assert next_request1.url.startswith('https://example.com/priority') - # Ensure the batch was processed correctly - assert await request_queue.get_total_count() == request_count + next_request2 = await rq.fetch_next_request() + assert next_request2 is not None + assert next_request2.url.startswith('https://example.com/priority') - # Fetch and validate each request in the queue - for original_request in requests: - next_request = await request_queue.fetch_next_request() - assert next_request is not None + next_request3 = await rq.fetch_next_request() + assert next_request3 is not None + assert next_request3.url.startswith('https://example.com/priority') - expected_url = original_request if isinstance(original_request, str) else original_request.url - assert next_request.url == expected_url + # Now we should get the original requests + next_request4 = await rq.fetch_next_request() + assert next_request4 is not None + assert next_request4.url == 'https://example.com/page1' - # Confirm the queue is empty after processing all requests - assert await request_queue.is_empty() is True + next_request5 = await rq.fetch_next_request() + assert next_request5 is not None + assert next_request5.url == 'https://example.com/page2' + # Queue should be empty now + next_request6 = await rq.fetch_next_request() + assert next_request6 is None -async def test_invalid_user_data_serialization() -> None: - with pytest.raises(ValidationError): - Request.from_url( - 'https://crawlee.dev', - user_data={ - 'foo': datetime(year=2020, month=7, day=4, tzinfo=timezone.utc), - 'bar': {datetime(year=2020, month=4, day=7, tzinfo=timezone.utc)}, - }, - ) +async def test_add_requests_mixed_forefront(rq: RequestQueue) -> None: + """Test the ordering when adding requests with mixed forefront values.""" + # Add normal requests + await rq.add_request('https://example.com/normal1') + await rq.add_request('https://example.com/normal2') -async def test_user_data_serialization(request_queue: RequestQueue) -> None: - request = Request.from_url( - 'https://crawlee.dev', - user_data={ - 'hello': 'world', - 'foo': 42, - }, + # Add a batch with forefront=True + await rq.add_requests( + ['https://example.com/priority1', 'https://example.com/priority2'], + forefront=True, ) - await request_queue.add_request(request) + # Add another normal request + await rq.add_request('https://example.com/normal3') - dequeued_request = await request_queue.fetch_next_request() - assert dequeued_request is not None + # Add another priority request + await rq.add_request('https://example.com/priority3', forefront=True) - assert dequeued_request.user_data['hello'] == 'world' - assert dequeued_request.user_data['foo'] == 42 + # Wait for background tasks + await asyncio.sleep(0.1) + # The expected order should be: + # 1. priority3 (most recent forefront) + # 2. priority1 (from batch, forefront) + # 3. priority2 (from batch, forefront) + # 4. normal1 (oldest normal) + # 5. normal2 + # 6. normal3 (newest normal) -async def test_complex_user_data_serialization(request_queue: RequestQueue) -> None: - request = Request.from_url('https://crawlee.dev') - request.user_data['hello'] = 'world' - request.user_data['foo'] = 42 - request.crawlee_data.max_retries = 1 - request.crawlee_data.state = RequestState.ERROR_HANDLER + requests = [] + while True: + req = await rq.fetch_next_request() + if req is None: + break + requests.append(req) + await rq.mark_request_as_handled(req) - await request_queue.add_request(request) + assert len(requests) == 6 + assert requests[0].url == 'https://example.com/priority3' - dequeued_request = await request_queue.fetch_next_request() - assert dequeued_request is not None + # The next two should be from the forefront batch (exact order within batch may vary) + batch_urls = {requests[1].url, requests[2].url} + assert 'https://example.com/priority1' in batch_urls + assert 'https://example.com/priority2' in batch_urls - data = dequeued_request.model_dump(by_alias=True) - assert data['userData']['hello'] == 'world' - assert data['userData']['foo'] == 42 - assert data['userData']['__crawlee'] == { - 'maxRetries': 1, - 'state': RequestState.ERROR_HANDLER, - } + # Then the normal requests in order + assert requests[3].url == 'https://example.com/normal1' + assert requests[4].url == 'https://example.com/normal2' + assert requests[5].url == 'https://example.com/normal3' -async def test_deduplication_of_requests_with_custom_unique_key() -> None: - with pytest.raises(ValueError, match='`always_enqueue` cannot be used with a custom `unique_key`'): - Request.from_url('https://apify.com', unique_key='apify', always_enqueue=True) +async def test_add_requests_with_forefront(rq: RequestQueue) -> None: + """Test adding requests to the front of the queue.""" + # Add some initial requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') + # Add a priority request at the forefront + await rq.add_request('https://example.com/priority', forefront=True) -async def test_deduplication_of_requests_with_invalid_custom_unique_key() -> None: - request_1 = Request.from_url('https://apify.com', always_enqueue=True) - request_2 = Request.from_url('https://apify.com', always_enqueue=True) + # Fetch the next request - should be the priority one + next_request = await rq.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/priority' - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_1) - await rq.add_request(request_2) - assert await rq.get_total_count() == 2 +async def test_fetch_next_request_and_mark_handled(rq: RequestQueue) -> None: + """Test fetching and marking requests as handled.""" + # Add some requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') - assert await rq.fetch_next_request() == request_1 - assert await rq.fetch_next_request() == request_2 + # Fetch first request + request1 = await rq.fetch_next_request() + assert request1 is not None + assert request1.url == 'https://example.com/page1' + # Mark the request as handled + result = await rq.mark_request_as_handled(request1) + assert result is not None + assert result.was_already_handled is True -async def test_deduplication_of_requests_with_valid_custom_unique_key() -> None: - request_1 = Request.from_url('https://apify.com') - request_2 = Request.from_url('https://apify.com') + # Fetch next request + request2 = await rq.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/page2' - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_1) - await rq.add_request(request_2) + # Mark the second request as handled + await rq.mark_request_as_handled(request2) - assert await rq.get_total_count() == 1 + # Verify counts + assert rq.metadata.total_request_count == 2 + assert rq.metadata.handled_request_count == 2 + assert rq.metadata.pending_request_count == 0 - assert await rq.fetch_next_request() == request_1 + # Verify queue is empty + empty_request = await rq.fetch_next_request() + assert empty_request is None -async def test_cache_requests(request_queue: RequestQueue) -> None: - request_1 = Request.from_url('https://apify.com') - request_2 = Request.from_url('https://crawlee.dev') +async def test_get_request_by_id(rq: RequestQueue) -> None: + """Test retrieving a request by its ID.""" + # Add a request + added_result = await rq.add_request('https://example.com') + request_id = added_result.id - await request_queue.add_request(request_1) - await request_queue.add_request(request_2) + # Retrieve the request by ID + retrieved_request = await rq.get_request(request_id) + assert retrieved_request is not None + assert retrieved_request.id == request_id + assert retrieved_request.url == 'https://example.com' - assert request_queue._requests_cache.currsize == 2 - fetched_request = await request_queue.fetch_next_request() +async def test_get_non_existent_request(rq: RequestQueue) -> None: + """Test retrieving a request that doesn't exist.""" + non_existent_request = await rq.get_request('non-existent-id') + assert non_existent_request is None - assert fetched_request is not None - assert fetched_request.id == request_1.id - # After calling fetch_next_request request_1 moved to the end of the cache store. - cached_items = [request_queue._requests_cache.popitem()[0] for _ in range(2)] - assert cached_items == [request_2.id, request_1.id] +async def test_reclaim_request(rq: RequestQueue) -> None: + """Test reclaiming a request that failed processing.""" + # Add a request + await rq.add_request('https://example.com') + # Fetch the request + request = await rq.fetch_next_request() + assert request is not None -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() + # Reclaim the request + result = await rq.reclaim_request(request) + assert result is not None + assert result.was_already_handled is False - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', - ) + # Verify we can fetch it again + reclaimed_request = await rq.fetch_next_request() + assert reclaimed_request is not None + assert reclaimed_request.id == request.id + assert reclaimed_request.url == 'https://example.com' - request_queue = RequestQueue.from_storage_object(storage_client, storage_object) - - assert request_queue.id == storage_object.id - assert request_queue.name == storage_object.name - assert request_queue.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] - - -async def test_add_batched_requests_with_retry(request_queue: RequestQueue) -> None: - """Test that unprocessed requests are retried. - - Unprocessed requests should not count in `get_total_count` - Test creates situation where in `batch_add_requests` call in first batch 3 requests are unprocessed. - On each following `batch_add_requests` call the last request in batch remains unprocessed. - In this test `batch_add_requests` is called once with batch of 10 requests. With retries only 1 request should - remain unprocessed.""" - - batch_add_requests_call_counter = count(start=1) - service_locator.get_storage_client() - initial_request_count = 10 - expected_added_requests = 9 - requests = [f'https://example.com/{i}' for i in range(initial_request_count)] - - class MockedRequestQueueClient(RequestQueueClient): - """Patched memory storage client that simulates unprocessed requests.""" - - async def _batch_add_requests_without_last_n( - self, batch: Sequence[Request], n: int = 0 - ) -> BatchRequestsOperationResponse: - response = await super().batch_add_requests(batch[:-n]) - response.unprocessed_requests = [ - UnprocessedRequest(url=r.url, unique_key=r.unique_key, method=r.method) for r in batch[-n:] - ] - return response - - async def batch_add_requests( - self, - requests: Sequence[Request], - *, - forefront: bool = False, # noqa: ARG002 - ) -> BatchRequestsOperationResponse: - """Mocked client behavior that simulates unprocessed requests. - - It processes all except last three at first run, then all except last none. - Overall if tried with the same batch it will process all except the last one. - """ - call_count = next(batch_add_requests_call_counter) - if call_count == 1: - # Process all but last three - return await self._batch_add_requests_without_last_n(requests, n=3) - # Process all but last - return await self._batch_add_requests_without_last_n(requests, n=1) - - mocked_storage_client = AsyncMock(spec=StorageClient) - mocked_storage_client.request_queue = MagicMock( - return_value=MockedRequestQueueClient(id='default', memory_storage_client=MemoryStorageClient.from_config()) + +async def test_reclaim_request_with_forefront(rq: RequestQueue) -> None: + """Test reclaiming a request to the front of the queue.""" + # Add requests + await rq.add_request('https://example.com/first') + await rq.add_request('https://example.com/second') + + # Fetch the first request + first_request = await rq.fetch_next_request() + assert first_request is not None + assert first_request.url == 'https://example.com/first' + + # Reclaim it to the forefront + await rq.reclaim_request(first_request, forefront=True) + + # The reclaimed request should be returned first (before the second request) + next_request = await rq.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/first' + + +async def test_is_empty(rq: RequestQueue) -> None: + """Test checking if a request queue is empty.""" + # Initially the queue should be empty + assert await rq.is_empty() is True + + # Add a request + await rq.add_request('https://example.com') + assert await rq.is_empty() is False + + # Fetch and handle the request + request = await rq.fetch_next_request() + + assert request is not None + await rq.mark_request_as_handled(request) + + # Queue should be empty again + assert await rq.is_empty() is True + + +async def test_is_finished(rq: RequestQueue) -> None: + """Test checking if a request queue is finished.""" + # Initially the queue should be finished (empty and no background tasks) + assert await rq.is_finished() is True + + # Add a request + await rq.add_request('https://example.com') + assert await rq.is_finished() is False + + # Add requests in the background + await rq.add_requests( + ['https://example.com/1', 'https://example.com/2'], + wait_for_all_requests_to_be_added=False, ) - request_queue = RequestQueue(id='default', name='some_name', storage_client=mocked_storage_client) + # Queue shouldn't be finished while background tasks are running + assert await rq.is_finished() is False + + # Wait for background tasks to finish + await asyncio.sleep(0.2) - # Add the requests to the RQ in batches - await request_queue.add_requests_batched( - requests, wait_for_all_requests_to_be_added=True, wait_time_between_batches=timedelta(0) + # Process all requests + while True: + request = await rq.fetch_next_request() + if request is None: + break + await rq.mark_request_as_handled(request) + + # Now queue should be finished + assert await rq.is_finished() is True + + +async def test_mark_non_existent_request_as_handled(rq: RequestQueue) -> None: + """Test marking a non-existent request as handled.""" + # Create a request that hasn't been added to the queue + request = Request.from_url(url='https://example.com', id='non-existent-id') + + # Attempt to mark it as handled + result = await rq.mark_request_as_handled(request) + assert result is None + + +async def test_reclaim_non_existent_request(rq: RequestQueue) -> None: + """Test reclaiming a non-existent request.""" + # Create a request that hasn't been added to the queue + request = Request.from_url(url='https://example.com', id='non-existent-id') + + # Attempt to reclaim it + result = await rq.reclaim_request(request) + assert result is None + + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test dropping a request queue removes it from cache and clears its data.""" + rq = await RequestQueue.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, ) - # Ensure the batch was processed correctly - assert await request_queue.get_total_count() == expected_added_requests - # Fetch and validate each request in the queue - for original_request in requests[:expected_added_requests]: - next_request = await request_queue.fetch_next_request() - assert next_request is not None + # Add a request + await rq.add_request('https://example.com') - expected_url = original_request if isinstance(original_request, str) else original_request.url - assert next_request.url == expected_url + # Verify request queue exists in cache + assert rq._cache_key in RequestQueue._cache + + # Drop the request queue + await rq.drop() + + # Verify request queue was removed from cache + assert rq._cache_key not in RequestQueue._cache + + # Verify request queue is empty (by creating a new one with the same name) + new_rq = await RequestQueue.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) - # Confirm the queue is empty after processing all requests - assert await request_queue.is_empty() is True + # Verify the queue is empty + assert await new_rq.is_empty() is True + assert new_rq.metadata.total_request_count == 0 + assert new_rq.metadata.pending_request_count == 0 + await new_rq.drop() diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index 73e17d50d9..f89401e5be 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -9,6 +9,7 @@ from crawlee.configuration import Configuration from crawlee.crawlers import HttpCrawler, HttpCrawlingContext from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._file_system._storage_client import FileSystemStorageClient if TYPE_CHECKING: from pathlib import Path @@ -35,14 +36,15 @@ def test_global_configuration_works_reversed() -> None: async def test_storage_not_persisted_when_disabled(tmp_path: Path, server_url: URL) -> None: - config = Configuration( - persist_storage=False, - write_metadata=False, + configuration = Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] ) - storage_client = MemoryStorageClient.from_config(config) + storage_client = MemoryStorageClient() - crawler = HttpCrawler(storage_client=storage_client) + crawler = HttpCrawler( + configuration=configuration, + storage_client=storage_client, + ) @crawler.router.default_handler async def default_handler(context: HttpCrawlingContext) -> None: @@ -56,14 +58,16 @@ async def default_handler(context: HttpCrawlingContext) -> None: async def test_storage_persisted_when_enabled(tmp_path: Path, server_url: URL) -> None: - config = Configuration( - persist_storage=True, - write_metadata=True, + configuration = Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] ) - storage_client = MemoryStorageClient.from_config(config) - crawler = HttpCrawler(storage_client=storage_client) + storage_client = FileSystemStorageClient() + + crawler = HttpCrawler( + configuration=configuration, + storage_client=storage_client, + ) @crawler.router.default_handler async def default_handler(context: HttpCrawlingContext) -> None: diff --git a/tests/unit/test_service_locator.py b/tests/unit/test_service_locator.py index 50da5ddb86..a4ed0620dd 100644 --- a/tests/unit/test_service_locator.py +++ b/tests/unit/test_service_locator.py @@ -6,7 +6,7 @@ from crawlee.configuration import Configuration from crawlee.errors import ServiceConflictError from crawlee.events import LocalEventManager -from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient def test_default_configuration() -> None: @@ -72,21 +72,21 @@ def test_event_manager_conflict() -> None: def test_default_storage_client() -> None: default_storage_client = service_locator.get_storage_client() - assert isinstance(default_storage_client, MemoryStorageClient) + assert isinstance(default_storage_client, FileSystemStorageClient) def test_custom_storage_client() -> None: - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() service_locator.set_storage_client(custom_storage_client) storage_client = service_locator.get_storage_client() assert storage_client is custom_storage_client def test_storage_client_overwrite() -> None: - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() service_locator.set_storage_client(custom_storage_client) - another_custom_storage_client = MemoryStorageClient.from_config() + another_custom_storage_client = MemoryStorageClient() service_locator.set_storage_client(another_custom_storage_client) assert custom_storage_client != another_custom_storage_client @@ -95,7 +95,7 @@ def test_storage_client_overwrite() -> None: def test_storage_client_conflict() -> None: service_locator.get_storage_client() - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() with pytest.raises(ServiceConflictError, match='StorageClient is already in use.'): service_locator.set_storage_client(custom_storage_client) diff --git a/website/generate_module_shortcuts.py b/website/generate_module_shortcuts.py index 5a18e8d3f3..61acc68ade 100755 --- a/website/generate_module_shortcuts.py +++ b/website/generate_module_shortcuts.py @@ -5,6 +5,7 @@ import importlib import inspect import json +from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -55,5 +56,5 @@ def resolve_shortcuts(shortcuts: dict) -> None: resolve_shortcuts(shortcuts) -with open('module_shortcuts.json', 'w', encoding='utf-8') as shortcuts_file: +with Path('module_shortcuts.json').open('w', encoding='utf-8') as shortcuts_file: json.dump(shortcuts, shortcuts_file, indent=4, sort_keys=True)