Skip to content

fix: streaming & deleting issue #538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")
_DEFAULT_LIGHTNING_CACHE_DIR = os.path.join("/cache", "chunks")
_SUPPORTED_PROVIDERS = ("s3", "gs") # cloud providers supported by litdata for uploading (optimize, map, merge, etc)
_SUPPORTED_DOWNLOADERS = ("s3", "gs", "azure", "hf") # cloud providers supported by litdata for streaming datasets

# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
Expand Down
28 changes: 17 additions & 11 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple

from litdata.constants import _INDEX_FILENAME
from litdata.constants import _INDEX_FILENAME, _SUPPORTED_DOWNLOADERS
from litdata.debugger import ChromeTraceColors, _get_log_msg
from litdata.streaming.compression import _COMPRESSORS, Compressor
from litdata.streaming.downloader import get_downloader
Expand All @@ -25,6 +25,7 @@
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import tree_unflatten, treespec_loads
from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.file_utils import increment_file_count

logger = logging.getLogger("litdata.streaming.config")

Expand Down Expand Up @@ -121,30 +122,35 @@ def skip_chunk_indexes_deletion(self) -> Optional[List[int]]:
def skip_chunk_indexes_deletion(self, skip_chunk_indexes_deletion: List[int]) -> None:
self._skip_chunk_indexes_deletion = skip_chunk_indexes_deletion

def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) -> None:
def download_chunk_from_index(self, chunk_index: int) -> None:
assert self._chunks is not None
chunk_filename = self._chunks[chunk_index]["filename"]

local_chunkpath = os.path.join(self._cache_dir, chunk_filename)

if os.path.exists(local_chunkpath):
self.try_decompress(local_chunkpath)
if self._downloader is not None and not skip_lock:
if self._downloader is not None and self._remote_dir.startswith(_SUPPORTED_DOWNLOADERS):
# We don't want to redownload the base, but we should mark
# it as having been requested by something
self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", ""))
pass
count = increment_file_count(local_chunkpath.replace(f".{self._compressor_name}", ""))
if count == 1:
# weird, shouldn't happen
# but if it does, we should start downloading the file
self._downloader.download_chunk_from_index(chunk_index)
self.try_decompress(local_chunkpath)
return

if self._downloader is None:
if (self._downloader is None) or (not self._remote_dir.startswith(_SUPPORTED_DOWNLOADERS)):
return

if not skip_lock:
self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", ""))
curr_count = increment_file_count(local_chunkpath.replace(f".{self._compressor_name}", ""))

self._downloader.download_chunk_from_index(chunk_index)

self.try_decompress(local_chunkpath)
if curr_count == 1:
# this is the first time we are downloading this file
# so we should actually download it
self._downloader.download_chunk_from_index(chunk_index)
self.try_decompress(local_chunkpath)

def try_decompress(self, local_chunkpath: str) -> None:
if self._compressor is None:
Expand Down
3 changes: 2 additions & 1 deletion src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __call__(self, items: List[Any]) -> Any:


class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter):
"""This is ç to inform the cache is done chunking."""
"""This is to inform the cache is done chunking."""

def _next_data(self) -> Any:
try:
Expand Down Expand Up @@ -253,6 +253,7 @@ def _next_data(self) -> Any:
raise e


#! TODO: This class is not being used anywhere.
class CacheDataLoader(DataLoader):
__doc__ = DataLoader.__doc__

Expand Down
51 changes: 30 additions & 21 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,13 @@ def __iter__(self) -> "StreamingDataset":
)

worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
if worker_rank == 0:
logger.debug(f"workers_chunks: {workers_chunks}\nworkers_intervals: {workers_intervals}")
self.worker_chunks = workers_chunks[worker_rank]
self.worker_intervals = workers_intervals[worker_rank]

logger.debug("-" * 50 + "\n" + f"{worker_rank=}; {self.worker_chunks=}; {self.worker_intervals=}\n" + "-" * 50)

# The max number of samples to return from `__next__` (in worker)
self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)

Expand All @@ -302,28 +306,30 @@ def __iter__(self) -> "StreamingDataset":
# Find the chunks shared across all workers of the current node.
# For each shared chunk, find the rank and worker to use the chunk last and prevent
# premature deletion for the other workers.
node_size = self.distributed_env.world_size // self.distributed_env.num_nodes
first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size
num_workers_per_node = node_size * self.num_workers
worker_start = first_rank_this_node * num_workers_per_node
worker_end = worker_start + num_workers_per_node
local_rank = self.distributed_env.global_rank % node_size

chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion(
self.num_workers,
self.batch_size,
workers_chunks[worker_start:worker_end],
workers_intervals[worker_start:worker_end],
)
worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete(
chunks_indexes_skip_deletion
)
if False:
#! TODO: fix skip_chunk_deletion (iops is much slower than in memory access)
node_size = self.distributed_env.world_size // self.distributed_env.num_nodes
first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size
num_workers_per_node = node_size * self.num_workers
worker_start = first_rank_this_node * num_workers_per_node
worker_end = worker_start + num_workers_per_node
local_rank = self.distributed_env.global_rank % node_size

chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion(
self.num_workers,
self.batch_size,
workers_chunks[worker_start:worker_end],
workers_intervals[worker_start:worker_end],
)
worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete(
chunks_indexes_skip_deletion
)

worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank
if worker_rank_local_node in worker_node_rank_to_chunk_indexes:
self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[
worker_rank_local_node
]
worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank
if worker_rank_local_node in worker_node_rank_to_chunk_indexes:
self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[
worker_rank_local_node
]

self.num_chunks = len(self.worker_chunks)
self.upcoming_indexes = []
Expand All @@ -334,6 +340,9 @@ def __iter__(self) -> "StreamingDataset":
self.has_triggered_download = False
self.last_time = time()

# start the downloader thread
self.cache._reader.prepare_downloader_thread(self.worker_chunks)

return self

def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) -> None:
Expand Down
14 changes: 0 additions & 14 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,6 @@ def __init__(
self._chunks = chunks
self._storage_options = storage_options or {}

def _increment_local_lock(self, chunkpath: str) -> None:
logger.debug(_get_log_msg({"name": f"increment_local_lock_for_{chunkpath}", "ph": "B"}))
countpath = chunkpath + ".cnt"
with suppress(Timeout), FileLock(countpath + ".lock", timeout=1):
try:
with open(countpath) as count_f:
curr_count = int(count_f.read().strip())
except Exception:
curr_count = 0
curr_count += 1
with open(countpath, "w+") as count_f:
count_f.write(str(curr_count))
logger.debug(_get_log_msg({"name": f"increment_local_lock_for_{chunkpath}", "ph": "E"}))

def download_chunk_from_index(self, chunk_index: int) -> None:
logger.debug(_get_log_msg({"name": f"download_chunk_from_index_{chunk_index}", "ph": "B"}))

Expand Down
15 changes: 13 additions & 2 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import PyTree, tree_unflatten
from litdata.utilities.encryption import Encryption, EncryptionLevel
from litdata.utilities.file_utils import decrement_file_count

Interval = namedtuple("Interval", ["chunk_start", "roi_start_idx", "roi_end_idx", "chunk_end"])

Expand Down Expand Up @@ -111,6 +112,15 @@ def load_item_from_chunk(
def delete(self, chunk_index: int, chunk_filepath: str) -> None:
"""Delete a chunk from the local filesystem."""

def safe_delete(self, chunk_index: int, chunk_filepath: str, delete_original_file: bool = False) -> None:
"""Decrement the file count and delete the chunk file if the count reaches 0."""
if os.path.exists(chunk_filepath + ".cnt"):
curr_count = decrement_file_count(chunk_filepath)
if curr_count == 0 and delete_original_file:
self.delete(chunk_index, chunk_filepath)
else:
self.delete(chunk_index, chunk_filepath)

@abstractmethod
def encode_data(self, data: List[bytes], sizes: List[int], flattened: List[Any]) -> Any:
pass
Expand Down Expand Up @@ -260,7 +270,8 @@ def _load_data(self, fp: Union[FileIO, BytesIO], offset: int) -> bytes:
pair = fp.read(8)
begin, end = np.frombuffer(pair, np.uint32)

fp.seek(begin) # move the file pointer to the offset_start where the item starts
# move the file pointer to the offset_start where the item starts
fp.seek(begin)
return fp.read(end - begin) # read the item

def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree":
Expand Down Expand Up @@ -737,7 +748,7 @@ def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_i

# Return the specific row from the dataframe
# Note: The `named=True` argument is used to return the row as a dictionary
return row_group_df.row(row_index_within_group, named=True) # type: ignore
return row_group_df.row(row_index_within_group, named=True)

def _get_item(self, chunk_index: int, chunk_filepath: str, index: int) -> Any:
"""Retrieve a dataframe row from a parquet chunk by loading the entire chunk into memory.
Expand Down
Loading
Loading