Skip to content

Commit

Permalink
Shutil.copyfile will open source with read share (#265)
Browse files Browse the repository at this point in the history
Opens the source file used in shutil.copyfile with share_access="r" to
ensure that it can be copied even if something else already has it
opened with read access and grants further opens.
  • Loading branch information
jborean93 authored Jan 29, 2024
1 parent a207f77 commit ba9e9bc
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* Added default timeout for disconnect operations for 60 seconds to ensure the process doesn't hang forever when closing a broken connection
* `smbprotocol.connection.Connection.disconnect()` now waits (with a timeout) for the message processing threads to be stopped before returning.
* Do not set the SMB SessionId and TreeId in the headers to `0xFFFFFFFF` for related compound requests
+ Ensures the source file for `shutil.copyfile` is opened with `share_access="r"` for better compatibility with files already opened by something else

## 1.12.0 - 2023-11-09

Expand Down
10 changes: 8 additions & 2 deletions src/smbclient/shutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,19 @@ def copyfile(src, dst, follow_symlinks=True, **kwargs):
symlink_func = symlink
src_open = open_file
src_kwargs = kwargs

# Open the soruce with read file sharing to allow copying
# files already opened.
# https://github.com/jborean93/smbprotocol/issues/258
src_open_kwargs = kwargs.copy()
src_open_kwargs["share_access"] = "r"
else:
src_root = None
islink_func = os.path.islink
readlink_func = os.readlink
symlink_func = os.symlink
src_open = open
src_kwargs = {}
src_kwargs = src_open_kwargs = {}

norm_dst = ntpath.normpath(dst)
if is_remote_path(norm_dst):
Expand Down Expand Up @@ -178,7 +184,7 @@ def copyfile(src, dst, follow_symlinks=True, **kwargs):
return dst

# Finally we are copying across different roots so we just chunk the data using copyfileobj
with src_open(src, mode="rb", **src_kwargs) as src_fd, dst_open(dst, mode="wb", **dst_kwargs) as dst_fd:
with src_open(src, mode="rb", **src_open_kwargs) as src_fd, dst_open(dst, mode="wb", **dst_kwargs) as dst_fd:
copyfileobj(src_fd, dst_fd, MAX_PAYLOAD_SIZE)

return dst
Expand Down
18 changes: 18 additions & 0 deletions tests/test_smbclient_shutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ctypes
import ntpath
import os
import os.path
import re
import shutil
import stat
Expand Down Expand Up @@ -305,6 +306,23 @@ def test_copyfile_remote_to_local(smb_share, tmpdir):
assert fd.read() == "content"


def test_copyfile_remote_to_local_read_share(smb_share, tmpdir):
test_dir = tmpdir.mkdir("test")
src_filename = "%s\\source.txt" % smb_share
dst_filename = os.path.join(test_dir, "target.txt")

with open_file(src_filename, mode="w") as fd:
fd.write("content")

with open_file(src_filename, mode="r", share_access="r") as fd:
actual = copyfile(src_filename, dst_filename)

assert actual == dst_filename

with open(dst_filename) as fd:
assert fd.read() == "content"


def test_copyfile_local_to_local(tmpdir):
test_dir = tmpdir.mkdir("test")
src_filename = "%s\\source.txt" % test_dir
Expand Down

0 comments on commit ba9e9bc

Please sign in to comment.