Skip to content

Commit

Permalink
Add some type annotations, most notably to smbclient.open_file (#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
mon authored Oct 18, 2024
1 parent b6affa7 commit a0b8544
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 7 deletions.
136 changes: 135 additions & 1 deletion src/smbclient/_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,140 @@ def makedirs(path, exist_ok=False, **kwargs):
create_queue.pop(-1)


# Taken from stdlib typeshed but removed the unused 'U' flag
OpenTextModeUpdating: t.TypeAlias = t.Literal[
"r+",
"+r",
"rt+",
"r+t",
"+rt",
"tr+",
"t+r",
"+tr",
"w+",
"+w",
"wt+",
"w+t",
"+wt",
"tw+",
"t+w",
"+tw",
"a+",
"+a",
"at+",
"a+t",
"+at",
"ta+",
"t+a",
"+ta",
"x+",
"+x",
"xt+",
"x+t",
"+xt",
"tx+",
"t+x",
"+tx",
]
OpenTextModeWriting: t.TypeAlias = t.Literal["w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"]
OpenTextModeReading: t.TypeAlias = t.Literal["r", "rt", "tr"]
OpenTextMode: t.TypeAlias = t.Literal[OpenTextModeUpdating, OpenTextModeWriting, OpenTextModeReading]
OpenBinaryModeUpdating: t.TypeAlias = t.Literal[
"rb+",
"r+b",
"+rb",
"br+",
"b+r",
"+br",
"wb+",
"w+b",
"+wb",
"bw+",
"b+w",
"+bw",
"ab+",
"a+b",
"+ab",
"ba+",
"b+a",
"+ba",
"xb+",
"x+b",
"+xb",
"bx+",
"b+x",
"+bx",
]
OpenBinaryModeWriting: t.TypeAlias = t.Literal["wb", "bw", "ab", "ba", "xb", "bx"]
OpenBinaryModeReading: t.TypeAlias = t.Literal["rb", "br"]
OpenBinaryMode: t.TypeAlias = t.Literal[OpenBinaryModeUpdating, OpenBinaryModeReading, OpenBinaryModeWriting]
FileType: t.TypeAlias = t.Literal["file", "dir", "pipe"]


# Text mode: always returns a TextIOWrapper
@t.overload
def open_file(
path,
mode: OpenTextMode = "r",
buffering=-1,
file_type: FileType = "file",
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
**kwargs,
) -> io.TextIOWrapper[io.BufferedRandom | io.BufferedReader | io.BufferedWriter]: ...


# Otherwise return BufferedRandom, BufferedReader, or BufferedWriter
# NOTE: This incorrectly returns unbuffered opens as Buffered types, due to difficulties
# in annotating that case
@t.overload
def open_file(
path,
mode: OpenBinaryModeUpdating,
buffering=-1,
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
file_type: FileType = "file",
**kwargs,
) -> io.BufferedRandom: ...
@t.overload
def open_file(
path,
mode: OpenBinaryModeReading,
buffering=-1,
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
file_type: FileType = "file",
**kwargs,
) -> io.BufferedReader: ...
@t.overload
def open_file(
path,
mode: OpenBinaryModeWriting,
buffering=-1,
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
file_type: FileType = "file",
**kwargs,
) -> io.BufferedWriter: ...


def open_file(
path,
mode="r",
Expand All @@ -325,7 +459,7 @@ def open_file(
share_access=None,
desired_access=None,
file_attributes=None,
file_type="file",
file_type: t.Literal["file", "dir", "pipe"] = "file",
**kwargs,
):
"""
Expand Down
11 changes: 6 additions & 5 deletions src/smbclient/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import ntpath
import uuid
from typing import Literal, Optional

from smbprotocol._text import to_text
from smbprotocol.connection import Capabilities, Connection
Expand Down Expand Up @@ -366,14 +367,14 @@ def get_smb_tree(


def register_session(
server,
username=None,
password=None,
server: str,
username: Optional[str] = None,
password: Optional[str] = None,
port=445,
encrypt=None,
encrypt: Optional[bool] = None,
connection_timeout=60,
connection_cache=None,
auth_protocol="negotiate",
auth_protocol: Literal["negotiate", "ntlm", "kerberos"] = "negotiate",
require_signing=True,
):
"""
Expand Down
10 changes: 9 additions & 1 deletion src/smbprotocol/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import random
from collections import OrderedDict
from typing import Literal, Optional

import spnego
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -170,7 +171,14 @@ def __init__(self):


class Session:
def __init__(self, connection, username=None, password=None, require_encryption=True, auth_protocol="negotiate"):
def __init__(
self,
connection,
username: Optional[str] = None,
password: Optional[str] = None,
require_encryption=True,
auth_protocol: Literal["negotiate", "ntlm", "kerberos"] = "negotiate",
):
"""
[MS-SMB2] v53.0 2017-09-15
Expand Down

0 comments on commit a0b8544

Please sign in to comment.