37
37
logger = logging .getLogger ("fsspec" )
38
38
39
39
Fetcher = Callable [[int , int ], bytes ] # Maps (start, end) to bytes
40
+ MultiFetcher = Callable [list [[int , int ]], bytes ] # Maps [(start, end)] to bytes
40
41
41
42
42
43
class BaseCache :
@@ -109,6 +110,26 @@ class MMapCache(BaseCache):
109
110
Ensure there is enough disc space in the temporary location.
110
111
111
112
This cache method might only work on posix
113
+
114
+ Parameters
115
+ ----------
116
+ blocksize: int
117
+ How far to read ahead in numbers of bytes
118
+ fetcher: Fetcher
119
+ Function of the form f(start, end) which gets bytes from remote as
120
+ specified
121
+ size: int
122
+ How big this file is
123
+ location: str
124
+ Where to create the temporary file. If None, a temporary file is
125
+ created using tempfile.TemporaryFile().
126
+ blocks: set[int]
127
+ Set of block numbers that have already been fetched. If None, an empty
128
+ set is created.
129
+ multi_fetcher: MultiFetcher
130
+ Function of the form f([(start, end)]) which gets bytes from remote
131
+ as specified. This function is used to fetch multiple blocks at once.
132
+ If not specified, the fetcher function is used instead.
112
133
"""
113
134
114
135
name = "mmap"
@@ -120,10 +141,12 @@ def __init__(
120
141
size : int ,
121
142
location : str | None = None ,
122
143
blocks : set [int ] | None = None ,
144
+ multi_fetcher : MultiFetcher | None = None ,
123
145
) -> None :
124
146
super ().__init__ (blocksize , fetcher , size )
125
147
self .blocks = set () if blocks is None else blocks
126
148
self .location = location
149
+ self .multi_fetcher = multi_fetcher
127
150
self .cache = self ._makefile ()
128
151
129
152
def _makefile (self ) -> mmap .mmap | bytearray :
@@ -164,6 +187,8 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
164
187
# Count the number of blocks already cached
165
188
self .hit_count += sum (1 for i in block_range if i in self .blocks )
166
189
190
+ ranges = []
191
+
167
192
# Consolidate needed blocks.
168
193
# Algorithm adapted from Python 2.x itertools documentation.
169
194
# We are grouping an enumerated sequence of blocks. By comparing when the difference
@@ -185,13 +210,27 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
185
210
logger .debug (
186
211
f"MMap get blocks { _blocks [0 ]} -{ _blocks [- 1 ]} ({ sstart } -{ send } )"
187
212
)
188
- self . cache [ sstart : send ] = self . fetcher ( sstart , send )
213
+ ranges . append (( sstart , send ) )
189
214
190
215
# Update set of cached blocks
191
216
self .blocks .update (_blocks )
192
217
# Update cache statistics with number of blocks we had to cache
193
218
self .miss_count += len (_blocks )
194
219
220
+ if not ranges :
221
+ return self .cache [start :end ]
222
+
223
+ if self .multi_fetcher :
224
+ logger .debug (f"MMap get blocks { ranges } " )
225
+ for idx , r in enumerate (self .multi_fetcher (ranges )):
226
+ (sstart , send ) = ranges [idx ]
227
+ logger .debug (f"MMap copy block ({ sstart } -{ send } " )
228
+ self .cache [sstart :send ] = r
229
+ else :
230
+ for sstart , send in ranges :
231
+ logger .debug (f"MMap get block ({ sstart } -{ send } " )
232
+ self .cache [sstart :send ] = self .fetcher (sstart , send )
233
+
195
234
return self .cache [start :end ]
196
235
197
236
def __getstate__ (self ) -> dict [str , Any ]:
0 commit comments