From 5d25ab3c02b79049279a2f8d6dbc8778ac1f77a8 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sat, 5 Apr 2025 17:39:33 +0530 Subject: [PATCH 01/12] wip --- src/litdata/streaming/dataloader.py | 2 +- src/litdata/streaming/dataset.py | 49 ++++++++++------------ src/litdata/streaming/reader.py | 9 +++- src/litdata/utilities/dataset_utilities.py | 3 ++ src/litdata/utilities/env.py | 5 ++- 5 files changed, 39 insertions(+), 29 deletions(-) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 9652419f..e96caba5 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -134,7 +134,7 @@ def __call__(self, items: List[Any]) -> Any: class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): - """This is รง to inform the cache is done chunking.""" + """This is to inform the cache is done chunking.""" def _next_data(self) -> Any: try: diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index b8e3c456..f52848b9 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -33,10 +33,6 @@ from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv from litdata.utilities.format import _convert_bytes_to_int from litdata.utilities.hf_dataset import index_hf_dataset -from litdata.utilities.shuffle import ( - _find_chunks_per_workers_on_which_to_skip_deletion, - _map_node_worker_rank_to_chunk_indexes_to_not_delete, -) logger = logging.getLogger(__name__) @@ -253,6 +249,7 @@ def get_len(self, num_workers: int, batch_size: int) -> int: return self.shuffler.get_len(self.distributed_env, self.num_workers, self.batch_size, self.current_epoch) def __iter__(self) -> "StreamingDataset": + print("jai maata di, chalo shuru ho jao") # When the StreamingDataset is used within map or optimize, let's refetch the distributed env. if os.getenv("DATA_OPTIMIZER_GLOBAL_RANK"): self.distributed_env = _DistributedEnv.detect() @@ -285,28 +282,28 @@ def __iter__(self) -> "StreamingDataset": # Find the chunks shared across all workers of the current node. # For each shared chunk, find the rank and worker to use the chunk last and prevent # premature deletion for the other workers. - node_size = self.distributed_env.world_size // self.distributed_env.num_nodes - first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size - num_workers_per_node = node_size * self.num_workers - worker_start = first_rank_this_node * num_workers_per_node - worker_end = worker_start + num_workers_per_node - local_rank = self.distributed_env.global_rank % node_size - - chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( - self.num_workers, - self.batch_size, - workers_chunks[worker_start:worker_end], - workers_intervals[worker_start:worker_end], - ) - worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete( - chunks_indexes_skip_deletion - ) - - worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank - if worker_rank_local_node in worker_node_rank_to_chunk_indexes: - self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[ - worker_rank_local_node - ] + # node_size = self.distributed_env.world_size // self.distributed_env.num_nodes + # first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size + # num_workers_per_node = node_size * self.num_workers + # worker_start = first_rank_this_node * num_workers_per_node + # worker_end = worker_start + num_workers_per_node + # local_rank = self.distributed_env.global_rank % node_size + + # chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( + # self.num_workers, + # self.batch_size, + # workers_chunks[worker_start:worker_end], + # workers_intervals[worker_start:worker_end], + # ) + # worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete( + # chunks_indexes_skip_deletion + # ) + + # worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank + # if worker_rank_local_node in worker_node_rank_to_chunk_indexes: + # self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[ + # worker_rank_local_node + # ] self.num_chunks = len(self.worker_chunks) self.current_indexes = [] diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index bfc4e1cc..dd78d251 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -130,6 +130,7 @@ def _apply_delete(self, chunk_index: int) -> None: """Inform the item loader of the chunk to delete.""" # TODO: Fix the can_delete method can_delete_chunk = self._config.can_delete(chunk_index) + print(f"apply delete called -> {chunk_index} {can_delete_chunk=}; by {self._rank or 0}") chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] remaining_locks = self._remaining_locks(chunk_filepath) @@ -184,6 +185,12 @@ def _maybe_delete_chunks(self) -> None: return def _can_delete_chunk(self) -> bool: + print( + "can delete chunk called", + self._delete_chunks_when_processed, + self._pre_download_counter, + self._max_pre_download, + ) if self._delete_chunks_when_processed: return self._pre_download_counter >= self._max_pre_download - 1 return ( @@ -495,7 +502,7 @@ def _get_folder_size(path: str, config: ChunksConfig) -> int: f"Ignoring '{filename}': " "This file doesn't appear to be a valid chunk file and has been excluded from the size calculation." ) - + print(f"Total size of files in '{path}': {size} bytes") return size diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index a191a929..0059223a 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -4,6 +4,7 @@ import os import shutil import tempfile +import time from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -62,6 +63,8 @@ def subsample_streaming_dataset( downloader = get_downloader(input_dir.url, input_dir.path, [], storage_options) downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), cache_index_filepath) + time.sleep(0.5) # Give some time for the file to be created + if not os.path.exists(input_dir.path): raise FileNotFoundError(f"The provided dataset path `{input_dir.path}` does not exist.") diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index 74f72c46..30b27f1d 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -80,7 +80,10 @@ def _instantiate_in_map_or_optimize(cls) -> "_DistributedEnv": return cls(world_size=num_workers * num_nodes, global_rank=int(global_rank), num_nodes=num_nodes) def __repr__(self) -> str: - return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)" + return ( + f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}," + + f" num_nodes: {self.num_nodes})" + ) def __str__(self) -> str: return repr(self) From 8b99ba4de75e9afe56817e1f441f979527e0418e Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sun, 6 Apr 2025 08:44:48 +0530 Subject: [PATCH 02/12] update --- src/litdata/streaming/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index f52848b9..b702adfe 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -379,6 +379,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: return self.cache[index] def __next__(self) -> Any: + print(f"next called in dataset, global_index: {self.global_index}, index: {self.index}") # Prevent to create more batch on a given process if self.global_index >= self.stop_length: self.current_epoch += 1 From f279b632026b2ecac2a740b6b3f5e974fe96485a Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sun, 6 Apr 2025 09:58:38 +0530 Subject: [PATCH 03/12] rename variables, original name doesn't carry enough info --- src/litdata/streaming/dataset.py | 75 ++++++++++++++++++-------------- src/litdata/streaming/reader.py | 12 ++--- 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index b702adfe..4ca2535a 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -150,13 +150,15 @@ def __init__( self.cache: Optional[Cache] = None self.worker_env: Optional[_WorkerEnv] = None - self.worker_chunks: List[int] = [] - self.worker_intervals: List[List[int]] = [] - self.current_indexes: List[int] = [] - self.chunk_index = 0 - self.num_chunks: Optional[int] = None - self.global_index = 0 - self.index = 0 + self.worker_chunks: List[int] = [] # chunk indexes on which the current worker will work + self.worker_intervals: List[List[int]] = [] # chunk index intervals for the current worker + self.upcoming_indexes: List[int] = [] # contains list of upcoming indexes to be processed + self.worker_chunks_index = 0 # which index of the array `self.worker_chunks` we are currently working on + self.num_chunks: Optional[int] = None # total number of chunks that the current worker will work on + self.global_index = 0 # total number of samples processed by the current worker up until now + + # number of samples processed by the current worker in the current chunk + self.consumed_sample_count_in_curr_chunk = 0 self.has_triggered_download = False self.min_items_per_replica: Optional[int] = None self.current_epoch = 1 @@ -249,7 +251,6 @@ def get_len(self, num_workers: int, batch_size: int) -> int: return self.shuffler.get_len(self.distributed_env, self.num_workers, self.batch_size, self.current_epoch) def __iter__(self) -> "StreamingDataset": - print("jai maata di, chalo shuru ho jao") # When the StreamingDataset is used within map or optimize, let's refetch the distributed env. if os.getenv("DATA_OPTIMIZER_GLOBAL_RANK"): self.distributed_env = _DistributedEnv.detect() @@ -306,10 +307,10 @@ def __iter__(self) -> "StreamingDataset": # ] self.num_chunks = len(self.worker_chunks) - self.current_indexes = [] - self.chunk_index = 0 + self.upcoming_indexes = [] + self.worker_chunks_index = 0 self.global_index = 0 - self.index = 0 + self.consumed_sample_count_in_curr_chunk = 0 self.has_triggered_download = False self.last_time = time() @@ -344,25 +345,25 @@ def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) worker_local_rank = self.worker_env.rank self.num_chunks = len(workers_intervals[worker_rank]) - self.chunk_index = chunks_index[worker_local_rank] + self.worker_chunks_index = chunks_index[worker_local_rank] self.worker_chunks = workers_chunks[worker_rank] self.worker_intervals = workers_intervals[worker_rank] # replay the indexes for the current chunks - interval = self.worker_intervals[self.chunk_index] + interval = self.worker_intervals[self.worker_chunks_index] current_indexes = np.arange(interval[1], interval[2]) # re-shuffle the indexes - current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index) + current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.worker_chunks_index) # skip any indexes already consumed current_indexes = current_indexes[indexes[worker_local_rank] :] - self.current_indexes = current_indexes + self.upcoming_indexes = current_indexes self.global_index = indexes[worker_local_rank] # bump the chunk_index - self.chunk_index += 1 + self.worker_chunks_index += 1 def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: if self.cache is None: @@ -379,50 +380,58 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: return self.cache[index] def __next__(self) -> Any: - print(f"next called in dataset, global_index: {self.global_index}, index: {self.index}") # Prevent to create more batch on a given process if self.global_index >= self.stop_length: + # global_index: total number of samples processed by the current worker across all chunks + # stop_length: max number of samples that the current worker will process + # if they are equal, means, worker has processed all the chunks + print("dame tu cosita aha ah ah ah") self.current_epoch += 1 self.reset_state_dict() raise StopIteration # Lazily re-populate the interval to reduce memory usage. - if len(self.current_indexes) == 0: - if self.chunk_index == self.num_chunks: - self.current_epoch += 1 - self.reset_state_dict() - raise StopIteration + if len(self.upcoming_indexes) == 0: + # if upcoming_indexes is empty, means we have processed all the indexes in the current chunk + # we need to move to the next chunk + # we don't need to account for `what if it's the last chunk` + # bcoz in that case, `self.global_index >= self.stop_length` will be true (check above) - # reset index - self.index = 0 + # reset consumed_sample_count_in_curr_chunk as we are moving to the next chunk + self.consumed_sample_count_in_curr_chunk = 0 - interval = self.worker_intervals[self.chunk_index] + interval = self.worker_intervals[self.worker_chunks_index] current_indexes = np.arange(interval[1], interval[2]) assert self.shuffler is not None assert self.num_chunks is not None - self.current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index) + self.upcoming_indexes = self.shuffler( + current_indexes, self.num_chunks, self.current_epoch, self.worker_chunks_index + ) - self.chunk_index += 1 + self.worker_chunks_index += 1 # Get the first index - index = self.current_indexes.pop(0) + index = self.upcoming_indexes.pop(0) # Call the `__getitem__` method. data = self.__getitem__( ChunkedIndex( index=index, - chunk_index=self.worker_chunks[self.chunk_index - 1], + chunk_index=self.worker_chunks[self.worker_chunks_index - 1], # We provide the chunks indexes only one the first - chunk_indexes=None if self.has_triggered_download else self.worker_chunks[self.chunk_index - 1 :], - is_last_index=(self.chunk_index) == len(self.worker_intervals) and len(self.current_indexes) == 0, + chunk_indexes=None + if self.has_triggered_download + else self.worker_chunks[self.worker_chunks_index - 1 :], + is_last_index=(self.worker_chunks_index) == len(self.worker_intervals) + and len(self.upcoming_indexes) == 0, ) ) self.has_triggered_download = True self.global_index += 1 - self.index += 1 - + self.consumed_sample_count_in_curr_chunk += 1 + print(f"data: {data}") return data def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int) -> Dict[str, Any]: diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index dd78d251..3bc8bf03 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -185,12 +185,12 @@ def _maybe_delete_chunks(self) -> None: return def _can_delete_chunk(self) -> bool: - print( - "can delete chunk called", - self._delete_chunks_when_processed, - self._pre_download_counter, - self._max_pre_download, - ) + # print( + # "can delete chunk called", + # self._delete_chunks_when_processed, + # self._pre_download_counter, + # self._max_pre_download, + # ) if self._delete_chunks_when_processed: return self._pre_download_counter >= self._max_pre_download - 1 return ( From 51a37fcbf5293a899e09473c418839e8966bd6de Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sun, 6 Apr 2025 10:39:52 +0530 Subject: [PATCH 04/12] some more updates --- src/litdata/streaming/dataset.py | 48 ++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 4ca2535a..cc40ecfe 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -153,7 +153,10 @@ def __init__( self.worker_chunks: List[int] = [] # chunk indexes on which the current worker will work self.worker_intervals: List[List[int]] = [] # chunk index intervals for the current worker self.upcoming_indexes: List[int] = [] # contains list of upcoming indexes to be processed - self.worker_chunks_index = 0 # which index of the array `self.worker_chunks` we are currently working on + + # which index of the array `self.worker_chunks` will we work on after this chunk is completely consumed + self.worker_next_chunks_index = 0 + self.num_chunks: Optional[int] = None # total number of chunks that the current worker will work on self.global_index = 0 # total number of samples processed by the current worker up until now @@ -308,7 +311,7 @@ def __iter__(self) -> "StreamingDataset": self.num_chunks = len(self.worker_chunks) self.upcoming_indexes = [] - self.worker_chunks_index = 0 + self.worker_next_chunks_index = 0 self.global_index = 0 self.consumed_sample_count_in_curr_chunk = 0 @@ -345,16 +348,18 @@ def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) worker_local_rank = self.worker_env.rank self.num_chunks = len(workers_intervals[worker_rank]) - self.worker_chunks_index = chunks_index[worker_local_rank] + self.worker_next_chunks_index = chunks_index[worker_local_rank] self.worker_chunks = workers_chunks[worker_rank] self.worker_intervals = workers_intervals[worker_rank] # replay the indexes for the current chunks - interval = self.worker_intervals[self.worker_chunks_index] + interval = self.worker_intervals[self.worker_next_chunks_index] current_indexes = np.arange(interval[1], interval[2]) # re-shuffle the indexes - current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.worker_chunks_index) + current_indexes = self.shuffler( + current_indexes, self.num_chunks, self.current_epoch, self.worker_next_chunks_index + ) # skip any indexes already consumed current_indexes = current_indexes[indexes[worker_local_rank] :] @@ -363,7 +368,7 @@ def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) self.global_index = indexes[worker_local_rank] # bump the chunk_index - self.worker_chunks_index += 1 + self.worker_next_chunks_index += 1 def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: if self.cache is None: @@ -380,7 +385,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: return self.cache[index] def __next__(self) -> Any: - # Prevent to create more batch on a given process + # check if we have reached the end of the dataset (i.e., all the chunks have been processed) if self.global_index >= self.stop_length: # global_index: total number of samples processed by the current worker across all chunks # stop_length: max number of samples that the current worker will process @@ -392,24 +397,31 @@ def __next__(self) -> Any: # Lazily re-populate the interval to reduce memory usage. if len(self.upcoming_indexes) == 0: - # if upcoming_indexes is empty, means we have processed all the indexes in the current chunk - # we need to move to the next chunk + # if upcoming_indexes is empty, means either: + # - it's the start, or, + # - we have processed all the indexes in the current chunk + # + # we need to move to the next chunk (or first chunk if it's the start) # we don't need to account for `what if it's the last chunk` # bcoz in that case, `self.global_index >= self.stop_length` will be true (check above) + if self.worker_next_chunks_index >= self.num_chunks: + raise ValueError("should not have happened!") - # reset consumed_sample_count_in_curr_chunk as we are moving to the next chunk + # reset consumed_sample_count_in_curr_chunk as we are switching to a new chunk self.consumed_sample_count_in_curr_chunk = 0 - interval = self.worker_intervals[self.worker_chunks_index] + # `next_worker_chunks_index` is the index of the chunk that we will be working on now + interval = self.worker_intervals[self.worker_next_chunks_index] + current_indexes = np.arange(interval[1], interval[2]) assert self.shuffler is not None assert self.num_chunks is not None self.upcoming_indexes = self.shuffler( - current_indexes, self.num_chunks, self.current_epoch, self.worker_chunks_index + current_indexes, self.num_chunks, self.current_epoch, self.worker_next_chunks_index ) - self.worker_chunks_index += 1 + self.worker_next_chunks_index += 1 # bump the chunk_index # Get the first index index = self.upcoming_indexes.pop(0) @@ -418,19 +430,19 @@ def __next__(self) -> Any: data = self.__getitem__( ChunkedIndex( index=index, - chunk_index=self.worker_chunks[self.worker_chunks_index - 1], + chunk_index=self.worker_chunks[self.worker_next_chunks_index - 1], # We provide the chunks indexes only one the first chunk_indexes=None if self.has_triggered_download - else self.worker_chunks[self.worker_chunks_index - 1 :], - is_last_index=(self.worker_chunks_index) == len(self.worker_intervals) + else self.worker_chunks[self.worker_next_chunks_index - 1 :], + is_last_index=(self.worker_next_chunks_index) == len(self.worker_intervals) and len(self.upcoming_indexes) == 0, ) ) self.has_triggered_download = True - self.global_index += 1 - self.consumed_sample_count_in_curr_chunk += 1 + self.global_index += 1 # total number of samples processed by the current worker + self.consumed_sample_count_in_curr_chunk += 1 # number of samples processed in the current chunk print(f"data: {data}") return data From 36e75c7430e4405b48822307f15cf0a3757abb29 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sun, 6 Apr 2025 11:14:06 +0530 Subject: [PATCH 05/12] something is wrong, I can feel it. --- src/litdata/streaming/reader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 3bc8bf03..55cf6579 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -82,6 +82,7 @@ def __init__( def download(self, chunk_indexes: List[int]) -> None: """Receive the list of the chunk indices to download for the current epoch.""" + print(f"thread: got indexes to download -> {chunk_indexes=};") for chunk_index in chunk_indexes: self._to_download_queue.put(chunk_index) @@ -353,6 +354,7 @@ def read(self, index: ChunkedIndex) -> Any: Prefetching should reduce the wait time to be the batch available. """ + print(f"reader read called -> {index=}") if not isinstance(index, ChunkedIndex): raise ValueError("The Reader.read(...) method expects a chunked Index.") From e30b96926c4519be69946874d9d1559e703c37f8 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 7 Apr 2025 15:27:07 +0530 Subject: [PATCH 06/12] works for multiple worker and sometimes in ddp too --- src/litdata/constants.py | 1 + src/litdata/streaming/config.py | 28 ++- src/litdata/streaming/dataset.py | 9 +- src/litdata/streaming/downloader.py | 12 - src/litdata/streaming/item_loader.py | 17 +- src/litdata/streaming/reader.py | 350 ++++++++++++++++----------- src/litdata/utilities/file_utils.py | 58 +++++ 7 files changed, 312 insertions(+), 163 deletions(-) create mode 100644 src/litdata/utilities/file_utils.py diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 018271a5..5bbc489c 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -25,6 +25,7 @@ _DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks") _DEFAULT_LIGHTNING_CACHE_DIR = os.path.join("/cache", "chunks") _SUPPORTED_PROVIDERS = ("s3", "gs") # cloud providers supported by litdata for uploading (optimize, map, merge, etc) +_SUPPORTED_DOWNLOADERS = ("s3", "gs", "azure", "hf") # cloud providers supported by litdata for streaming datasets # This is required for full pytree serialization / deserialization support _TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0") diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 833824f4..1ff15923 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -15,7 +15,7 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple -from litdata.constants import _INDEX_FILENAME +from litdata.constants import _INDEX_FILENAME, _SUPPORTED_DOWNLOADERS from litdata.streaming.compression import _COMPRESSORS, Compressor from litdata.streaming.downloader import get_downloader from litdata.streaming.item_loader import BaseItemLoader, Interval, PyTreeLoader, TokensLoader @@ -23,6 +23,7 @@ from litdata.streaming.serializers import Serializer from litdata.utilities._pytree import tree_unflatten, treespec_loads from litdata.utilities.dataset_utilities import load_index_file +from litdata.utilities.file_utils import increment_file_count class ChunksConfig: @@ -117,7 +118,7 @@ def skip_chunk_indexes_deletion(self) -> Optional[List[int]]: def skip_chunk_indexes_deletion(self, skip_chunk_indexes_deletion: List[int]) -> None: self._skip_chunk_indexes_deletion = skip_chunk_indexes_deletion - def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) -> None: + def download_chunk_from_index(self, chunk_index: int, rank: int = 0) -> None: assert self._chunks is not None chunk_filename = self._chunks[chunk_index]["filename"] @@ -125,22 +126,27 @@ def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) - if os.path.exists(local_chunkpath): self.try_decompress(local_chunkpath) - if self._downloader is not None and not skip_lock: + if self._downloader is not None and self._remote_dir.startswith(_SUPPORTED_DOWNLOADERS): # We don't want to redownload the base, but we should mark # it as having been requested by something - self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", "")) - pass + count = increment_file_count(local_chunkpath.replace(f".{self._compressor_name}", ""), rank) + if count == 1: + # weird, shouldn't happen + # but if it does, we should start downloading the file + self._downloader.download_chunk_from_index(chunk_index) + self.try_decompress(local_chunkpath) return - if self._downloader is None: + if (self._downloader is None) or (not self._remote_dir.startswith(_SUPPORTED_DOWNLOADERS)): return - if not skip_lock: - self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", "")) + curr_count = increment_file_count(local_chunkpath.replace(f".{self._compressor_name}", ""), rank) - self._downloader.download_chunk_from_index(chunk_index) - - self.try_decompress(local_chunkpath) + if curr_count == 1: + # this is the first time we are downloading this file + # so we should download it + self._downloader.download_chunk_from_index(chunk_index) + self.try_decompress(local_chunkpath) def try_decompress(self, local_chunkpath: str) -> None: if self._compressor is None: diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index cc40ecfe..a39b16a5 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -273,9 +273,13 @@ def __iter__(self) -> "StreamingDataset": ) worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank + if worker_rank == 0: + print(f"workers_chunks: {workers_chunks}\nworkers_intervals: {workers_intervals}") self.worker_chunks = workers_chunks[worker_rank] self.worker_intervals = workers_intervals[worker_rank] + print("-" * 50 + "\n" + f"{worker_rank=}; {self.worker_chunks=}; {self.worker_intervals=}\n" + "-" * 50) + # The max number of samples to return from `__next__` (in worker) self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) @@ -317,6 +321,7 @@ def __iter__(self) -> "StreamingDataset": self.has_triggered_download = False self.last_time = time() + self.cache._reader.prepare_downloader_thread(self.worker_chunks) return self @@ -435,15 +440,13 @@ def __next__(self) -> Any: chunk_indexes=None if self.has_triggered_download else self.worker_chunks[self.worker_next_chunks_index - 1 :], - is_last_index=(self.worker_next_chunks_index) == len(self.worker_intervals) - and len(self.upcoming_indexes) == 0, + is_last_index=(self.worker_next_chunks_index) == self.num_chunks and len(self.upcoming_indexes) == 0, ) ) self.has_triggered_download = True self.global_index += 1 # total number of samples processed by the current worker self.consumed_sample_count_in_curr_chunk += 1 # number of samples processed in the current chunk - print(f"data: {data}") return data def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int) -> Dict[str, Any]: diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index eb5e501e..58f02f55 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -42,18 +42,6 @@ def __init__( self._chunks = chunks self._storage_options = storage_options or {} - def _increment_local_lock(self, chunkpath: str) -> None: - countpath = chunkpath + ".cnt" - with suppress(Timeout), FileLock(countpath + ".lock", timeout=1): - try: - with open(countpath) as count_f: - curr_count = int(count_f.read().strip()) - except Exception: - curr_count = 0 - curr_count += 1 - with open(countpath, "w+") as count_f: - count_f.write(str(curr_count)) - def download_chunk_from_index(self, chunk_index: int) -> None: chunk_filename = self._chunks[chunk_index]["filename"] local_chunkpath = os.path.join(self._cache_dir, chunk_filename) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index f2a779e6..25140949 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -35,6 +35,7 @@ from litdata.streaming.serializers import Serializer from litdata.utilities._pytree import PyTree, tree_unflatten from litdata.utilities.encryption import Encryption, EncryptionLevel +from litdata.utilities.file_utils import decrement_file_count Interval = namedtuple("Interval", ["chunk_start", "roi_start_idx", "roi_end_idx", "chunk_end"]) @@ -110,6 +111,16 @@ def load_item_from_chunk( def delete(self, chunk_index: int, chunk_filepath: str) -> None: """Delete a chunk from the local filesystem.""" + def safe_delete(self, chunk_index: int, chunk_filepath: str, rank: int) -> None: + """Decrement the file count and delete the chunk file if the count reaches 0.""" + if os.path.exists(chunk_filepath + ".cnt"): + curr_count = decrement_file_count(chunk_filepath, rank) + + if curr_count == 0: + self.delete(chunk_index, chunk_filepath) + else: + self.delete(chunk_index, chunk_filepath) + @abstractmethod def encode_data(self, data: List[bytes], sizes: List[int], flattened: List[Any]) -> Any: pass @@ -250,7 +261,8 @@ def _load_data(self, fp: Union[FileIO, BytesIO], offset: int) -> bytes: pair = fp.read(8) begin, end = np.frombuffer(pair, np.uint32) - fp.seek(begin) # move the file pointer to the offset_start where the item starts + # move the file pointer to the offset_start where the item starts + fp.seek(begin) return fp.read(end - begin) # read the item def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree": @@ -675,7 +687,8 @@ def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_i # Return the specific row from the dataframe # Note: The `named=True` argument is used to return the row as a dictionary - return row_group_df.row(row_index_within_group, named=True) # type: ignore + # type: ignore + return row_group_df.row(row_index_within_group, named=True) def _get_item(self, chunk_index: int, chunk_filepath: str, index: int) -> Any: """Retrieve a dataframe row from a parquet chunk by loading the entire chunk into memory. diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 55cf6579..7d68931a 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -15,13 +15,10 @@ import logging import os import warnings -from contextlib import suppress from queue import Empty, Queue from threading import Event, Thread from typing import Any, Dict, List, Optional, Tuple, Union -from filelock import FileLock, Timeout - from litdata.constants import _DEBUG from litdata.streaming.config import ChunksConfig, Interval from litdata.streaming.item_loader import BaseItemLoader, ParquetLoader, PyTreeLoader, TokensLoader @@ -51,6 +48,7 @@ def __init__( self, config: ChunksConfig, item_loader: BaseItemLoader, + chunks_order: List[int], distributed_env: _DistributedEnv, max_cache_size: Optional[int] = None, max_pre_download: int = 2, @@ -59,15 +57,19 @@ def __init__( super().__init__(daemon=True) self._config = config self._item_loader = item_loader + self._chunks_order = chunks_order # order in which chunks are to be downloaded + self.current_downloading_chunk_index = -1 + self.current_reading_chunk_index = -1 self._max_pre_download = max_pre_download self._pre_download_counter = 0 self._distributed_env = distributed_env - self._chunks_index_to_be_deleted: List[int] = [] + # self._chunks_index_to_be_deleted: List[int] = [] self._max_cache_size = max_cache_size self._parent_cache_dir = os.path.dirname(self._config._cache_dir) self._to_download_queue: Queue = Queue() self._to_delete_queue: Queue = Queue() + self._delete_queue_received_none: bool = False self._force_stop_event = Event() # TODO: Find a real fix to this problem @@ -78,127 +80,183 @@ def __init__( # Check whether a dataset slice fits on the node num_bytes_per_nodes = self._config.num_bytes // self._distributed_env.num_nodes self._delete_chunks_when_processed = num_bytes_per_nodes > max_cache_size if max_cache_size else False + + # if self._delete_chunks_when_processed: + # print(f"clearing cache dir {self._parent_cache_dir} because the dataset is too large to fit in memory") + # # means we can't keep all chunks in the cache directory, so we should clear it to minimize the size + # # clear the cache directory except the index.json file + # for root, _, files in os.walk(self._parent_cache_dir): + # for file in files: + # if file != _INDEX_FILENAME: + # with contextlib.suppress(FileNotFoundError): + # os.remove(os.path.join(root, file)) self._has_exited = False - def download(self, chunk_indexes: List[int]) -> None: - """Receive the list of the chunk indices to download for the current epoch.""" - print(f"thread: got indexes to download -> {chunk_indexes=};") - for chunk_index in chunk_indexes: - self._to_download_queue.put(chunk_index) + # def download(self, chunk_indexes: List[int]) -> None: + # """Receive the list of the chunk indices to download for the current epoch.""" + # print(f"thread: got indexes to download -> {chunk_indexes=};") + # for chunk_index in chunk_indexes: + # self._to_download_queue.put(chunk_index) def delete(self, chunk_indexes: List[int]) -> None: """Receive the list of the chunk indices to delete for the current epoch.""" for chunk_index in chunk_indexes: + print(f"โšก๏ธ {self._rank=} asked to delete chunk {chunk_index=}") self._to_delete_queue.put(chunk_index) - def _remaining_locks(self, chunkpath: str) -> int: - countpath = chunkpath + ".cnt" - if not os.path.exists(countpath): - return 0 - with open(countpath) as count_f: - try: - return int(count_f.read().strip()) - except Exception: - return 1 - - def _decrement_local_lock(self, chunk_index: int) -> int: - """Remove a count from the local lock, return the remaining count.""" - chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] - - countpath = chunk_filepath + ".cnt" - with suppress(Timeout), FileLock(countpath + ".lock", timeout=3): - if not os.path.exists(countpath): - return 0 - with open(countpath) as count_f: - try: - curr_count = int(count_f.read().strip()) - except Exception: - curr_count = 1 - curr_count -= 1 - if curr_count <= 0: - with contextlib.suppress(FileNotFoundError, PermissionError): - os.remove(countpath) - - with contextlib.suppress(FileNotFoundError, PermissionError): - os.remove(countpath + ".lock") - else: - with open(countpath, "w+") as count_f: - count_f.write(str(curr_count)) - return curr_count - return 0 + # def _remaining_locks(self, chunkpath: str) -> int: + # countpath = chunkpath + ".cnt" + # if not os.path.exists(countpath): + # return 0 + # with open(countpath) as count_f: + # try: + # return int(count_f.read().strip()) + # except Exception: + # return 1 + + # def _decrement_local_lock(self, chunk_index: int) -> int: + # """Remove a count from the local lock, return the remaining count.""" + # chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] + + # countpath = chunk_filepath + ".cnt" + # with suppress(Timeout), FileLock(countpath + ".lock", timeout=3): + # if not os.path.exists(countpath): + # return 0 + # with open(countpath) as count_f: + # try: + # curr_count = int(count_f.read().strip()) + # except Exception: + # curr_count = 1 + # curr_count -= 1 + # if curr_count <= 0: + # with contextlib.suppress(FileNotFoundError, PermissionError): + # os.remove(countpath) + + # with contextlib.suppress(FileNotFoundError, PermissionError): + # os.remove(countpath + ".lock") + # else: + # with open(countpath, "w+") as count_f: + # count_f.write(str(curr_count)) + # return curr_count + # return 0 def _apply_delete(self, chunk_index: int) -> None: """Inform the item loader of the chunk to delete.""" # TODO: Fix the can_delete method - can_delete_chunk = self._config.can_delete(chunk_index) - print(f"apply delete called -> {chunk_index} {can_delete_chunk=}; by {self._rank or 0}") + # can_delete_chunk = self._config.can_delete(chunk_index) + # print(f"apply delete called -> {chunk_index} {can_delete_chunk=}; by {self._rank or 0}") chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] - remaining_locks = self._remaining_locks(chunk_filepath) - if remaining_locks > 0: # Can't delete this, something has it - if _DEBUG: - print(f"Skip delete {chunk_filepath} by {self._rank or 0}, current lock count: {remaining_locks}") - return + # remaining_locks = self._remaining_locks(chunk_filepath) + # if remaining_locks > 0: # Can't delete this, something has it + # if _DEBUG: + # print(f"Skip delete {chunk_filepath} by {self._rank or 0}, current lock count: {remaining_locks}") + # return - if _DEBUG: - with open(chunk_filepath + ".tmb", "w+") as tombstone_file: - tombstone_file.write(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") + # if _DEBUG: + # with open(chunk_filepath + ".tmb", "w+") as tombstone_file: + # tombstone_file.write(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") - self._item_loader.delete(chunk_index, chunk_filepath) + self._item_loader.safe_delete(chunk_index, chunk_filepath, self._rank) - if _DEBUG: - print(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") + # if _DEBUG: + # print(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") - for lock_extension in [".lock", ".cnt.lock"]: - try: - locak_chunk_path = chunk_filepath + lock_extension - if os.path.exists(locak_chunk_path): - os.remove(locak_chunk_path) - except FileNotFoundError: - pass + # for lock_extension in [".lock", ".cnt.lock"]: + # try: + # locak_chunk_path = chunk_filepath + lock_extension + # if os.path.exists(locak_chunk_path): + # os.remove(locak_chunk_path) + # except FileNotFoundError: + # pass def stop(self) -> None: """Receive the list of the chunk indices to download for the current epoch.""" - self._to_download_queue.put(_END_TOKEN) + # self._to_download_queue.put(_END_TOKEN) + if self._delete_chunks_when_processed and not self._delete_queue_received_none: + # for chnk_idx in self._chunks_index_to_be_deleted: + # self._apply_delete(chnk_idx) + # read from delete queue until None is received and delete the chunks + total_waiting_time = 0 + print(f"{self._rank=} stopping prepare_chunks_thread") + while True: + try: + chunk_index = self._to_delete_queue.get(timeout=_DEFAULT_TIMEOUT) + if chunk_index is None: + break + self._apply_delete(chunk_index) + total_waiting_time = 0 + except Empty: + total_waiting_time += _DEFAULT_TIMEOUT + if total_waiting_time > _LONG_DEFAULT_TIMEOUT: + print("Timeout waiting for delete queue to be empty (None)") + break + + # extra cleanup + # if self._delete_chunks_when_processed: + # # clear the cache directory (except the index.json file) + # for root, _, files in os.walk(self._parent_cache_dir): + # for file in files: + # if file != _INDEX_FILENAME: + # with contextlib.suppress(FileNotFoundError): + # os.remove(os.path.join(root, file)) + self.force_stop() def force_stop(self) -> None: self._force_stop_event.set() def _maybe_delete_chunks(self) -> None: - reached_pre_download = self._pre_download_counter == self._max_pre_download + # reached_pre_download = self._pre_download_counter == self._max_pre_download + # reached_max_pre_download = ( + # self.current_downloading_chunk_index < self.current_reading_chunk_index + self._max_pre_download + # ) + + # should_start_deleting_chunks = self._can_delete_chunk() + # if not should_start_deleting_chunks: + # return # we have already pre-downloaded some chunks, we just need to wait for them to be processed. - chunk_index = _get_from_queue( - self._to_delete_queue, timeout=_LONG_DEFAULT_TIMEOUT if reached_pre_download else _DEFAULT_TIMEOUT - ) + while True: + try: + chunk_index_to_be_deleted = _get_from_queue(self._to_delete_queue, timeout=_DEFAULT_TIMEOUT) - if chunk_index is not None: - self._pre_download_counter -= 1 + if chunk_index_to_be_deleted is None: + self._delete_queue_received_none = True + return + # self._pre_download_counter -= 1 - # Store the current chunk index - self._chunks_index_to_be_deleted.append(chunk_index) + # Store the current chunk index + # self._chunks_index_to_be_deleted.append(chunk_index) - # Get the current cache size and decide whether we need to start cleanup. Otherwise, keep track of it - while self._max_cache_size and self._chunks_index_to_be_deleted and self._can_delete_chunk(): - # Delete the oldest chunk - self._apply_delete(self._chunks_index_to_be_deleted.pop(0)) + # Get the current cache size and decide whether we need to start cleanup. Otherwise, keep track of it + # Delete the oldest chunk + self._apply_delete(chunk_index_to_be_deleted) + except Empty: + # Timeout waiting for delete queue to be empty + # print(f"Timeout waiting for delete queue to be empty (None)") + break + except Exception as e: + raise RuntimeError(f"Error while deleting chunks: {e}") from e return def _can_delete_chunk(self) -> bool: - # print( - # "can delete chunk called", - # self._delete_chunks_when_processed, - # self._pre_download_counter, - # self._max_pre_download, - # ) if self._delete_chunks_when_processed: - return self._pre_download_counter >= self._max_pre_download - 1 + # return self._pre_download_counter >= self._max_pre_download - 1 + # if we have downloaded all chunks, we can delete the oldest one + if self.current_downloading_chunk_index == len(self._chunks_order) - 1: + return True + return self.current_downloading_chunk_index >= (self.current_reading_chunk_index + self._max_pre_download) + + return False # if complete dataset can be stored in the cache, we don't need to delete any chunk return ( self._max_cache_size is not None and _get_folder_size(self._config._cache_dir, self._config) >= self._max_cache_size ) + def _can_download_chunk(self) -> bool: + return not self._can_delete_chunk() + def _pre_load_chunk(self, chunk_index: int) -> None: chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] self._item_loader.pre_load_chunk(chunk_index, chunk_filepath) @@ -210,7 +268,8 @@ def _force_download(self) -> None: chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] print(f"Requested force download for {chunk_filepath} by {self._rank}") - self._config.download_chunk_from_index(chunk_index, skip_lock=True) + # skip counter_file logic and directly download the chunk + self._config._downloader.download_chunk_from_index(chunk_index) # Preload item if possible to gain some time but only # if this is one of the pre-downloaded chunk @@ -228,25 +287,26 @@ def run(self) -> None: self._force_download() - if self._pre_download_counter < self._max_pre_download: - chunk_index = _get_from_queue(self._to_download_queue) - if chunk_index == _END_TOKEN: - self._has_exited = True - return + if self._can_download_chunk(): + self.current_downloading_chunk_index += 1 + # chunk_index = _get_from_queue(self._to_download_queue) + chunk_index = self._chunks_order[self.current_downloading_chunk_index] + # if chunk_index == _END_TOKEN: + # self._has_exited = True + # return if chunk_index is not None: - self._config.download_chunk_from_index(chunk_index) + self._config.download_chunk_from_index(chunk_index, self._rank) # Preload item if possible to gain some time but only # if this is one of the pre-downloaded chunk - if self._pre_download_counter > 0: - self._pre_load_chunk(chunk_index) + # if self._pre_download_counter > 0: + # self._pre_load_chunk(chunk_index) # Avoid downloading too many chunks in advance at the risk of over using the disk space - self._pre_download_counter += 1 + # self._pre_download_counter += 1 - if self._max_cache_size: - self._maybe_delete_chunks() + self._maybe_delete_chunks() # The BinaryReader operates as the inverse of the data optimization process: @@ -332,6 +392,24 @@ def _try_load_config(self) -> Optional[ChunksConfig]: ) return self._config + def prepare_downloader_thread(self, chunks_order: List[int]) -> None: + """Prepare the downloader thread and start downloading the first few chunks.""" + if self._config is None and self._try_load_config() is None: + raise Exception("The reader index isn't defined.") + + # Create and start the prepare chunks thread + if self._prepare_thread is None: + self._prepare_thread = PrepareChunksThread( + config=self._config, + item_loader=self._item_loader, + chunks_order=chunks_order, + distributed_env=self._distributed_env, + max_cache_size=self._max_cache_size, + max_pre_download=self._max_pre_download, + rank=self.rank, + ) + self._prepare_thread.start() + @property def config(self) -> ChunksConfig: if self._config is None: @@ -354,41 +432,45 @@ def read(self, index: ChunkedIndex) -> Any: Prefetching should reduce the wait time to be the batch available. """ - print(f"reader read called -> {index=}") if not isinstance(index, ChunkedIndex): raise ValueError("The Reader.read(...) method expects a chunked Index.") + if self._config is None or self._prepare_thread is None: + raise Exception( + "Reader's downloading thread is not started. Please call `reader.prepare_downloader_thread()` first." + ) + # Load the config containing the index - if self._config is None and self._try_load_config() is None: - raise Exception("The reader index isn't defined.") + # if self._config is None and self._try_load_config() is None: + # raise Exception("The reader index isn't defined.") - if self._config and (self._config._remote_dir or self._config._compressor): + if self._config and (self._config._remote_dir or self._config._compressor): # noqa: SIM102 # Create and start the prepare chunks thread - if self._prepare_thread is None and self._config: - self._prepare_thread = PrepareChunksThread( - self._config, - self._item_loader, - self._distributed_env, - self._max_cache_size, - self._max_pre_download, - self._rank, - ) - # Attach the force download queue - self._item_loader._force_download_queue = self._prepare_thread._force_download_queue # type: ignore - self._prepare_thread.start() - if index.chunk_indexes: - self._prepare_thread.download(index.chunk_indexes) - self._chunks_queued_for_download = True + # if self._prepare_thread is None and self._config: + # self._prepare_thread = PrepareChunksThread( + # self._config, + # self._item_loader, + # self._distributed_env, + # self._max_cache_size, + # self._max_pre_download, + # self._rank, + # ) + # # Attach the force download queue + # self._item_loader._force_download_queue = self._prepare_thread._force_download_queue # type: ignore + # self._prepare_thread.start() + # if index.chunk_indexes: + # self._prepare_thread.download(index.chunk_indexes) + # self._chunks_queued_for_download = True # Only request individual chunk download if: # 1. We haven't already queued all chunks for the download # 2. We're processing a new chunk (different from the last one) - if not self._chunks_queued_for_download and index.chunk_index != self._last_chunk_index: - assert self._prepare_thread - self._prepare_thread.download([index.chunk_index]) + # if not self._chunks_queued_for_download and index.chunk_index != self._last_chunk_index: + # assert self._prepare_thread + # self._prepare_thread.download([index.chunk_index]) - if self._last_chunk_index is None: - self._last_chunk_index = index.chunk_index + if self._last_chunk_index is None or index.chunk_index != self._last_chunk_index: + self._prepare_thread.current_reading_chunk_index += 1 # Fetch the element chunk_filepath, begin, filesize_bytes = self.config[index] @@ -404,32 +486,30 @@ def read(self, index: ChunkedIndex) -> Any: # We need to request deletion after the latest element has been loaded. # Otherwise, this could trigger segmentation fault error depending on the item loader used. - if ( - self._config - and (self._config._remote_dir or self._config._compressor) - and index.chunk_index != self._last_chunk_index - ): - assert self._prepare_thread - assert self._last_chunk_index is not None - - # inform the chunk has been completely consumed - self._prepare_thread._decrement_local_lock(self._last_chunk_index) - self._prepare_thread.delete([self._last_chunk_index]) - - if index.chunk_index != self._last_chunk_index: + if self._last_chunk_index is None or index.chunk_index != self._last_chunk_index: # Close the memory-mapped file for the last chunk index if isinstance(self._item_loader, (TokensLoader, ParquetLoader)) and self._last_chunk_index is not None: self._item_loader.close(self._last_chunk_index) + if self._config and (self._config._remote_dir or self._config._compressor): + assert self._prepare_thread + if self._last_chunk_index is not None: + # inform the chunk has been completely consumed + # self._prepare_thread._decrement_local_lock(self._last_chunk_index) + self._prepare_thread.delete([self._last_chunk_index]) + # track the new chunk index as the latest one self._last_chunk_index = index.chunk_index if index.is_last_index and self._prepare_thread: # inform the thread it is time to stop - self._prepare_thread._decrement_local_lock(index.chunk_index) + # self._prepare_thread._decrement_local_lock(index.chunk_index) + self._item_loader.close(self._last_chunk_index) + print(f"๐Ÿ˜ˆ it's last index of this chunk, sent it deleting: {index.chunk_index}") + self._prepare_thread.delete([index.chunk_index]) # send this chunk for deletion + self._prepare_thread._to_delete_queue.put(None) # signal the end of the queue self._prepare_thread.stop() self._prepare_thread = None - self._item_loader.close(self._last_chunk_index) self._last_chunk_index = None self._chunks_queued_for_download = False diff --git a/src/litdata/utilities/file_utils.py b/src/litdata/utilities/file_utils.py new file mode 100644 index 00000000..e8dac96d --- /dev/null +++ b/src/litdata/utilities/file_utils.py @@ -0,0 +1,58 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from contextlib import suppress + +from filelock import FileLock, Timeout + + +def increment_file_count(file_path: str, rank: int = 0) -> int: + """Increment the file count in the index file.""" + countpath = file_path + ".cnt" + with suppress(Timeout), FileLock(countpath + ".lock", timeout=1): + try: + with open(countpath) as count_f: + curr_count = int(count_f.read().strip()) + except Exception: + curr_count = 0 + curr_count += 1 + with open(countpath, "w+") as count_f: + count_f.write(str(curr_count)) + + print(f"โœ… {rank=} Incremented file count for {file_path} to => {curr_count}") + return curr_count + + +def decrement_file_count(file_path: str, rank: int = 0) -> int: + """Decrement the file count in the index file.""" + countpath = file_path + ".cnt" + + with suppress(Timeout), FileLock(countpath + ".lock", timeout=1): + try: + with open(countpath) as count_f: + curr_count = int(count_f.read().strip()) + except Exception as e: + raise ValueError(f"{rank=} Count file not found when trying to decrement_file_count: {countpath}.") from e + curr_count -= 1 + + if curr_count <= 0: + # remove the count file if it reaches zero + with suppress(FileNotFoundError): + os.remove(countpath) + print(f"โŒ {rank=} Decremented file count for {file_path} to => {curr_count}") + return 0 + with open(countpath, "w+") as count_f: + count_f.write(str(curr_count)) + + print(f"โŒ {rank=} Decremented file count for {file_path} to => {curr_count}") + return curr_count From afbd6cf8074a9ea76c82e5e4051ddafac5890fbd Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 7 Apr 2025 15:53:00 +0530 Subject: [PATCH 07/12] meow meow --- src/litdata/streaming/reader.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 7d68931a..ef017075 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -183,12 +183,13 @@ def stop(self) -> None: try: chunk_index = self._to_delete_queue.get(timeout=_DEFAULT_TIMEOUT) if chunk_index is None: + print(f"{self._rank=} received the none. bye bye") break self._apply_delete(chunk_index) total_waiting_time = 0 except Empty: total_waiting_time += _DEFAULT_TIMEOUT - if total_waiting_time > _LONG_DEFAULT_TIMEOUT: + if total_waiting_time > _LONG_DEFAULT_TIMEOUT * 6: # wait for 30 seconds print("Timeout waiting for delete queue to be empty (None)") break @@ -216,12 +217,15 @@ def _maybe_delete_chunks(self) -> None: # return # we have already pre-downloaded some chunks, we just need to wait for them to be processed. + if self._delete_queue_received_none: + return while True: try: - chunk_index_to_be_deleted = _get_from_queue(self._to_delete_queue, timeout=_DEFAULT_TIMEOUT) + chunk_index_to_be_deleted = self._to_delete_queue.get(timeout=_DEFAULT_TIMEOUT) if chunk_index_to_be_deleted is None: self._delete_queue_received_none = True + print(f"Received the none. bye bye {self._rank=}") return # self._pre_download_counter -= 1 @@ -505,7 +509,7 @@ def read(self, index: ChunkedIndex) -> Any: # inform the thread it is time to stop # self._prepare_thread._decrement_local_lock(index.chunk_index) self._item_loader.close(self._last_chunk_index) - print(f"๐Ÿ˜ˆ it's last index of this chunk, sent it deleting: {index.chunk_index}") + print(f"๐Ÿ˜ˆ {self._rank=} it's last index of this chunk, sent it deleting: {index.chunk_index}") self._prepare_thread.delete([index.chunk_index]) # send this chunk for deletion self._prepare_thread._to_delete_queue.put(None) # signal the end of the queue self._prepare_thread.stop() From 1d45d9508dd5dd1ae08a9ff0f153df6290021277 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 7 Apr 2025 17:06:05 +0530 Subject: [PATCH 08/12] time for imagenet test --- src/litdata/streaming/config.py | 8 ++-- src/litdata/streaming/dataloader.py | 1 + src/litdata/streaming/dataset.py | 59 ++++++++++++++++------------ src/litdata/streaming/item_loader.py | 5 +-- src/litdata/streaming/reader.py | 25 +++--------- src/litdata/streaming/sampler.py | 4 +- src/litdata/utilities/env.py | 6 ++- src/litdata/utilities/file_utils.py | 9 ++--- 8 files changed, 55 insertions(+), 62 deletions(-) diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 1ff15923..bd0b2303 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -118,7 +118,7 @@ def skip_chunk_indexes_deletion(self) -> Optional[List[int]]: def skip_chunk_indexes_deletion(self, skip_chunk_indexes_deletion: List[int]) -> None: self._skip_chunk_indexes_deletion = skip_chunk_indexes_deletion - def download_chunk_from_index(self, chunk_index: int, rank: int = 0) -> None: + def download_chunk_from_index(self, chunk_index: int) -> None: assert self._chunks is not None chunk_filename = self._chunks[chunk_index]["filename"] @@ -129,7 +129,7 @@ def download_chunk_from_index(self, chunk_index: int, rank: int = 0) -> None: if self._downloader is not None and self._remote_dir.startswith(_SUPPORTED_DOWNLOADERS): # We don't want to redownload the base, but we should mark # it as having been requested by something - count = increment_file_count(local_chunkpath.replace(f".{self._compressor_name}", ""), rank) + count = increment_file_count(local_chunkpath.replace(f".{self._compressor_name}", "")) if count == 1: # weird, shouldn't happen # but if it does, we should start downloading the file @@ -140,11 +140,11 @@ def download_chunk_from_index(self, chunk_index: int, rank: int = 0) -> None: if (self._downloader is None) or (not self._remote_dir.startswith(_SUPPORTED_DOWNLOADERS)): return - curr_count = increment_file_count(local_chunkpath.replace(f".{self._compressor_name}", ""), rank) + curr_count = increment_file_count(local_chunkpath.replace(f".{self._compressor_name}", "")) if curr_count == 1: # this is the first time we are downloading this file - # so we should download it + # so we should actually download it self._downloader.download_chunk_from_index(chunk_index) self.try_decompress(local_chunkpath) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index e96caba5..c99b6e84 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -252,6 +252,7 @@ def _next_data(self) -> Any: raise e +#! TODO: This class is not being used anywhere. class CacheDataLoader(DataLoader): __doc__ = DataLoader.__doc__ diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index a39b16a5..68c96ab2 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -33,6 +33,10 @@ from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv from litdata.utilities.format import _convert_bytes_to_int from litdata.utilities.hf_dataset import index_hf_dataset +from litdata.utilities.shuffle import ( + _find_chunks_per_workers_on_which_to_skip_deletion, + _map_node_worker_rank_to_chunk_indexes_to_not_delete, +) logger = logging.getLogger(__name__) @@ -150,7 +154,7 @@ def __init__( self.cache: Optional[Cache] = None self.worker_env: Optional[_WorkerEnv] = None - self.worker_chunks: List[int] = [] # chunk indexes on which the current worker will work + self.worker_chunks: List[int] = [] # chunk indexes that the current worker will download, read & stream self.worker_intervals: List[List[int]] = [] # chunk index intervals for the current worker self.upcoming_indexes: List[int] = [] # contains list of upcoming indexes to be processed @@ -274,11 +278,11 @@ def __iter__(self) -> "StreamingDataset": worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank if worker_rank == 0: - print(f"workers_chunks: {workers_chunks}\nworkers_intervals: {workers_intervals}") + logger.debug(f"workers_chunks: {workers_chunks}\nworkers_intervals: {workers_intervals}") self.worker_chunks = workers_chunks[worker_rank] self.worker_intervals = workers_intervals[worker_rank] - print("-" * 50 + "\n" + f"{worker_rank=}; {self.worker_chunks=}; {self.worker_intervals=}\n" + "-" * 50) + logger.debug("-" * 50 + "\n" + f"{worker_rank=}; {self.worker_chunks=}; {self.worker_intervals=}\n" + "-" * 50) # The max number of samples to return from `__next__` (in worker) self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) @@ -290,28 +294,30 @@ def __iter__(self) -> "StreamingDataset": # Find the chunks shared across all workers of the current node. # For each shared chunk, find the rank and worker to use the chunk last and prevent # premature deletion for the other workers. - # node_size = self.distributed_env.world_size // self.distributed_env.num_nodes - # first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size - # num_workers_per_node = node_size * self.num_workers - # worker_start = first_rank_this_node * num_workers_per_node - # worker_end = worker_start + num_workers_per_node - # local_rank = self.distributed_env.global_rank % node_size - - # chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( - # self.num_workers, - # self.batch_size, - # workers_chunks[worker_start:worker_end], - # workers_intervals[worker_start:worker_end], - # ) - # worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete( - # chunks_indexes_skip_deletion - # ) - - # worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank - # if worker_rank_local_node in worker_node_rank_to_chunk_indexes: - # self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[ - # worker_rank_local_node - # ] + if False: + #! TODO: fix skip_chunk_deletion (iops is much slower than in memory access) + node_size = self.distributed_env.world_size // self.distributed_env.num_nodes + first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size + num_workers_per_node = node_size * self.num_workers + worker_start = first_rank_this_node * num_workers_per_node + worker_end = worker_start + num_workers_per_node + local_rank = self.distributed_env.global_rank % node_size + + chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( + self.num_workers, + self.batch_size, + workers_chunks[worker_start:worker_end], + workers_intervals[worker_start:worker_end], + ) + worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete( + chunks_indexes_skip_deletion + ) + + worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank + if worker_rank_local_node in worker_node_rank_to_chunk_indexes: + self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[ + worker_rank_local_node + ] self.num_chunks = len(self.worker_chunks) self.upcoming_indexes = [] @@ -321,6 +327,8 @@ def __iter__(self) -> "StreamingDataset": self.has_triggered_download = False self.last_time = time() + + # start the downloader thread self.cache._reader.prepare_downloader_thread(self.worker_chunks) return self @@ -395,7 +403,6 @@ def __next__(self) -> Any: # global_index: total number of samples processed by the current worker across all chunks # stop_length: max number of samples that the current worker will process # if they are equal, means, worker has processed all the chunks - print("dame tu cosita aha ah ah ah") self.current_epoch += 1 self.reset_state_dict() raise StopIteration diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 25140949..efb6041a 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -111,10 +111,10 @@ def load_item_from_chunk( def delete(self, chunk_index: int, chunk_filepath: str) -> None: """Delete a chunk from the local filesystem.""" - def safe_delete(self, chunk_index: int, chunk_filepath: str, rank: int) -> None: + def safe_delete(self, chunk_index: int, chunk_filepath: str) -> None: """Decrement the file count and delete the chunk file if the count reaches 0.""" if os.path.exists(chunk_filepath + ".cnt"): - curr_count = decrement_file_count(chunk_filepath, rank) + curr_count = decrement_file_count(chunk_filepath) if curr_count == 0: self.delete(chunk_index, chunk_filepath) @@ -687,7 +687,6 @@ def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_i # Return the specific row from the dataframe # Note: The `named=True` argument is used to return the row as a dictionary - # type: ignore return row_group_df.row(row_index_within_group, named=True) def _get_item(self, chunk_index: int, chunk_filepath: str, index: int) -> Any: diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index ef017075..3881e9d4 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -101,7 +101,6 @@ def __init__( def delete(self, chunk_indexes: List[int]) -> None: """Receive the list of the chunk indices to delete for the current epoch.""" for chunk_index in chunk_indexes: - print(f"โšก๏ธ {self._rank=} asked to delete chunk {chunk_index=}") self._to_delete_queue.put(chunk_index) # def _remaining_locks(self, chunkpath: str) -> int: @@ -178,29 +177,19 @@ def stop(self) -> None: # self._apply_delete(chnk_idx) # read from delete queue until None is received and delete the chunks total_waiting_time = 0 - print(f"{self._rank=} stopping prepare_chunks_thread") - while True: + while not self._delete_queue_received_none: # parallelly it can be set true by thread's run method try: chunk_index = self._to_delete_queue.get(timeout=_DEFAULT_TIMEOUT) if chunk_index is None: - print(f"{self._rank=} received the none. bye bye") + self._delete_queue_received_none = True break self._apply_delete(chunk_index) total_waiting_time = 0 except Empty: total_waiting_time += _DEFAULT_TIMEOUT - if total_waiting_time > _LONG_DEFAULT_TIMEOUT * 6: # wait for 30 seconds + if total_waiting_time > _LONG_DEFAULT_TIMEOUT * 2: # wait for 10 seconds print("Timeout waiting for delete queue to be empty (None)") break - - # extra cleanup - # if self._delete_chunks_when_processed: - # # clear the cache directory (except the index.json file) - # for root, _, files in os.walk(self._parent_cache_dir): - # for file in files: - # if file != _INDEX_FILENAME: - # with contextlib.suppress(FileNotFoundError): - # os.remove(os.path.join(root, file)) self.force_stop() def force_stop(self) -> None: @@ -225,7 +214,6 @@ def _maybe_delete_chunks(self) -> None: if chunk_index_to_be_deleted is None: self._delete_queue_received_none = True - print(f"Received the none. bye bye {self._rank=}") return # self._pre_download_counter -= 1 @@ -237,7 +225,6 @@ def _maybe_delete_chunks(self) -> None: self._apply_delete(chunk_index_to_be_deleted) except Empty: # Timeout waiting for delete queue to be empty - # print(f"Timeout waiting for delete queue to be empty (None)") break except Exception as e: raise RuntimeError(f"Error while deleting chunks: {e}") from e @@ -509,10 +496,9 @@ def read(self, index: ChunkedIndex) -> Any: # inform the thread it is time to stop # self._prepare_thread._decrement_local_lock(index.chunk_index) self._item_loader.close(self._last_chunk_index) - print(f"๐Ÿ˜ˆ {self._rank=} it's last index of this chunk, sent it deleting: {index.chunk_index}") - self._prepare_thread.delete([index.chunk_index]) # send this chunk for deletion - self._prepare_thread._to_delete_queue.put(None) # signal the end of the queue + self._prepare_thread.delete([index.chunk_index, None]) # send this chunk for deletion self._prepare_thread.stop() + self._prepare_thread.join() self._prepare_thread = None self._last_chunk_index = None self._chunks_queued_for_download = False @@ -588,7 +574,6 @@ def _get_folder_size(path: str, config: ChunksConfig) -> int: f"Ignoring '{filename}': " "This file doesn't appear to be a valid chunk file and has been excluded from the size calculation." ) - print(f"Total size of files in '{path}': {size} bytes") return size diff --git a/src/litdata/streaming/sampler.py b/src/litdata/streaming/sampler.py index 4c67b51d..273b87fd 100644 --- a/src/litdata/streaming/sampler.py +++ b/src/litdata/streaming/sampler.py @@ -22,7 +22,8 @@ @dataclass class ChunkedIndex: - """Represents an index within a chunked dataset. + """The docstring below is incorrect for `chunk_indexes`. + Represents an index within a chunked dataset. Attributes: index (int): The global index of the data point across all chunks. @@ -48,6 +49,7 @@ class ChunkedIndex: is_last_index: bool = False +#! TODO: This class is not being used. (Used in CacheDataLoader, but CacheDataLoader itself not being used) class CacheBatchSampler: def __init__( self, diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index 30b27f1d..0cf0dbcc 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -81,8 +81,10 @@ def _instantiate_in_map_or_optimize(cls) -> "_DistributedEnv": def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}," - + f" num_nodes: {self.num_nodes})" + f"{self.__class__.__name__}(" + f"world_size={self.world_size}, " + f"global_rank={self.global_rank}, " + f"num_nodes={self.num_nodes})" ) def __str__(self) -> str: diff --git a/src/litdata/utilities/file_utils.py b/src/litdata/utilities/file_utils.py index e8dac96d..18479995 100644 --- a/src/litdata/utilities/file_utils.py +++ b/src/litdata/utilities/file_utils.py @@ -16,7 +16,7 @@ from filelock import FileLock, Timeout -def increment_file_count(file_path: str, rank: int = 0) -> int: +def increment_file_count(file_path: str) -> int: """Increment the file count in the index file.""" countpath = file_path + ".cnt" with suppress(Timeout), FileLock(countpath + ".lock", timeout=1): @@ -29,11 +29,10 @@ def increment_file_count(file_path: str, rank: int = 0) -> int: with open(countpath, "w+") as count_f: count_f.write(str(curr_count)) - print(f"โœ… {rank=} Incremented file count for {file_path} to => {curr_count}") return curr_count -def decrement_file_count(file_path: str, rank: int = 0) -> int: +def decrement_file_count(file_path: str) -> int: """Decrement the file count in the index file.""" countpath = file_path + ".cnt" @@ -42,17 +41,15 @@ def decrement_file_count(file_path: str, rank: int = 0) -> int: with open(countpath) as count_f: curr_count = int(count_f.read().strip()) except Exception as e: - raise ValueError(f"{rank=} Count file not found when trying to decrement_file_count: {countpath}.") from e + raise ValueError(f"Count file not found when trying to decrement_file_count: {countpath}.") from e curr_count -= 1 if curr_count <= 0: # remove the count file if it reaches zero with suppress(FileNotFoundError): os.remove(countpath) - print(f"โŒ {rank=} Decremented file count for {file_path} to => {curr_count}") return 0 with open(countpath, "w+") as count_f: count_f.write(str(curr_count)) - print(f"โŒ {rank=} Decremented file count for {file_path} to => {curr_count}") return curr_count From 46bbc2c93bdefed06218be2c1e2653323002b48d Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 7 Apr 2025 17:17:07 +0530 Subject: [PATCH 09/12] update --- src/litdata/streaming/item_loader.py | 1 - src/litdata/streaming/reader.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index efb6041a..ea02c3ab 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -115,7 +115,6 @@ def safe_delete(self, chunk_index: int, chunk_filepath: str) -> None: """Decrement the file count and delete the chunk file if the count reaches 0.""" if os.path.exists(chunk_filepath + ".cnt"): curr_count = decrement_file_count(chunk_filepath) - if curr_count == 0: self.delete(chunk_index, chunk_filepath) else: diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 3881e9d4..bb6dec4b 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -156,7 +156,7 @@ def _apply_delete(self, chunk_index: int) -> None: # with open(chunk_filepath + ".tmb", "w+") as tombstone_file: # tombstone_file.write(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") - self._item_loader.safe_delete(chunk_index, chunk_filepath, self._rank) + self._item_loader.safe_delete(chunk_index, chunk_filepath) # if _DEBUG: # print(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") @@ -287,7 +287,7 @@ def run(self) -> None: # return if chunk_index is not None: - self._config.download_chunk_from_index(chunk_index, self._rank) + self._config.download_chunk_from_index(chunk_index) # Preload item if possible to gain some time but only # if this is one of the pre-downloaded chunk From 4b1000a3bc95d9c4177db2ce5e0a75abb505760e Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 8 Apr 2025 04:31:14 +0000 Subject: [PATCH 10/12] update --- src/litdata/streaming/item_loader.py | 4 ++-- src/litdata/streaming/reader.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index ea02c3ab..2e5f0d1b 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -111,11 +111,11 @@ def load_item_from_chunk( def delete(self, chunk_index: int, chunk_filepath: str) -> None: """Delete a chunk from the local filesystem.""" - def safe_delete(self, chunk_index: int, chunk_filepath: str) -> None: + def safe_delete(self, chunk_index: int, chunk_filepath: str, delete_original_file: bool = False) -> None: """Decrement the file count and delete the chunk file if the count reaches 0.""" if os.path.exists(chunk_filepath + ".cnt"): curr_count = decrement_file_count(chunk_filepath) - if curr_count == 0: + if curr_count == 0 and delete_original_file: self.delete(chunk_index, chunk_filepath) else: self.delete(chunk_index, chunk_filepath) diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index bb6dec4b..0af1814e 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -80,7 +80,7 @@ def __init__( # Check whether a dataset slice fits on the node num_bytes_per_nodes = self._config.num_bytes // self._distributed_env.num_nodes self._delete_chunks_when_processed = num_bytes_per_nodes > max_cache_size if max_cache_size else False - + # print(f"reader: {self._delete_chunks_when_processed=}") # if self._delete_chunks_when_processed: # print(f"clearing cache dir {self._parent_cache_dir} because the dataset is too large to fit in memory") # # means we can't keep all chunks in the cache directory, so we should clear it to minimize the size @@ -156,7 +156,9 @@ def _apply_delete(self, chunk_index: int) -> None: # with open(chunk_filepath + ".tmb", "w+") as tombstone_file: # tombstone_file.write(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") - self._item_loader.safe_delete(chunk_index, chunk_filepath) + self._item_loader.safe_delete( + chunk_index, chunk_filepath, delete_original_file=self._delete_chunks_when_processed + ) # if _DEBUG: # print(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") @@ -278,7 +280,7 @@ def run(self) -> None: self._force_download() - if self._can_download_chunk(): + if self._can_download_chunk() and (self.current_downloading_chunk_index < len(self._chunks_order) - 1): self.current_downloading_chunk_index += 1 # chunk_index = _get_from_queue(self._to_download_queue) chunk_index = self._chunks_order[self.current_downloading_chunk_index] @@ -296,7 +298,6 @@ def run(self) -> None: # Avoid downloading too many chunks in advance at the risk of over using the disk space # self._pre_download_counter += 1 - self._maybe_delete_chunks() From e80543ae5febf2f36f34a68c8c438b879e0a595a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Apr 2025 16:37:29 +0000 Subject: [PATCH 11/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/config.py | 1 - src/litdata/streaming/reader.py | 1 - src/litdata/utilities/dataset_utilities.py | 1 - 3 files changed, 3 deletions(-) diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 82155239..1907ab12 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -16,7 +16,6 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple - from litdata.constants import _INDEX_FILENAME, _SUPPORTED_DOWNLOADERS from litdata.debugger import ChromeTraceColors, _get_log_msg from litdata.streaming.compression import _COMPRESSORS, Compressor diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 23a39a61..95285358 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -140,7 +140,6 @@ def delete(self, chunk_indexes: List[int]) -> None: # return curr_count # return 0 - def _apply_delete(self, chunk_index: int) -> None: """Inform the item loader of the chunk to delete.""" # TODO: Fix the can_delete method diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 29f61ef6..3586f1a9 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -64,7 +64,6 @@ def subsample_streaming_dataset( downloader = get_downloader(input_dir.url, input_dir.path, [], storage_options) downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), cache_index_filepath) - time.sleep(0.5) # Give some time for the file to be created if not os.path.exists(input_dir.path): From 117bd4dec76451b5d520e743f7cae40f1626a60a Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 17 Apr 2025 22:11:22 +0530 Subject: [PATCH 12/12] update --- src/litdata/utilities/dataset_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 3586f1a9..ec3bb73b 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -64,7 +64,7 @@ def subsample_streaming_dataset( downloader = get_downloader(input_dir.url, input_dir.path, [], storage_options) downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), cache_index_filepath) - time.sleep(0.5) # Give some time for the file to be created + time.sleep(0.5) # Give some time for the file to be available if not os.path.exists(input_dir.path): raise FileNotFoundError(f"The provided dataset path `{input_dir.path}` does not exist.")