diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 57bf6c25..031343c2 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -25,6 +25,7 @@ _DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks") _DEFAULT_LIGHTNING_CACHE_DIR = os.path.join("/cache", "chunks") _SUPPORTED_PROVIDERS = ("s3", "gs") # cloud providers supported by litdata for uploading (optimize, map, merge, etc) +_SUPPORTED_DOWNLOADERS = ("s3", "gs", "azure", "hf") # cloud providers supported by litdata for streaming datasets # This is required for full pytree serialization / deserialization support _TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0") diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 3740c9b0..1907ab12 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -16,7 +16,7 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple -from litdata.constants import _INDEX_FILENAME +from litdata.constants import _INDEX_FILENAME, _SUPPORTED_DOWNLOADERS from litdata.debugger import ChromeTraceColors, _get_log_msg from litdata.streaming.compression import _COMPRESSORS, Compressor from litdata.streaming.downloader import get_downloader @@ -25,6 +25,7 @@ from litdata.streaming.serializers import Serializer from litdata.utilities._pytree import tree_unflatten, treespec_loads from litdata.utilities.dataset_utilities import load_index_file +from litdata.utilities.file_utils import increment_file_count logger = logging.getLogger("litdata.streaming.config") @@ -121,7 +122,7 @@ def skip_chunk_indexes_deletion(self) -> Optional[List[int]]: def skip_chunk_indexes_deletion(self, skip_chunk_indexes_deletion: List[int]) -> None: self._skip_chunk_indexes_deletion = skip_chunk_indexes_deletion - def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) -> None: + def download_chunk_from_index(self, chunk_index: int) -> None: assert self._chunks is not None chunk_filename = self._chunks[chunk_index]["filename"] @@ -129,22 +130,27 @@ def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) - if os.path.exists(local_chunkpath): self.try_decompress(local_chunkpath) - if self._downloader is not None and not skip_lock: + if self._downloader is not None and self._remote_dir.startswith(_SUPPORTED_DOWNLOADERS): # We don't want to redownload the base, but we should mark # it as having been requested by something - self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", "")) - pass + count = increment_file_count(local_chunkpath.replace(f".{self._compressor_name}", "")) + if count == 1: + # weird, shouldn't happen + # but if it does, we should start downloading the file + self._downloader.download_chunk_from_index(chunk_index) + self.try_decompress(local_chunkpath) return - if self._downloader is None: + if (self._downloader is None) or (not self._remote_dir.startswith(_SUPPORTED_DOWNLOADERS)): return - if not skip_lock: - self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", "")) + curr_count = increment_file_count(local_chunkpath.replace(f".{self._compressor_name}", "")) - self._downloader.download_chunk_from_index(chunk_index) - - self.try_decompress(local_chunkpath) + if curr_count == 1: + # this is the first time we are downloading this file + # so we should actually download it + self._downloader.download_chunk_from_index(chunk_index) + self.try_decompress(local_chunkpath) def try_decompress(self, local_chunkpath: str) -> None: if self._compressor is None: diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index c510f461..b19852c4 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -135,7 +135,7 @@ def __call__(self, items: List[Any]) -> Any: class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): - """This is รง to inform the cache is done chunking.""" + """This is to inform the cache is done chunking.""" def _next_data(self) -> Any: try: @@ -253,6 +253,7 @@ def _next_data(self) -> Any: raise e +#! TODO: This class is not being used anywhere. class CacheDataLoader(DataLoader): __doc__ = DataLoader.__doc__ diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index d69d8e3e..16552695 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -289,9 +289,13 @@ def __iter__(self) -> "StreamingDataset": ) worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank + if worker_rank == 0: + logger.debug(f"workers_chunks: {workers_chunks}\nworkers_intervals: {workers_intervals}") self.worker_chunks = workers_chunks[worker_rank] self.worker_intervals = workers_intervals[worker_rank] + logger.debug("-" * 50 + "\n" + f"{worker_rank=}; {self.worker_chunks=}; {self.worker_intervals=}\n" + "-" * 50) + # The max number of samples to return from `__next__` (in worker) self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) @@ -302,28 +306,30 @@ def __iter__(self) -> "StreamingDataset": # Find the chunks shared across all workers of the current node. # For each shared chunk, find the rank and worker to use the chunk last and prevent # premature deletion for the other workers. - node_size = self.distributed_env.world_size // self.distributed_env.num_nodes - first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size - num_workers_per_node = node_size * self.num_workers - worker_start = first_rank_this_node * num_workers_per_node - worker_end = worker_start + num_workers_per_node - local_rank = self.distributed_env.global_rank % node_size - - chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( - self.num_workers, - self.batch_size, - workers_chunks[worker_start:worker_end], - workers_intervals[worker_start:worker_end], - ) - worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete( - chunks_indexes_skip_deletion - ) + if False: + #! TODO: fix skip_chunk_deletion (iops is much slower than in memory access) + node_size = self.distributed_env.world_size // self.distributed_env.num_nodes + first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size + num_workers_per_node = node_size * self.num_workers + worker_start = first_rank_this_node * num_workers_per_node + worker_end = worker_start + num_workers_per_node + local_rank = self.distributed_env.global_rank % node_size + + chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( + self.num_workers, + self.batch_size, + workers_chunks[worker_start:worker_end], + workers_intervals[worker_start:worker_end], + ) + worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete( + chunks_indexes_skip_deletion + ) - worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank - if worker_rank_local_node in worker_node_rank_to_chunk_indexes: - self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[ - worker_rank_local_node - ] + worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank + if worker_rank_local_node in worker_node_rank_to_chunk_indexes: + self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[ + worker_rank_local_node + ] self.num_chunks = len(self.worker_chunks) self.upcoming_indexes = [] @@ -334,6 +340,9 @@ def __iter__(self) -> "StreamingDataset": self.has_triggered_download = False self.last_time = time() + # start the downloader thread + self.cache._reader.prepare_downloader_thread(self.worker_chunks) + return self def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) -> None: diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 3f104994..1b43651a 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -46,20 +46,6 @@ def __init__( self._chunks = chunks self._storage_options = storage_options or {} - def _increment_local_lock(self, chunkpath: str) -> None: - logger.debug(_get_log_msg({"name": f"increment_local_lock_for_{chunkpath}", "ph": "B"})) - countpath = chunkpath + ".cnt" - with suppress(Timeout), FileLock(countpath + ".lock", timeout=1): - try: - with open(countpath) as count_f: - curr_count = int(count_f.read().strip()) - except Exception: - curr_count = 0 - curr_count += 1 - with open(countpath, "w+") as count_f: - count_f.write(str(curr_count)) - logger.debug(_get_log_msg({"name": f"increment_local_lock_for_{chunkpath}", "ph": "E"})) - def download_chunk_from_index(self, chunk_index: int) -> None: logger.debug(_get_log_msg({"name": f"download_chunk_from_index_{chunk_index}", "ph": "B"})) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index ea7dcb83..294a5c58 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -36,6 +36,7 @@ from litdata.streaming.serializers import Serializer from litdata.utilities._pytree import PyTree, tree_unflatten from litdata.utilities.encryption import Encryption, EncryptionLevel +from litdata.utilities.file_utils import decrement_file_count Interval = namedtuple("Interval", ["chunk_start", "roi_start_idx", "roi_end_idx", "chunk_end"]) @@ -111,6 +112,15 @@ def load_item_from_chunk( def delete(self, chunk_index: int, chunk_filepath: str) -> None: """Delete a chunk from the local filesystem.""" + def safe_delete(self, chunk_index: int, chunk_filepath: str, delete_original_file: bool = False) -> None: + """Decrement the file count and delete the chunk file if the count reaches 0.""" + if os.path.exists(chunk_filepath + ".cnt"): + curr_count = decrement_file_count(chunk_filepath) + if curr_count == 0 and delete_original_file: + self.delete(chunk_index, chunk_filepath) + else: + self.delete(chunk_index, chunk_filepath) + @abstractmethod def encode_data(self, data: List[bytes], sizes: List[int], flattened: List[Any]) -> Any: pass @@ -260,7 +270,8 @@ def _load_data(self, fp: Union[FileIO, BytesIO], offset: int) -> bytes: pair = fp.read(8) begin, end = np.frombuffer(pair, np.uint32) - fp.seek(begin) # move the file pointer to the offset_start where the item starts + # move the file pointer to the offset_start where the item starts + fp.seek(begin) return fp.read(end - begin) # read the item def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree": @@ -737,7 +748,7 @@ def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_i # Return the specific row from the dataframe # Note: The `named=True` argument is used to return the row as a dictionary - return row_group_df.row(row_index_within_group, named=True) # type: ignore + return row_group_df.row(row_index_within_group, named=True) def _get_item(self, chunk_index: int, chunk_filepath: str, index: int) -> Any: """Retrieve a dataframe row from a parquet chunk by loading the entire chunk into memory. diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 99c5f704..95285358 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -15,13 +15,10 @@ import logging import os import warnings -from contextlib import suppress from queue import Empty, Queue from threading import Event, Thread from typing import Any, Dict, List, Optional, Tuple, Union -from filelock import FileLock, Timeout - from litdata.constants import _DEBUG from litdata.debugger import _get_log_msg from litdata.streaming.config import ChunksConfig, Interval @@ -52,6 +49,7 @@ def __init__( self, config: ChunksConfig, item_loader: BaseItemLoader, + chunks_order: List[int], distributed_env: _DistributedEnv, max_cache_size: Optional[int] = None, max_pre_download: int = 2, @@ -60,15 +58,19 @@ def __init__( super().__init__(daemon=True) self._config = config self._item_loader = item_loader + self._chunks_order = chunks_order # order in which chunks are to be downloaded + self.current_downloading_chunk_index = -1 + self.current_reading_chunk_index = -1 self._max_pre_download = max_pre_download self._pre_download_counter = 0 self._distributed_env = distributed_env - self._chunks_index_to_be_deleted: List[int] = [] + # self._chunks_index_to_be_deleted: List[int] = [] self._max_cache_size = max_cache_size self._parent_cache_dir = os.path.dirname(self._config._cache_dir) self._to_download_queue: Queue = Queue() self._to_delete_queue: Queue = Queue() + self._delete_queue_received_none: bool = False self._force_stop_event = Event() # TODO: Find a real fix to this problem @@ -79,122 +81,176 @@ def __init__( # Check whether a dataset slice fits on the node num_bytes_per_nodes = self._config.num_bytes // self._distributed_env.num_nodes self._delete_chunks_when_processed = num_bytes_per_nodes > max_cache_size if max_cache_size else False + # print(f"reader: {self._delete_chunks_when_processed=}") + # if self._delete_chunks_when_processed: + # print(f"clearing cache dir {self._parent_cache_dir} because the dataset is too large to fit in memory") + # # means we can't keep all chunks in the cache directory, so we should clear it to minimize the size + # # clear the cache directory except the index.json file + # for root, _, files in os.walk(self._parent_cache_dir): + # for file in files: + # if file != _INDEX_FILENAME: + # with contextlib.suppress(FileNotFoundError): + # os.remove(os.path.join(root, file)) self._has_exited = False - def download(self, chunk_indexes: List[int]) -> None: - """Receive the list of the chunk indices to download for the current epoch.""" - for chunk_index in chunk_indexes: - self._to_download_queue.put(chunk_index) + # def download(self, chunk_indexes: List[int]) -> None: + # """Receive the list of the chunk indices to download for the current epoch.""" + # print(f"thread: got indexes to download -> {chunk_indexes=};") + # for chunk_index in chunk_indexes: + # self._to_download_queue.put(chunk_index) def delete(self, chunk_indexes: List[int]) -> None: """Receive the list of the chunk indices to delete for the current epoch.""" for chunk_index in chunk_indexes: self._to_delete_queue.put(chunk_index) - def _remaining_locks(self, chunkpath: str) -> int: - countpath = chunkpath + ".cnt" - if not os.path.exists(countpath): - return 0 - with open(countpath) as count_f: - try: - return int(count_f.read().strip()) - except Exception: - return 1 - - def _decrement_local_lock(self, chunk_index: int) -> int: - """Remove a count from the local lock, return the remaining count.""" - chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] - - countpath = chunk_filepath + ".cnt" - with suppress(Timeout), FileLock(countpath + ".lock", timeout=3): - if not os.path.exists(countpath): - return 0 - with open(countpath) as count_f: - logger.debug(_get_log_msg({"name": f"decrement_local_lock_for_ {chunk_filepath}", "ph": "B"})) - try: - curr_count = int(count_f.read().strip()) - except Exception: - curr_count = 1 - curr_count -= 1 - if curr_count <= 0: - with contextlib.suppress(FileNotFoundError, PermissionError): - os.remove(countpath) - - with contextlib.suppress(FileNotFoundError, PermissionError): - os.remove(countpath + ".lock") - else: - with open(countpath, "w+") as count_f: - count_f.write(str(curr_count)) - logger.debug(_get_log_msg({"name": f"decrement_local_lock_for_ {chunk_filepath}", "ph": "E"})) - return curr_count - return 0 + # def _remaining_locks(self, chunkpath: str) -> int: + # countpath = chunkpath + ".cnt" + # if not os.path.exists(countpath): + # return 0 + # with open(countpath) as count_f: + # try: + # return int(count_f.read().strip()) + # except Exception: + # return 1 + + # def _decrement_local_lock(self, chunk_index: int) -> int: + # """Remove a count from the local lock, return the remaining count.""" + # chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] + + # countpath = chunk_filepath + ".cnt" + # with suppress(Timeout), FileLock(countpath + ".lock", timeout=3): + # if not os.path.exists(countpath): + # return 0 + # with open(countpath) as count_f: + # try: + # curr_count = int(count_f.read().strip()) + # except Exception: + # curr_count = 1 + # curr_count -= 1 + # if curr_count <= 0: + # with contextlib.suppress(FileNotFoundError, PermissionError): + # os.remove(countpath) + + # with contextlib.suppress(FileNotFoundError, PermissionError): + # os.remove(countpath + ".lock") + # else: + # with open(countpath, "w+") as count_f: + # count_f.write(str(curr_count)) + # return curr_count + # return 0 def _apply_delete(self, chunk_index: int) -> None: """Inform the item loader of the chunk to delete.""" # TODO: Fix the can_delete method - can_delete_chunk = self._config.can_delete(chunk_index) + # can_delete_chunk = self._config.can_delete(chunk_index) + # print(f"apply delete called -> {chunk_index} {can_delete_chunk=}; by {self._rank or 0}") chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] - remaining_locks = self._remaining_locks(chunk_filepath) - if remaining_locks > 0: # Can't delete this, something has it - if _DEBUG: - print(f"Skip delete {chunk_filepath} by {self._rank or 0}, current lock count: {remaining_locks}") - return + # remaining_locks = self._remaining_locks(chunk_filepath) + # if remaining_locks > 0: # Can't delete this, something has it + # if _DEBUG: + # print(f"Skip delete {chunk_filepath} by {self._rank or 0}, current lock count: {remaining_locks}") + # return - if _DEBUG: - with open(chunk_filepath + ".tmb", "w+") as tombstone_file: - tombstone_file.write(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") + # if _DEBUG: + # with open(chunk_filepath + ".tmb", "w+") as tombstone_file: + # tombstone_file.write(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") - self._item_loader.delete(chunk_index, chunk_filepath) + self._item_loader.safe_delete( + chunk_index, chunk_filepath, delete_original_file=self._delete_chunks_when_processed + ) - if _DEBUG: - print(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") + # if _DEBUG: + # print(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") - for lock_extension in [".lock", ".cnt.lock"]: - try: - locak_chunk_path = chunk_filepath + lock_extension - if os.path.exists(locak_chunk_path): - os.remove(locak_chunk_path) - except FileNotFoundError: - pass + # for lock_extension in [".lock", ".cnt.lock"]: + # try: + # locak_chunk_path = chunk_filepath + lock_extension + # if os.path.exists(locak_chunk_path): + # os.remove(locak_chunk_path) + # except FileNotFoundError: + # pass def stop(self) -> None: """Receive the list of the chunk indices to download for the current epoch.""" - self._to_download_queue.put(_END_TOKEN) + # self._to_download_queue.put(_END_TOKEN) + if self._delete_chunks_when_processed and not self._delete_queue_received_none: + # for chnk_idx in self._chunks_index_to_be_deleted: + # self._apply_delete(chnk_idx) + # read from delete queue until None is received and delete the chunks + total_waiting_time = 0 + while not self._delete_queue_received_none: # parallelly it can be set true by thread's run method + try: + chunk_index = self._to_delete_queue.get(timeout=_DEFAULT_TIMEOUT) + if chunk_index is None: + self._delete_queue_received_none = True + break + self._apply_delete(chunk_index) + total_waiting_time = 0 + except Empty: + total_waiting_time += _DEFAULT_TIMEOUT + if total_waiting_time > _LONG_DEFAULT_TIMEOUT * 2: # wait for 10 seconds + print("Timeout waiting for delete queue to be empty (None)") + break + self.force_stop() def force_stop(self) -> None: self._force_stop_event.set() def _maybe_delete_chunks(self) -> None: - reached_pre_download = self._pre_download_counter == self._max_pre_download + # reached_pre_download = self._pre_download_counter == self._max_pre_download + # reached_max_pre_download = ( + # self.current_downloading_chunk_index < self.current_reading_chunk_index + self._max_pre_download + # ) - # we have already pre-downloaded some chunks, we just need to wait for them to be processed. - chunk_index = _get_from_queue( - self._to_delete_queue, timeout=_LONG_DEFAULT_TIMEOUT if reached_pre_download else _DEFAULT_TIMEOUT - ) + # should_start_deleting_chunks = self._can_delete_chunk() + # if not should_start_deleting_chunks: + # return - if chunk_index is None: + # we have already pre-downloaded some chunks, we just need to wait for them to be processed. + if self._delete_queue_received_none: return + while True: + try: + chunk_index_to_be_deleted = self._to_delete_queue.get(timeout=_DEFAULT_TIMEOUT) + + if chunk_index_to_be_deleted is None: + self._delete_queue_received_none = True + return + # self._pre_download_counter -= 1 + + # Store the current chunk index + # self._chunks_index_to_be_deleted.append(chunk_index) - # Store the current chunk index - self._chunks_index_to_be_deleted.append(chunk_index) + # Get the current cache size and decide whether we need to start cleanup. Otherwise, keep track of it + # Delete the oldest chunk + self._apply_delete(chunk_index_to_be_deleted) + except Empty: + # Timeout waiting for delete queue to be empty + break + except Exception as e: + raise RuntimeError(f"Error while deleting chunks: {e}") from e - # Get the current cache size and decide whether we need to start cleanup. Otherwise, keep track of it - while self._max_cache_size and self._chunks_index_to_be_deleted and self._can_delete_chunk(): - # Delete the oldest chunk - self._apply_delete(self._chunks_index_to_be_deleted.pop(0)) - # Decrement the pre-download counter - self._pre_download_counter -= 1 return def _can_delete_chunk(self) -> bool: if self._delete_chunks_when_processed: - return self._pre_download_counter >= self._max_pre_download - 1 + # return self._pre_download_counter >= self._max_pre_download - 1 + # if we have downloaded all chunks, we can delete the oldest one + if self.current_downloading_chunk_index == len(self._chunks_order) - 1: + return True + return self.current_downloading_chunk_index >= (self.current_reading_chunk_index + self._max_pre_download) + + return False # if complete dataset can be stored in the cache, we don't need to delete any chunk return ( self._max_cache_size is not None and _get_folder_size(self._config._cache_dir, self._config) >= self._max_cache_size ) + def _can_download_chunk(self) -> bool: + return not self._can_delete_chunk() + def _pre_load_chunk(self, chunk_index: int) -> None: chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] self._item_loader.pre_load_chunk(chunk_index, chunk_filepath) @@ -206,7 +262,8 @@ def _force_download(self) -> None: chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] print(f"Requested force download for {chunk_filepath} by {self._rank}") - self._config.download_chunk_from_index(chunk_index, skip_lock=True) + # skip counter_file logic and directly download the chunk + self._config._downloader.download_chunk_from_index(chunk_index) # Preload item if possible to gain some time but only # if this is one of the pre-downloaded chunk @@ -224,27 +281,25 @@ def run(self) -> None: self._force_download() - if self._pre_download_counter < self._max_pre_download: - chunk_index = _get_from_queue(self._to_download_queue) - if chunk_index == _END_TOKEN: - if self._max_cache_size: - self._maybe_delete_chunks() - self._has_exited = True - return + if self._can_download_chunk() and (self.current_downloading_chunk_index < len(self._chunks_order) - 1): + self.current_downloading_chunk_index += 1 + # chunk_index = _get_from_queue(self._to_download_queue) + chunk_index = self._chunks_order[self.current_downloading_chunk_index] + # if chunk_index == _END_TOKEN: + # self._has_exited = True + # return if chunk_index is not None: self._config.download_chunk_from_index(chunk_index) # Preload item if possible to gain some time but only # if this is one of the pre-downloaded chunk - if self._pre_download_counter > 0: - self._pre_load_chunk(chunk_index) + # if self._pre_download_counter > 0: + # self._pre_load_chunk(chunk_index) # Avoid downloading too many chunks in advance at the risk of over using the disk space - self._pre_download_counter += 1 - - if self._max_cache_size: - self._maybe_delete_chunks() + # self._pre_download_counter += 1 + self._maybe_delete_chunks() # The BinaryReader operates as the inverse of the data optimization process: @@ -330,6 +385,24 @@ def _try_load_config(self) -> Optional[ChunksConfig]: ) return self._config + def prepare_downloader_thread(self, chunks_order: List[int]) -> None: + """Prepare the downloader thread and start downloading the first few chunks.""" + if self._config is None and self._try_load_config() is None: + raise Exception("The reader index isn't defined.") + + # Create and start the prepare chunks thread + if self._prepare_thread is None: + self._prepare_thread = PrepareChunksThread( + config=self._config, + item_loader=self._item_loader, + chunks_order=chunks_order, + distributed_env=self._distributed_env, + max_cache_size=self._max_cache_size, + max_pre_download=self._max_pre_download, + rank=self.rank, + ) + self._prepare_thread.start() + @property def config(self) -> ChunksConfig: if self._config is None: @@ -358,37 +431,42 @@ def read(self, index: ChunkedIndex) -> Any: if not isinstance(index, ChunkedIndex): raise ValueError("The Reader.read(...) method expects a chunked Index.") + if self._config is None or self._prepare_thread is None: + raise Exception( + "Reader's downloading thread is not started. Please call `reader.prepare_downloader_thread()` first." + ) + # Load the config containing the index - if self._config is None and self._try_load_config() is None: - raise Exception("The reader index isn't defined.") + # if self._config is None and self._try_load_config() is None: + # raise Exception("The reader index isn't defined.") - if self._config and (self._config._remote_dir or self._config._compressor): + if self._config and (self._config._remote_dir or self._config._compressor): # noqa: SIM102 # Create and start the prepare chunks thread - if self._prepare_thread is None and self._config: - self._prepare_thread = PrepareChunksThread( - self._config, - self._item_loader, - self._distributed_env, - self._max_cache_size, - self._max_pre_download, - self._rank, - ) - # Attach the force download queue - self._item_loader._force_download_queue = self._prepare_thread._force_download_queue # type: ignore - self._prepare_thread.start() - if index.chunk_indexes: - self._prepare_thread.download(index.chunk_indexes) - self._chunks_queued_for_download = True + # if self._prepare_thread is None and self._config: + # self._prepare_thread = PrepareChunksThread( + # self._config, + # self._item_loader, + # self._distributed_env, + # self._max_cache_size, + # self._max_pre_download, + # self._rank, + # ) + # # Attach the force download queue + # self._item_loader._force_download_queue = self._prepare_thread._force_download_queue # type: ignore + # self._prepare_thread.start() + # if index.chunk_indexes: + # self._prepare_thread.download(index.chunk_indexes) + # self._chunks_queued_for_download = True # Only request individual chunk download if: # 1. We haven't already queued all chunks for the download # 2. We're processing a new chunk (different from the last one) - if not self._chunks_queued_for_download and index.chunk_index != self._last_chunk_index: - assert self._prepare_thread - self._prepare_thread.download([index.chunk_index]) + # if not self._chunks_queued_for_download and index.chunk_index != self._last_chunk_index: + # assert self._prepare_thread + # self._prepare_thread.download([index.chunk_index]) - if self._last_chunk_index is None: - self._last_chunk_index = index.chunk_index + if self._last_chunk_index is None or index.chunk_index != self._last_chunk_index: + self._prepare_thread.current_reading_chunk_index += 1 # Fetch the element chunk_filepath, begin, filesize_bytes = self.config[index] @@ -404,41 +482,29 @@ def read(self, index: ChunkedIndex) -> Any: # We need to request deletion after the latest element has been loaded. # Otherwise, this could trigger segmentation fault error depending on the item loader used. - if ( - self._config - and (self._config._remote_dir or self._config._compressor) - and index.chunk_index != self._last_chunk_index - ): - assert self._prepare_thread - assert self._last_chunk_index is not None - - # inform the chunk has been completely consumed - self._prepare_thread._decrement_local_lock(self._last_chunk_index) - self._prepare_thread.delete([self._last_chunk_index]) - - if index.chunk_index != self._last_chunk_index: + if self._last_chunk_index is None or index.chunk_index != self._last_chunk_index: # Close the memory-mapped file for the last chunk index if isinstance(self._item_loader, (TokensLoader, ParquetLoader)) and self._last_chunk_index is not None: self._item_loader.close(self._last_chunk_index) + if self._config and (self._config._remote_dir or self._config._compressor): + assert self._prepare_thread + if self._last_chunk_index is not None: + # inform the chunk has been completely consumed + # self._prepare_thread._decrement_local_lock(self._last_chunk_index) + self._prepare_thread.delete([self._last_chunk_index]) + # track the new chunk index as the latest one self._last_chunk_index = index.chunk_index if index.is_last_index and self._prepare_thread: # inform the thread it is time to stop - self._prepare_thread._decrement_local_lock(index.chunk_index) - self._prepare_thread.delete([index.chunk_index]) + # self._prepare_thread._decrement_local_lock(index.chunk_index) + self._item_loader.close(self._last_chunk_index) + self._prepare_thread.delete([index.chunk_index, None]) # send this chunk for deletion self._prepare_thread.stop() - if self._max_cache_size and self._prepare_thread.is_alive(): - try: - self._prepare_thread.join(timeout=_LONG_DEFAULT_TIMEOUT) - except Timeout: - logger.warning( - "The prepare chunks thread didn't exit properly. " - "This can happen if the chunk files are too large." - ) + self._prepare_thread.join() self._prepare_thread = None - self._item_loader.close(self._last_chunk_index) self._last_chunk_index = None self._chunks_queued_for_download = False @@ -516,7 +582,6 @@ def _get_folder_size(path: str, config: ChunksConfig) -> int: f"Ignoring '{filename}': " "This file doesn't appear to be a valid chunk file and has been excluded from the size calculation." ) - return size diff --git a/src/litdata/streaming/sampler.py b/src/litdata/streaming/sampler.py index 4c67b51d..273b87fd 100644 --- a/src/litdata/streaming/sampler.py +++ b/src/litdata/streaming/sampler.py @@ -22,7 +22,8 @@ @dataclass class ChunkedIndex: - """Represents an index within a chunked dataset. + """The docstring below is incorrect for `chunk_indexes`. + Represents an index within a chunked dataset. Attributes: index (int): The global index of the data point across all chunks. @@ -48,6 +49,7 @@ class ChunkedIndex: is_last_index: bool = False +#! TODO: This class is not being used. (Used in CacheDataLoader, but CacheDataLoader itself not being used) class CacheBatchSampler: def __init__( self, diff --git a/src/litdata/utilities/file_utils.py b/src/litdata/utilities/file_utils.py new file mode 100644 index 00000000..18479995 --- /dev/null +++ b/src/litdata/utilities/file_utils.py @@ -0,0 +1,55 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from contextlib import suppress + +from filelock import FileLock, Timeout + + +def increment_file_count(file_path: str) -> int: + """Increment the file count in the index file.""" + countpath = file_path + ".cnt" + with suppress(Timeout), FileLock(countpath + ".lock", timeout=1): + try: + with open(countpath) as count_f: + curr_count = int(count_f.read().strip()) + except Exception: + curr_count = 0 + curr_count += 1 + with open(countpath, "w+") as count_f: + count_f.write(str(curr_count)) + + return curr_count + + +def decrement_file_count(file_path: str) -> int: + """Decrement the file count in the index file.""" + countpath = file_path + ".cnt" + + with suppress(Timeout), FileLock(countpath + ".lock", timeout=1): + try: + with open(countpath) as count_f: + curr_count = int(count_f.read().strip()) + except Exception as e: + raise ValueError(f"Count file not found when trying to decrement_file_count: {countpath}.") from e + curr_count -= 1 + + if curr_count <= 0: + # remove the count file if it reaches zero + with suppress(FileNotFoundError): + os.remove(countpath) + return 0 + with open(countpath, "w+") as count_f: + count_f.write(str(curr_count)) + + return curr_count