Skip to content

Commit

Permalink
ThreadSafeFile
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 19, 2024
1 parent f3d7de9 commit 249460f
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions src/beignet/io/_thread_safe_file.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,51 @@
import operator
import threading
from pathlib import Path
from typing import Any, Callable, TypeVar, Union
from os import PathLike
from typing import Any, Callable, TypeVar

T = TypeVar("T")


class ThreadSafeFile:
"""
Share file objects (e.g., raw binary, buffered binary, and text) between
threads by storing the file object in thread-local storage (TLS).
"""

def __init__(
self,
path: Union[str, Path],
open_function: Callable[[Union[str, Path]], T] = operator.methodcaller(
"open"
),
close_function: Callable[[T], None] = operator.methodcaller("close"),
path: str | PathLike,
open: Callable[[str | PathLike], T] = operator.methodcaller("open"),
close: Callable[[T], None] = operator.methodcaller("close"),
) -> None:
"""
Share file objects (i.e., raw binary files, buffered binary files, and
text files) between threads by storing the file object in thread-local
storage (TLS).
"""
self._local = threading.local()

self._path = path
self._open_function = open_function
self._close_function = close_function

def __getattr__(self, name: str) -> Any:
return getattr(self.file, name)
self.path = path

@property
def file(self) -> T:
if not hasattr(self._local, "file"):
self._local.file = self._open_function(self._path)
self.open = open

self.close = close

return self._local.file
self.storage = threading.local()

def __del__(self) -> None:
if hasattr(self._local, "file"):
self._close_function(self._local.file)
if hasattr(self.storage, "file"):
self.close(self.storage.file)

del self._local.file
del self.storage.file

def __getattr__(self, name: str) -> object:
return getattr(self.file, name)

def __getstate__(self) -> dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if k != "_local"}

def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__ = state

self._local = threading.local()
self.storage = threading.local()

@property
def file(self) -> T:
if not hasattr(self.storage, "file"):
self.storage.file = self.open(self.path)

return self.storage.file

0 comments on commit 249460f

Please sign in to comment.