Skip to content

Commit c710581

Browse files
monkenmartindurant
andauthored
Call cat_ranges in blockcache for async filesystems (#1336)
Co-authored-by: Martin Durant <[email protected]>
1 parent c0d1034 commit c710581

File tree

3 files changed

+90
-2
lines changed

3 files changed

+90
-2
lines changed

fsspec/caching.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
logger = logging.getLogger("fsspec")
3838

3939
Fetcher = Callable[[int, int], bytes] # Maps (start, end) to bytes
40+
MultiFetcher = Callable[list[[int, int]], bytes] # Maps [(start, end)] to bytes
4041

4142

4243
class BaseCache:
@@ -109,6 +110,26 @@ class MMapCache(BaseCache):
109110
Ensure there is enough disc space in the temporary location.
110111
111112
This cache method might only work on posix
113+
114+
Parameters
115+
----------
116+
blocksize: int
117+
How far to read ahead in numbers of bytes
118+
fetcher: Fetcher
119+
Function of the form f(start, end) which gets bytes from remote as
120+
specified
121+
size: int
122+
How big this file is
123+
location: str
124+
Where to create the temporary file. If None, a temporary file is
125+
created using tempfile.TemporaryFile().
126+
blocks: set[int]
127+
Set of block numbers that have already been fetched. If None, an empty
128+
set is created.
129+
multi_fetcher: MultiFetcher
130+
Function of the form f([(start, end)]) which gets bytes from remote
131+
as specified. This function is used to fetch multiple blocks at once.
132+
If not specified, the fetcher function is used instead.
112133
"""
113134

114135
name = "mmap"
@@ -120,10 +141,12 @@ def __init__(
120141
size: int,
121142
location: str | None = None,
122143
blocks: set[int] | None = None,
144+
multi_fetcher: MultiFetcher | None = None,
123145
) -> None:
124146
super().__init__(blocksize, fetcher, size)
125147
self.blocks = set() if blocks is None else blocks
126148
self.location = location
149+
self.multi_fetcher = multi_fetcher
127150
self.cache = self._makefile()
128151

129152
def _makefile(self) -> mmap.mmap | bytearray:
@@ -164,6 +187,8 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
164187
# Count the number of blocks already cached
165188
self.hit_count += sum(1 for i in block_range if i in self.blocks)
166189

190+
ranges = []
191+
167192
# Consolidate needed blocks.
168193
# Algorithm adapted from Python 2.x itertools documentation.
169194
# We are grouping an enumerated sequence of blocks. By comparing when the difference
@@ -185,13 +210,27 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
185210
logger.debug(
186211
f"MMap get blocks {_blocks[0]}-{_blocks[-1]} ({sstart}-{send})"
187212
)
188-
self.cache[sstart:send] = self.fetcher(sstart, send)
213+
ranges.append((sstart, send))
189214

190215
# Update set of cached blocks
191216
self.blocks.update(_blocks)
192217
# Update cache statistics with number of blocks we had to cache
193218
self.miss_count += len(_blocks)
194219

220+
if not ranges:
221+
return self.cache[start:end]
222+
223+
if self.multi_fetcher:
224+
logger.debug(f"MMap get blocks {ranges}")
225+
for idx, r in enumerate(self.multi_fetcher(ranges)):
226+
(sstart, send) = ranges[idx]
227+
logger.debug(f"MMap copy block ({sstart}-{send}")
228+
self.cache[sstart:send] = r
229+
else:
230+
for sstart, send in ranges:
231+
logger.debug(f"MMap get block ({sstart}-{send}")
232+
self.cache[sstart:send] = self.fetcher(sstart, send)
233+
195234
return self.cache[start:end]
196235

197236
def __getstate__(self) -> dict[str, Any]:

fsspec/implementations/cached.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,19 @@ def _open(
362362
)
363363
else:
364364
detail["blocksize"] = f.blocksize
365-
f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks)
365+
366+
def _fetch_ranges(ranges):
367+
return self.fs.cat_ranges(
368+
[path] * len(ranges),
369+
[r[0] for r in ranges],
370+
[r[1] for r in ranges],
371+
**kwargs,
372+
)
373+
374+
multi_fetcher = None if self.compression else _fetch_ranges
375+
f.cache = MMapCache(
376+
f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher
377+
)
366378
close = f.close
367379
f.close = lambda: self.close_and_update(f, close)
368380
self.save_cache()

fsspec/tests/test_caches.py

+37
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from fsspec.caching import (
77
BlockCache,
88
FirstChunkCache,
9+
MMapCache,
910
ReadAheadCache,
1011
caches,
1112
register_cache,
@@ -147,6 +148,10 @@ def letters_fetcher(start, end):
147148
return string.ascii_letters[start:end].encode()
148149

149150

151+
def multi_letters_fetcher(ranges):
152+
return [string.ascii_letters[start:end].encode() for start, end in ranges]
153+
154+
150155
not_parts_caches = {k: v for k, v in caches.items() if k != "parts"}
151156

152157

@@ -174,6 +179,38 @@ def test_cache_pickleable(Cache_imp):
174179
assert unpickled._fetch(0, 10) == b"0" * 10
175180

176181

182+
def test_first_cache():
183+
c = FirstChunkCache(5, letters_fetcher, 52)
184+
assert c.cache is None
185+
assert c._fetch(12, 15) == letters_fetcher(12, 15)
186+
assert c.cache is None
187+
assert c._fetch(3, 10) == letters_fetcher(3, 10)
188+
assert c.cache == letters_fetcher(0, 5)
189+
c.fetcher = None
190+
assert c._fetch(1, 4) == letters_fetcher(1, 4)
191+
192+
193+
def test_mmap_cache(mocker):
194+
fetcher = mocker.Mock(wraps=letters_fetcher)
195+
c = MMapCache(5, fetcher, 52)
196+
assert c._fetch(6, 8) == letters_fetcher(6, 8)
197+
assert fetcher.call_count == 1
198+
assert c._fetch(17, 22) == letters_fetcher(17, 22)
199+
assert fetcher.call_count == 2
200+
assert c._fetch(1, 38) == letters_fetcher(1, 38)
201+
assert fetcher.call_count == 5
202+
203+
multi_fetcher = mocker.Mock(wraps=multi_letters_fetcher)
204+
m = MMapCache(5, fetcher, size=52, multi_fetcher=multi_fetcher)
205+
assert m._fetch(6, 8) == letters_fetcher(6, 8)
206+
assert multi_fetcher.call_count == 1
207+
assert m._fetch(17, 22) == letters_fetcher(17, 22)
208+
assert multi_fetcher.call_count == 2
209+
assert m._fetch(1, 38) == letters_fetcher(1, 38)
210+
assert multi_fetcher.call_count == 3
211+
assert fetcher.call_count == 5
212+
213+
177214
@pytest.mark.parametrize(
178215
"size_requests",
179216
[[(0, 30), (0, 35), (51, 52)], [(0, 1), (1, 11), (1, 52)], [(0, 52), (11, 15)]],

0 commit comments

Comments
 (0)