diff --git a/.gitignore b/.gitignore index e51a4491..471aade6 100644 --- a/.gitignore +++ b/.gitignore @@ -192,5 +192,9 @@ docker.env # ignore Infisical config file .infisical.json +# ignore google drive api credentials +credentials.json +token.json + # ignore default docq extensions .docq-extensions.json diff --git a/misc/docker.env.template b/misc/docker.env.template index a803e9dc..45e64359 100644 --- a/misc/docker.env.template +++ b/misc/docker.env.template @@ -11,6 +11,10 @@ DOCQ_SMTP_LOGIN="SMTP-LOGIN" # The username for logging in to the SMTP service DOCQ_SMTP_KEY="SMTP-MASTER-PASSWORD" # The password for logging in to the SMTP service DOCQ_SMTP_FROM="Docq.AI Support " # A custom sender email +# GOOGLE_DRIVE_API +DOCQ_GOOGLE_APPLICATION_CREDENTIALS=credentials.json # Credentials containing Docq project configs +DOCQ_GOOGLE_AUTH_REDIRECT_URL=http://localhost:8501/authorize_gdrive # The URL configured for redirect in google console. + # SERVER SETTINGS DOCQ_SERVER_ADDRESS = "http://localhost:8501" # Web address for the docq server, used for generating verification urls. diff --git a/misc/secrets.toml.template b/misc/secrets.toml.template index a720ac49..5373a981 100644 --- a/misc/secrets.toml.template +++ b/misc/secrets.toml.template @@ -9,6 +9,10 @@ DOCQ_SMTP_LOGIN = "SMTP-LOGIN" # The username for logging in to the SMTP service DOCQ_SMTP_KEY = "SMTP-MASTER-PASSWORD" # The password for logging in to the SMTP service DOCQ_SMTP_FROM = "Docq.AI Support " # A custom sender email +# GOOGLE_DRIVE_API +DOCQ_GOOGLE_APPLICATION_CREDENTIALS = "credentials.json" # Credentials containing Docq project configs +DOCQ_GOOGLE_AUTH_REDIRECT_URL = "http://localhost:8501/authorize_gdrive" # The URL configured for redirect in google console. + # SERVER SETTINGS DOCQ_SERVER_ADDRESS = "http://localhost:8501" # Web address for the docq server, used for generating verification urls. OTEL_SERVICE_NAME = "docq-" #for local dev "docq-dev-". Prod "docq-prod" diff --git a/poetry.lock b/poetry.lock index 5ba1465e..ecb87a3b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -1171,6 +1171,102 @@ gitdb = ">=4.0.1,<5" [package.extras] test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"] +[[package]] +name = "google-api-core" +version = "2.14.0" +description = "Google API client core library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google-api-core-2.14.0.tar.gz", hash = "sha256:5368a4502b793d9bbf812a5912e13e4e69f9bd87f6efb508460c43f5bbd1ce41"}, + {file = "google_api_core-2.14.0-py3-none-any.whl", hash = "sha256:de2fb50ed34d47ddbb2bd2dcf680ee8fead46279f4ed6b16de362aca23a18952"}, +] + +[package.dependencies] +google-auth = ">=2.14.1,<3.0.dev0" +googleapis-common-protos = ">=1.56.2,<2.0.dev0" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" +requests = ">=2.18.0,<3.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"] +grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] + +[[package]] +name = "google-api-python-client" +version = "2.107.0" +description = "Google API Client Library for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google-api-python-client-2.107.0.tar.gz", hash = "sha256:ef6d4c1a17fe9ec0894fc6d4f61e751c4b859fb33f2ab5b881ceb0b80ba442ba"}, + {file = "google_api_python_client-2.107.0-py2.py3-none-any.whl", hash = "sha256:51d7bf676f41a77b00b7b9c72ace0c1db3dd5a4dd392a13ae897cf4f571a3539"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0" +google-auth = ">=1.19.0,<3.0.0.dev0" +google-auth-httplib2 = ">=0.1.0" +httplib2 = ">=0.15.0,<1.dev0" +uritemplate = ">=3.0.1,<5" + +[[package]] +name = "google-auth" +version = "2.23.4" +description = "Google Authentication Library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google-auth-2.23.4.tar.gz", hash = "sha256:79905d6b1652187def79d491d6e23d0cbb3a21d3c7ba0dbaa9c8a01906b13ff3"}, + {file = "google_auth-2.23.4-py2.py3-none-any.whl", hash = "sha256:d4bbc92fe4b8bfd2f3e8d88e5ba7085935da208ee38a134fc280e7ce682a05f2"}, +] + +[package.dependencies] +cachetools = ">=2.0.0,<6.0" +pyasn1-modules = ">=0.2.1" +rsa = ">=3.1.4,<5" + +[package.extras] +aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] +enterprise-cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"] +pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] +reauth = ["pyu2f (>=0.1.5)"] +requests = ["requests (>=2.20.0,<3.0.0.dev0)"] + +[[package]] +name = "google-auth-httplib2" +version = "0.1.1" +description = "Google Authentication Library: httplib2 transport" +optional = false +python-versions = "*" +files = [ + {file = "google-auth-httplib2-0.1.1.tar.gz", hash = "sha256:c64bc555fdc6dd788ea62ecf7bccffcf497bf77244887a3f3d7a5a02f8e3fc29"}, + {file = "google_auth_httplib2-0.1.1-py2.py3-none-any.whl", hash = "sha256:42c50900b8e4dcdf8222364d1f0efe32b8421fb6ed72f2613f12f75cc933478c"}, +] + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.19.0" + +[[package]] +name = "google-auth-oauthlib" +version = "1.1.0" +description = "Google Authentication Library" +optional = false +python-versions = ">=3.6" +files = [ + {file = "google-auth-oauthlib-1.1.0.tar.gz", hash = "sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb"}, + {file = "google_auth_oauthlib-1.1.0-py2.py3-none-any.whl", hash = "sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12"}, +] + +[package.dependencies] +google-auth = ">=2.15.0" +requests-oauthlib = ">=0.7.0" + +[package.extras] +tool = ["click (>=6.0.0)"] + [[package]] name = "googleapis-common-protos" version = "1.61.0" @@ -1341,6 +1437,20 @@ opentelemetry-exporter-otlp = "1.20.0" opentelemetry-instrumentation = "0.41b0" opentelemetry-sdk = "1.20.0" +[[package]] +name = "httplib2" +version = "0.22.0" +description = "A comprehensive HTTP client library." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, + {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"}, +] + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "huggingface-hub" version = "0.19.0" @@ -1699,16 +1809,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2145,6 +2245,22 @@ files = [ {file = "numpy-1.25.2.tar.gz", hash = "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760"}, ] +[[package]] +name = "oauthlib" +version = "3.2.2" +description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" +optional = false +python-versions = ">=3.6" +files = [ + {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, + {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, +] + +[package.extras] +rsa = ["cryptography (>=3.0.0)"] +signals = ["blinker (>=1.4.0)"] +signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] + [[package]] name = "onnx" version = "1.15.0" @@ -2251,15 +2367,15 @@ wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1 [[package]] name = "opendal" -version = "0.38.1" +version = "0.41.0" description = "OpenDAL Python Binding" optional = false python-versions = ">=3.7" files = [ - {file = "opendal-0.38.1-cp37-abi3-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:95e008199d33387a2cdcb004d314f9a3f132ff6ceb5cafafb90e5722bdd53fe5"}, - {file = "opendal-0.38.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a6b775f04b240d05ba59ffea995cc509fe6f0dadc3fcf893f439d7d66e4190e"}, - {file = "opendal-0.38.1-cp37-abi3-win_amd64.whl", hash = "sha256:279353e91b754ed6f93a6b5f26928b77cdc16be7adbc3c1cc60364a114ffb490"}, - {file = "opendal-0.38.1.tar.gz", hash = "sha256:9214bcc64671494734a37f03cf1fe040dd21469fad2bd476b29e4d4d11923cfb"}, + {file = "opendal-0.41.0-cp37-abi3-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:db45b4915ad275f40aa27de2413299fcaf07fec7f77b7f8ab2f98e81a9298c3c"}, + {file = "opendal-0.41.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:314c1cb4140aa9b6da32c9e0a193247628dbd7ebcda427a709bafbbbc1351f0c"}, + {file = "opendal-0.41.0-cp37-abi3-win_amd64.whl", hash = "sha256:ba5abec90218220a41d64d910114f4864d11593be3f51475f37cb67c9b5e4a03"}, + {file = "opendal-0.41.0.tar.gz", hash = "sha256:d5761ed82653a453650233e58372cda40df7708b6f29144414e05675127469b5"}, ] [package.extras] @@ -3215,6 +3331,31 @@ files = [ [package.dependencies] numpy = ">=1.16.6" +[[package]] +name = "pyasn1" +version = "0.5.0" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"}, + {file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"}, +] + +[[package]] +name = "pyasn1-modules" +version = "0.3.0" +description = "A collection of ASN.1-based protocols modules" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"}, + {file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"}, +] + +[package.dependencies] +pyasn1 = ">=0.4.6,<0.6.0" + [[package]] name = "pycparser" version = "2.21" @@ -3425,6 +3566,20 @@ files = [ {file = "Pympler-1.0.1.tar.gz", hash = "sha256:993f1a3599ca3f4fcd7160c7545ad06310c9e12f70174ae7ae8d4e25f6c5d3fa"}, ] +[[package]] +name = "pyparsing" +version = "3.1.1" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.1.1-py3-none-any.whl", hash = "sha256:32c7c0b711493c72ff18a981d24f28aaf9c1fb7ed5e9667c9e84e3db623bdbfb"}, + {file = "pyparsing-3.1.1.tar.gz", hash = "sha256:ede28a1a32462f5a9705e07aea48001a08f7cf81a021585011deba701581a0db"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pypdf" version = "3.17.0" @@ -3837,6 +3992,24 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "requests-oauthlib" +version = "1.3.1" +description = "OAuthlib authentication support for Requests." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, + {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, +] + +[package.dependencies] +oauthlib = ">=3.0.0" +requests = ">=2.0.0" + +[package.extras] +rsa = ["oauthlib[signedtoken] (>=3.0.0)"] + [[package]] name = "responses" version = "0.18.0" @@ -3981,6 +4154,20 @@ files = [ {file = "rpds_py-0.12.0.tar.gz", hash = "sha256:7036316cc26b93e401cedd781a579be606dad174829e6ad9e9c5a0da6e036f80"}, ] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = false +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + [[package]] name = "ruff" version = "0.1.2" @@ -4426,7 +4613,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} typing-extensions = ">=4.2.0" [package.extras] @@ -4973,6 +5160,17 @@ tzdata = {version = "*", markers = "platform_system == \"Windows\""} [package.extras] devenv = ["black", "check-manifest", "flake8", "pyroma", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"] +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +optional = false +python-versions = ">=3.6" +files = [ + {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, + {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, +] + [[package]] name = "urllib3" version = "1.26.18" @@ -5420,4 +5618,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "9c3a98749501c186b314c6f73ad63ba5f768f69d4ac2090790e2bafed89ba9b7" +content-hash = "f193db94ef3a811dafdf8bac6e52390ea5f5bd6a94e78e3000656ec175f2eb84" diff --git a/pyproject.toml b/pyproject.toml index 77a8e4b3..2fdd7edc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ pypdf = "^3.9.0" docx2txt = "^0.8" argon2-cffi = "^21.3.0" azure-core = "^1.27.1" -opendal = "^0.38.1" +opendal = "^0.41.0" llama-index = "0.8.34" transformers = "4.33.2" optimum = { extras = ["exporters"], version = "1.13.2" } @@ -43,6 +43,9 @@ opentelemetry-instrumentation-sqlalchemy = "0.41b0" opentelemetry-instrumentation-tornado = "0.41b0" opentelemetry-instrumentation-tortoiseorm = "0.41b0" opentelemetry-instrumentation-urllib3 = "0.41b0" +google-auth-oauthlib = "^1.1.0" +google-api-python-client = "^2.104.0" +google-auth-httplib2 = "^0.1.1" [tool.poetry.group.dev.dependencies] pre-commit = "^2.18.1" @@ -79,7 +82,7 @@ check = true [tool.ruff] -ignore = ["E501"] +ignore = ["E501", "ANN401"] line-length = 120 select = [ "C9", diff --git a/source/docq/data_source/googledrive.py b/source/docq/data_source/googledrive.py new file mode 100644 index 00000000..07bdd681 --- /dev/null +++ b/source/docq/data_source/googledrive.py @@ -0,0 +1,104 @@ +"""Google drive datasource.""" +import json +import logging as log +from datetime import datetime +from typing import Any, List, Self + +from llama_index import Document + +from .. import services +from ..domain import ConfigKey, SpaceKey +from ..support.store import get_index_dir +from .main import DocumentMetadata, FileStorageServiceKeys, SpaceDataSourceFileBased +from .support.opendal_reader.base import GoogleDriveReader, OpendalReader + + +class GDrive(SpaceDataSourceFileBased): + """Space data source for Google Drive.""" + + def __init__(self: Self) -> None: + """Initialize the data source.""" + super().__init__("Google Drive") + self.credential = f"{FileStorageServiceKeys.GOOGLE_DRIVE.name}-credential" + self.root_path = f"{FileStorageServiceKeys.GOOGLE_DRIVE.name}-root_path" + + def list_folders(self: Self, configs: Any, state: dict) -> tuple[list, bool]: + """List google drive folders.""" + __creds = configs.get(self.credential) if configs else state.get(self.credential, None) + creds = services.google_drive.validate_credentials(__creds) + + return (services.google_drive.list_folders(creds), False) if creds else ([], False) + + @property + def disabled(self: Self) -> bool: + """Disable the data source.""" + return not services.google_drive.api_enabled() + + def get_config_keys(self: Self) -> List[ConfigKey]: + """Get the config keys for google drive.""" + return [ + ConfigKey( + self.credential, + "Credential", + is_secret=True, + ref_link="https://docqai.github.io/docq/user-guide/config-spaces/#data-source-google-drive", + options={ + "type": "credential", + "handler": services.google_drive.get_auth_url, + "btn_label": "Sign in with Google", + } + ), + ConfigKey( + self.root_path, + "Select a folder", + ref_link="https://docqai.github.io/docq/user-guide/config-spaces/#data-source-google-drive", + options={ + "type": "root_path", + "handler": self.list_folders, + "format_function": lambda x: x['name'], + } + ), + ] + + def load(self: Self, space: SpaceKey, configs: dict) -> list[Document] | None: + """Load the documents from google drive.""" + + def lambda_metadata(x: str) -> dict: + return { + str(DocumentMetadata.FILE_PATH.name).lower(): x, + str(DocumentMetadata.SPACE_ID.name).lower(): space.id_, + str(DocumentMetadata.SPACE_TYPE.name).lower(): space.type_.name, + str(DocumentMetadata.DATA_SOURCE_NAME.name).lower(): self.get_name(), + str(DocumentMetadata.DATA_SOURCE_TYPE.name).lower(): self.__class__.__base__.__name__, + str(DocumentMetadata.SOURCE_URI.name).lower(): x, + str(DocumentMetadata.INDEXED_ON.name).lower(): datetime.timestamp(datetime.now().utcnow()), + } + + root_path = configs[self.root_path] + + options = { + "root": root_path["name"], + "access_token": json.dumps(configs[self.credential]), + } + + try: + loader = OpendalReader( + scheme="gdrive", + file_metadata=lambda_metadata, + **options, + ) + except Exception as e: + log.error("Failed to load google drive with opendal reader: %s", e) + loader = GoogleDriveReader( + file_metadata=lambda_metadata, + root=root_path["name"], + access_token=configs[self.credential], + selected_folder_id=root_path["id"] + ) + + documents = loader.load_data() + file_list = loader.get_document_list() + log.debug("Loaded %s documents from google drive", len(file_list)) + persist_path = get_index_dir(space) + self._save_document_list(file_list, persist_path, self._DOCUMENT_LIST_FILENAME) + return documents diff --git a/source/docq/data_source/list.py b/source/docq/data_source/list.py index d5ecca27..752baa9a 100644 --- a/source/docq/data_source/list.py +++ b/source/docq/data_source/list.py @@ -4,6 +4,7 @@ from .aws_s3 import AwsS3 from .azure_blob import AzureBlob +from .googledrive import GDrive from .knowledge_base_scraper import KnowledgeBaseScraper from .manual_upload import ManualUpload from .web_scraper import WebScraper @@ -17,3 +18,4 @@ class SpaceDataSources(Enum): AWS_S3 = AwsS3() WEB_SCRAPER = WebScraper() KNOWLEDGE_BASE_SCRAPER = KnowledgeBaseScraper() + GOOGLE_DRIVE = GDrive() diff --git a/source/docq/data_source/main.py b/source/docq/data_source/main.py index c319e620..3a8f9cd3 100644 --- a/source/docq/data_source/main.py +++ b/source/docq/data_source/main.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from dataclasses import asdict from enum import Enum -from typing import List, Self +from typing import Any, Callable, List, Literal, Self from llama_index import Document from opentelemetry import trace @@ -27,6 +27,15 @@ class DocumentMetadata(Enum): SOURCE_URI = "Source URI" +class FileStorageServiceKeys(Enum): + """File storage service keys.""" + + GOOGLE_DRIVE = "Google Drive" + ONEDRIVE = "OneDrive" + DROPBOX = "Dropbox" + BOX = "Box" + + trace = trace.get_tracer("docq.api.data_source") @@ -41,20 +50,25 @@ def get_name(self: Self) -> str: """Get the name of the data source.""" return self.name + @property + def disabled(self: Self) -> bool: + """Disable the data source.""" + return False + @abstractmethod - def get_config_keys(self) -> List[ConfigKey]: + def get_config_keys(self: Self) -> List[ConfigKey]: """Get the list of config keys.""" pass @abstractmethod @trace.start_as_current_span("SpaceDataSource.load") - def load(self, space: SpaceKey, configs: dict) -> List[Document]: + def load(self: Self, space: SpaceKey, configs: dict) -> List[Document]: """Load the documents from the data source.""" pass @abstractmethod @trace.start_as_current_span("SpaceDataSource.get_document_list") - def get_document_list(self, space: SpaceKey, configs: dict) -> List[DocumentListItem]: + def get_document_list(self: Self, space: SpaceKey, configs: dict) -> List[DocumentListItem]: """Returns a list of tuples containing the name, creation time, and size (Mb) of each document in the specified space's cnfigured data source. Args: @@ -72,13 +86,13 @@ class SpaceDataSourceFileBased(SpaceDataSource): _DOCUMENT_LIST_FILENAME = "document_list.json" - def get_document_list(self, space: SpaceKey, configs: dict) -> List[DocumentListItem]: + def get_document_list(self: Self, space: SpaceKey, configs: dict) -> List[DocumentListItem]: """Get the list of documents.""" persist_path = get_index_dir(space) return self._load_document_list(persist_path, self._DOCUMENT_LIST_FILENAME) @trace.start_as_current_span("SpaceDataSourceFileBased._save_document_list") - def _save_document_list(self, document_list: List[DocumentListItem], persist_path: str, filename: str) -> None: + def _save_document_list(self: Self, document_list: List[DocumentListItem], persist_path: str, filename: str) -> None: path = os.path.join(persist_path, filename) try: data = [asdict(item) for item in document_list] @@ -94,7 +108,7 @@ def _save_document_list(self, document_list: List[DocumentListItem], persist_pat log.error("Failed to save space index document list to '%s': %s", path, e, stack_info=True) @trace.start_as_current_span("SpaceDataSourceFileBased._load_document_list") - def _load_document_list(self, persist_path: str, filename: str) -> List[DocumentListItem]: + def _load_document_list(self: Self, persist_path: str, filename: str) -> List[DocumentListItem]: path = os.path.join(persist_path, filename) with open(path, "r") as f: data = json.load(f) @@ -104,7 +118,7 @@ def _load_document_list(self, persist_path: str, filename: str) -> List[Document @trace.start_as_current_span("SpaceDataSourceFileBased._add_exclude_metadata_keys") def _add_exclude_metadata_keys( - self, documents: List[Document], embed_keys: List[str], llm_keys: List[str] + self: Self, documents: List[Document], embed_keys: List[str], llm_keys: List[str] ) -> List[Document]: """Exclude metadata keys from embedding and LLM.""" if documents is None: diff --git a/source/docq/data_source/support/opendal_reader/base.py b/source/docq/data_source/support/opendal_reader/base.py index 7573979f..ec77a0a2 100644 --- a/source/docq/data_source/support/opendal_reader/base.py +++ b/source/docq/data_source/support/opendal_reader/base.py @@ -30,7 +30,7 @@ import tempfile from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Type, Union, cast +from typing import Any, Callable, Dict, List, Optional, Self, Type, Union, cast import opendal from llama_index.readers.base import BaseReader @@ -43,8 +43,9 @@ from llama_index.readers.file.slides_reader import PptxReader from llama_index.readers.file.tabular_reader import PandasCSVReader from llama_index.readers.file.video_audio_reader import VideoAudioReader -from llama_index.readers.schema.base import Document +from llama_index.schema import Document +from .... import services from ....domain import DocumentListItem DEFAULT_FILE_READER_CLS: Dict[str, Type[BaseReader]] = { @@ -63,17 +64,39 @@ ".ipynb": IPYNBReader, } +FILE_MIME_EXTENSION_MAP: Dict[str, str] = { + "application/pdf": ".pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + "application/vnd.google-apps.document": ".gdoc", + "application/vnd.google-apps.presentation": ".gslides", + "application/vnd.google-apps.spreadsheet": ".gsheet", + "image/jpeg": ".jpg", + "image/png": ".png", + "image/jpg": ".jpg", + "audio/mpeg": ".mp3", + "audio/mp3": ".mp3", + "video/mp4": ".mp4", + "video/mpeg": ".mp4", + "text/csv": ".csv", + "application/epub+zip": ".epub", + "text/markdown": ".md", + "application/x-ipynb+json": ".ipynb", + "application/mbox": ".mbox", +} + class OpendalReader(BaseReader): """General reader for any opendal operator.""" def __init__( - self, + self: Self, scheme: str, path: str = "/", file_extractor: Optional[Dict[str, Union[str, BaseReader]]] = None, file_metadata: Optional[Callable[[str], Dict]] = None, - **kwargs: Optional[dict[str, any]], + **kwargs: Optional[dict[str, Any]], ) -> None: """Initialize opendal operator, along with credentials if needed. @@ -102,7 +125,7 @@ def __init__( self.documents: List[Document] = [] - def load_data(self) -> List[Document]: + def load_data(self: Self) -> List[Document]: """Load file(s) from OpenDAL.""" # TODO: think about the private and secure aspect of this temp folder. # NOTE: the following code cleans up the temp folder when existing the context. @@ -124,7 +147,78 @@ def load_data(self) -> List[Document]: return self.documents - def get_document_list(self) -> List[DocumentListItem]: + def get_document_list(self: Self) -> List[DocumentListItem]: + """Get a list of all documents in the index. A document is a list are 1:1 with a file.""" + dl: List[DocumentListItem] = [] + try: + for df in self.downloaded_files: + dl.append(DocumentListItem(link=df[0], indexed_on=df[2], size=df[3])) + except Exception as e: + log.exception("Converting Document list to DocumentListItem list failed: %s", e) + + return dl + + +# TODO: Tobe removed once opendal starts supporting Google Drive. +class GoogleDriveReader(BaseReader): + """Google Drive reader.""" + + def __init__( + self: Self, + access_token: dict, + root: str, + selected_folder_id: Optional[str] = None, + path: str = "/", + file_extractor: Optional[Dict[str, Union[str, BaseReader]]] = None, + file_metadata: Optional[Callable[[str], Dict]] = None, + ) -> None: + """Initialize Google Drive reader. + + Args: + path (str): the path of the data. If none is provided, + this loader will iterate through the entire bucket. If path is endswith `/`, this loader will iterate through the entire dir. Otherwise, this loader will load the file. + access_token (dict): the access token for the google drive service + root (str): the root folder to start the iteration + selected_folder_id (Optional[str] = None): the selected folder id + file_extractor (Optional[Dict[str, BaseReader]]): A mapping of file + extension to a BaseReader class that specifies how to convert that file + to text. NOTE: this isn't implemented yet. + file_metadata (Optional[Callable[[str], Dict]]): A function that takes a source file path and returns a dictionary of metadata to be added to the Document object. + """ + super().__init__() + self.path = path + self.file_extractor = file_extractor if file_extractor is not None else {} + self.supported_suffix = list(DEFAULT_FILE_READER_CLS.keys()) + self.access_token = access_token + self.root = root + self.file_metadata = file_metadata + self.selected_folder_id = selected_folder_id + self.documents: List[Document] = [] + self.downloaded_files = [] + + def load_data(self: Self) -> List[Document]: + """Load file(s) from Google Drive.""" + service = services.google_drive.get_drive_service(self.access_token) + id_ = self.selected_folder_id if self.selected_folder_id is not None else "root" + folder_content = service.files().list( + q=f"'{id_}' in parents and trashed=false", + fields="files(id, name, parents, mimeType, modifiedTime, webViewLink, webContentLink, size, fullFileExtension)", + ).execute() + files = folder_content.get("files", []) + with tempfile.TemporaryDirectory() as temp_dir: + self.downloaded_files = asyncio.run( + download_from_gdrive(files, temp_dir, service) + ) + + self.documents = asyncio.run( + extract_files( + self.downloaded_files, file_extractor=self.file_extractor, file_metadata=self.file_metadata + ) + ) + + return self.documents + + def get_document_list(self: Self) -> List[DocumentListItem]: """Get a list of all documents in the index. A document is a list are 1:1 with a file.""" dl: List[DocumentListItem] = [] try: @@ -136,6 +230,30 @@ def get_document_list(self) -> List[DocumentListItem]: return dl +async def download_from_gdrive(files: List[dict], temp_dir: str, service: Any,) -> List[tuple[str, str, int, int]]: + """Download files from Google Drive.""" + downloaded_files: List[tuple[str, str, int, int]] = [] + + for file in files: + if file["mimeType"] == "application/vnd.google-apps.folder": + # TODO: Implement recursive folder download + continue + suffix = FILE_MIME_EXTENSION_MAP.get(file["mimeType"], None) + if suffix not in DEFAULT_FILE_READER_CLS: + continue + + file_path = f"{temp_dir}/{file['name']}" + indexed_on = datetime.timestamp(datetime.now().utcnow()) + await asyncio.to_thread( + services.google_drive.download_file, service, file["id"], file_path, file["mimeType"] + ) + downloaded_files.append( + (file["webViewLink"], file_path, int(indexed_on), int(file["size"])) + ) + + return downloaded_files + + async def download_file_from_opendal(op: Any, temp_dir: str, path: str) -> tuple[str, int, int]: """Download file from OpenDAL.""" import opendal @@ -144,7 +262,7 @@ async def download_file_from_opendal(op: Any, temp_dir: str, path: str) -> tuple op = cast(opendal.AsyncOperator, op) suffix = Path(path).suffix - filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore file_size = 0 indexed_on = datetime.timestamp(datetime.now().utcnow()) async with op.open_reader(path) as r: @@ -153,7 +271,7 @@ async def download_file_from_opendal(op: Any, temp_dir: str, path: str) -> tuple w.write(b) file_size = len(b) - return (filepath, indexed_on, file_size) + return (filepath, int(indexed_on), file_size) async def download_dir_from_opendal( diff --git a/source/docq/domain.py b/source/docq/domain.py index 8bf22389..2e5011f0 100644 --- a/source/docq/domain.py +++ b/source/docq/domain.py @@ -4,7 +4,7 @@ import sys from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional +from typing import Any, Optional, Self from .config import OrganisationFeatureType, SpaceType @@ -24,10 +24,12 @@ class FeatureKey: type_: OrganisationFeatureType id_: int - def __str__(self) -> str: + def __str__(self: Self) -> str: + """Returns the string representation of the feature key.""" return _join_properties(_SEPARATOR_FOR_STR, self.type_.name, self.id_) - def value(self) -> str: + def value(self: Self) -> str: + """Feature key value.""" return _join_properties(_SEPARATOR_FOR_VALUE, self.type_.name, self.id_) @@ -41,10 +43,12 @@ class SpaceKey: summary: Optional[str] = None """The organisation ID that owns the space.""" - def __str__(self) -> str: + def __str__(self: Self) -> str: + """Returns the string representation of the space key.""" return _join_properties(_SEPARATOR_FOR_STR, self.type_.name, self.org_id, self.id_) - def value(self) -> str: + def value(self: Self) -> str: + """Space key value.""" return _join_properties(_SEPARATOR_FOR_VALUE, self.type_.name, self.org_id, self.id_) @@ -57,6 +61,7 @@ class ConfigKey: is_optional: bool = False is_secret: bool = False ref_link: Optional[str] = None + options: Optional[dict] = None @dataclass diff --git a/source/docq/manage_documents.py b/source/docq/manage_documents.py index 20b2eb45..492c4cba 100644 --- a/source/docq/manage_documents.py +++ b/source/docq/manage_documents.py @@ -5,6 +5,7 @@ import unicodedata from datetime import datetime from mimetypes import guess_type +from typing import Optional from llama_index.schema import NodeWithScore from streamlit import runtime @@ -42,10 +43,17 @@ def delete_all(space: SpaceKey) -> None: reindex(space) +def _is_web_address(uri: str) -> bool: + """Return true if the uri is a web address.""" + return uri.startswith("http://") or uri.startswith("https://") + def _get_download_link(filename: str, path: str) -> str: """Return the download link for the file if runtime exists, otherwise return an empty string.""" - if runtime.exists() and os.path.isfile(path): + if _is_web_address(path): + return path + + elif runtime.exists() and os.path.isfile(path): return runtime.get_instance().media_file_mgr.add( path_or_data=path, mimetype=guess_type(path)[0] or "application/octet-stream", @@ -68,8 +76,8 @@ def _parse_metadata(metadata: dict) -> tuple: s_type = metadata.get(str(DocumentMetadata.DATA_SOURCE_TYPE.name).lower()) uri = metadata.get(str(DocumentMetadata.SOURCE_URI.name).lower()) if s_type == "SpaceDataSourceWebBased": - website = _remove_ascii_control_characters(metadata.get("source_website")) - page_title = _remove_ascii_control_characters(metadata.get("page_title")) + website = _remove_ascii_control_characters(metadata.get("source_website", "")) + page_title = _remove_ascii_control_characters(metadata.get("page_title", "")) return website, page_title, uri, s_type else: file_name = metadata.get("file_name") @@ -77,7 +85,7 @@ def _parse_metadata(metadata: dict) -> tuple: return file_name, page_label, uri, s_type -def _classify_file_sources(name: str, uri: str, page: str, sources: dict = None) -> str: +def _classify_file_sources(name: str, uri: str, page: str, sources: Optional[dict] = None) -> dict: """Classify file sources for easy grouping.""" if sources is None: sources = {} @@ -88,7 +96,7 @@ def _classify_file_sources(name: str, uri: str, page: str, sources: dict = None) return sources -def _classify_web_sources(website: str, uri: str, page_title: str, sources: dict = None) -> str: +def _classify_web_sources(website: str, uri: str, page_title: str, sources: Optional[dict] = None) -> dict: """Classify web sources for easy grouping.""" if sources is None: sources = {} diff --git a/source/docq/services/__init__.py b/source/docq/services/__init__.py new file mode 100644 index 00000000..6579ba77 --- /dev/null +++ b/source/docq/services/__init__.py @@ -0,0 +1,11 @@ +"""Docq services.""" +from . import google_drive, smtp_service + +__all__ = [ + "google_drive", + "smtp_service" +] + +def _init() -> None: + """Initialize all default services.""" + google_drive._init() diff --git a/source/docq/services/google_drive.py b/source/docq/services/google_drive.py new file mode 100644 index 00000000..82c2bb4a --- /dev/null +++ b/source/docq/services/google_drive.py @@ -0,0 +1,173 @@ +"""Google drive service.""" + +import json +import logging as log +import os +from typing import Any, Optional, Union + +from google.auth.external_account_authorized_user import Credentials as ExtCredentials +from google.auth.transport.requests import Request +from google.oauth2.credentials import Credentials +from google_auth_oauthlib.flow import InstalledAppFlow +from googleapiclient.discovery import build +from googleapiclient.http import MediaIoBaseDownload + +CREDENTIALS_KEY = "DOCQ_GOOGLE_APPLICATION_CREDENTIALS" +REDIRECT_URL_KEY = "DOCQ_GOOGLE_AUTH_REDIRECT_URL" + +GOOGLE_APPLICATION_CREDS_PATH = os.environ.get(CREDENTIALS_KEY) +FLOW_REDIRECT_URI = os.environ.get(REDIRECT_URL_KEY) + +SCOPES = [ + 'https://www.googleapis.com/auth/drive.readonly', + 'https://www.googleapis.com/auth/userinfo.email', + 'openid' +] + +KEY = "google_drive-API" +VALID_CREDENTIALS = "valid_credentials" +INVALID_CREDENTIALS = "invalid_credentials" +AUTH_WRONG_EMAIL = "auth_wrong_email" +AUTH_URL = "auth_url" +AUTH_ERROR = "auth_error" + +CREDENTIALS = Union[Credentials, ExtCredentials] + +def _init() -> None: + """Initialize.""" + if not GOOGLE_APPLICATION_CREDS_PATH: + return log.info("services.google_drive -- Google application credentials not found. API disabled.") + if not os.path.exists(GOOGLE_APPLICATION_CREDS_PATH): + return log.info("services.google_drive -- Google application credentials file not found. API disabled.") + if not FLOW_REDIRECT_URI: + return log.info("services.google_drive -- Google auth redirect url not found. API disabled.") + + +def get_flow() -> InstalledAppFlow: + """Get Google Drive flow.""" + flow = InstalledAppFlow.from_client_secrets_file( + GOOGLE_APPLICATION_CREDS_PATH, SCOPES + ) + flow.redirect_uri = FLOW_REDIRECT_URI + return flow + + +def get_credentials(creds: Optional[dict]) -> CREDENTIALS: + """Get credentials from user info.""" + _creds = Credentials.from_authorized_user_info(creds, SCOPES) + if _creds.expired and _creds.refresh_token: + _creds.refresh(Request()) + return _creds + + +def refresh_credentials(creds: CREDENTIALS) -> CREDENTIALS: + """Refresh credentials.""" + creds.refresh(Request()) + return creds + + +def validate_credentials(creds: Optional[str]) -> Optional[dict]: + """Validate credentials.""" + if not creds: + return None + try: + _creds = get_credentials(json.loads(creds)) + if _creds.valid: + return json.loads(_creds.to_json()) + return None + except Exception as e: + log.error("Failed to validate credentials: %s", e) + return None + + +def get_gdrive_authorized_email(creds: CREDENTIALS) -> str: + """Get user email.""" + service = build('oauth2', 'v2', credentials=creds) + return service.userinfo().get().execute()['email'] + + +def get_auth_url_params(email: Optional[str] = None, state: Optional[str] = None) -> dict: + """Get authorization url params.""" + authorization_params = { + "access_type": "offline", + "prompt": "consent", + "state": state if state else "", + } + if email: + authorization_params["login_hint"] = email + return authorization_params + + +def list_folders(creds: dict) -> list[dict]: + """List folders.""" + _creds = get_credentials(creds) + drive = build('drive', 'v3', credentials=_creds) + folders = drive.files().list( + q="mimeType='application/vnd.google-apps.folder'", + fields="files(id, name, parents, mimeType, modifiedTime)", + ).execute() + return folders.get('files', []) + + +def get_drive_service(creds: dict | str) -> Any: + """Get drive service.""" + _cred_dict = {} + _cred_dict = json.loads(creds) if isinstance(creds, str) else creds + _creds = get_credentials(_cred_dict) + return build('drive', 'v3', credentials=_creds) + + +def _export_gdrive_docs(service: Any, file_id: str) -> Any: + """Export google docs.""" + return service.files().export(fileId=file_id, mimeType="application/pdf") + + +def download_file(service: Any, file_id: str, file_name: str, mime: str) -> bool: + """Download file.""" + try: + if "google-apps" in mime: + request = _export_gdrive_docs(service, file_id) + file_name = f"{file_name}.pdf" + else: + request = service.files().get_media(fileId=file_id) + with open(file_name, "wb") as fh: + downloader, done = MediaIoBaseDownload(fh, request), False + while done is False: + status, done = downloader.next_chunk() + log.debug("Download - %s", f"{file_name}: {str(status.progress() * 100)}%") + return True + except Exception as e: + log.error("Failed to download file: %s", e) + return False + + +def get_auth_url(data: dict) -> Optional[dict]: + """Get auth url for google drive api.""" + try: + flow = get_flow() + code = data.get("code", None) + if code is not None: + flow.fetch_token(code=code) + creds = flow.credentials + return {"credential": creds.to_json()} + else: + email = data.get("email", None) + state = data.get("state", None) + authorization_params = get_auth_url_params(email, state) + authorization_url, state = flow.authorization_url( + **authorization_params, + ) + return {"auth_url": authorization_url } + except Exception as e: + log.error("Failed to get auth url: %s", e) + return None + + +def api_enabled() -> bool: + """Check if google drive API is enabled.""" + credentials_path = os.environ.get(CREDENTIALS_KEY) + return credentials_path is not None and os.path.isfile( + credentials_path) and all([ + credentials_path, + os.environ.get(REDIRECT_URL_KEY) + ]) diff --git a/source/docq/services/smtp_service.py b/source/docq/services/smtp_service.py index 361a5c27..637e4083 100644 --- a/source/docq/services/smtp_service.py +++ b/source/docq/services/smtp_service.py @@ -8,11 +8,12 @@ from email.mime.application import MIMEApplication from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from typing import Optional from urllib.parse import quote_plus SENDER_EMAIL_KEY = "DOCQ_SMTP_LOGIN" SMTP_PORT_KEY = "DOCQ_SMTP_PORT" -SMTP_PASSWORD_KEY = "DOCQ_SMTP_KEY" +SMTP_LOGIN_KEY = "DOCQ_SMTP_KEY" SMTP_SERVER_KEY = "DOCQ_SMTP_SERVER" SERVER_ADDRESS_KEY = "DOCQ_SERVER_ADDRESS" SMTP_SENDER_EMAIL_KEY = "DOCQ_SMTP_FROM" @@ -23,9 +24,9 @@ """ -def _get_verification_email_template(**kwargs: dict) -> str: +def _get_verification_email_template(**kwargs: str) -> str: """Get email template.""" - template = VERIFICATION_EMAIL_TEMPLATE + template:str = VERIFICATION_EMAIL_TEMPLATE for key, value in kwargs.items(): template = template.replace("{{ " + key + " }}", value) return template @@ -40,7 +41,7 @@ def _send_email( smtp_port: int, username: str, password: str, - attachments: list[str] = None, + attachments: Optional[list[str]] = None, ) -> None: """Send an email.""" try: @@ -78,11 +79,11 @@ def _generate_verification_url(user_id: int) -> str: def send_verification_email(reciever_email: str, name: str, user_id: int) -> None: """Send verification email.""" - username = os.environ.get(SENDER_EMAIL_KEY) - smtp_port = os.environ.get(SMTP_PORT_KEY) - smtp_password = os.environ.get(SMTP_PASSWORD_KEY) - smtp_server = os.environ.get(SMTP_SERVER_KEY) - sender_email = os.environ.get(SMTP_SENDER_EMAIL_KEY) + username: str = os.environ.get(SENDER_EMAIL_KEY, "") + smtp_port: int = int(os.environ.get(SMTP_PORT_KEY, 0)) + smtp_password: str = os.environ.get(SMTP_LOGIN_KEY, "") + smtp_server: str = os.environ.get(SMTP_SERVER_KEY, "") + sender_email: str = os.environ.get(SMTP_SENDER_EMAIL_KEY, "") subject = "Docq.AI Sign-up - Email Verification" message = _get_verification_email_template( @@ -109,7 +110,7 @@ def mailer_ready() -> bool: [ os.environ.get(SENDER_EMAIL_KEY), os.environ.get(SMTP_PORT_KEY), - os.environ.get(SMTP_PASSWORD_KEY), + os.environ.get(SMTP_LOGIN_KEY), os.environ.get(SMTP_SERVER_KEY), os.environ.get(SERVER_ADDRESS_KEY), os.environ.get(SMTP_SENDER_EMAIL_KEY), diff --git a/source/docq/setup.py b/source/docq/setup.py index 4b5ef0ef..f985576c 100644 --- a/source/docq/setup.py +++ b/source/docq/setup.py @@ -13,6 +13,7 @@ manage_spaces, manage_user_groups, manage_users, + services, ) from .support import auth_utils, llm, metadata_extractors, store @@ -34,6 +35,7 @@ def init() -> None: manage_settings._init() manage_spaces._init() manage_users._init() + services._init() store._init() manage_organisations._init_default_org_if_necessary() manage_users._init_admin_if_necessary() diff --git a/web/utils/handlers.py b/web/utils/handlers.py index c7d460cd..281743c2 100644 --- a/web/utils/handlers.py +++ b/web/utils/handlers.py @@ -5,6 +5,7 @@ import logging as log import math import random +import re from datetime import datetime from typing import Any, List, Optional, Tuple from urllib.parse import unquote_plus @@ -30,6 +31,7 @@ from docq.services.smtp_service import mailer_ready, send_verification_email from docq.support.auth_utils import reset_cache_and_cookie_auth_session from opentelemetry import baggage, trace +from streamlit.components.v1 import html from .constants import ( MAX_NUMBER_OF_PERSONAL_DOCS, @@ -650,6 +652,7 @@ def _prepare_space_data_source(prefix: str) -> Tuple[str, dict]: def handle_update_space_details(id_: int) -> bool: + """Update space details.""" ds_type, ds_configs = _prepare_space_data_source(f"update_space_details_{id_}_") org_id = get_selected_org_id() result = manage_spaces.update_shared_space( @@ -697,9 +700,10 @@ def get_space_data_source(space: SpaceKey) -> Tuple[str, dict]: def list_space_data_source_choices() -> List[Tuple[str, str, List[domain.ConfigKey]]]: + """List space data source choices.""" return [ (key, value.value.get_name(), value.value.get_config_keys()) - for key, value in SpaceDataSources.__members__.items() + for key, value in SpaceDataSources.__members__.items() if not value.value.disabled ] @@ -875,6 +879,14 @@ def get_query_param(key: str, type_: type = int) -> Any | None: return None +def handle_check_str_is_email(str_: Optional[str]) -> bool: + """Check if a string is a valid email.""" + if str_ is None: + return False + email_regex = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$" + return bool(re.match(email_regex, str_)) + + def handle_public_session() -> None: """Handle public session.""" session_id = get_query_param("session_id", str) @@ -902,3 +914,31 @@ def handle_public_session() -> None: space_group_id=-1, public_session_id=-1, ) + + +def handle_get_user_email() -> Optional[str]: + """Handle get username and check if it is an email. + + Returns: + Optional[str]: The username if it is an email, otherwise None. + """ + _email = get_username() + if handle_check_str_is_email(_email): + return _email + return None + + +def handle_redirect_to_url(url: str, key: str) -> None: + """Redirect to url.""" + html(f""" + + """, height=0 + ) diff --git a/web/utils/layout.py b/web/utils/layout.py index da9fd329..7f718b2c 100644 --- a/web/utils/layout.py +++ b/web/utils/layout.py @@ -1,8 +1,10 @@ """Layout components for the web app.""" - +import base64 import logging as log +import random import re -from typing import List, Tuple +from typing import Callable, List, Optional, Tuple +from urllib.parse import quote_plus, unquote_plus import docq import streamlit as st @@ -16,7 +18,7 @@ SystemFeatureType, SystemSettingsKey, ) -from docq.domain import DocumentListItem, FeatureKey, SpaceKey +from docq.domain import ConfigKey, DocumentListItem, FeatureKey, SpaceKey from docq.extensions import ExtensionContext from docq.model_selection.main import ( ModelUsageSettingsCollection, @@ -64,6 +66,7 @@ handle_fire_extensions_callbacks, handle_get_gravatar_url, handle_get_system_settings, + handle_get_user_email, handle_list_documents, handle_list_orgs, handle_login, @@ -71,6 +74,7 @@ handle_manage_space_permissions, handle_org_selection_change, handle_public_session, + handle_redirect_to_url, handle_reindex_space, handle_resend_email_verification, handle_update_org, @@ -845,34 +849,199 @@ def organisation_settings_ui() -> None: st.divider() -def _render_space_data_source_config_input_fields(data_source: Tuple, prefix: str, configs: dict = None) -> None: - for key in data_source[2]: - input_type = "password" if key.is_secret else "default" - st.text_input( - f"{key.name}{'' if key.is_optional else ' *'}", - value=configs.get(key.key) if configs else "", - key=prefix + "ds_config_" + key.key, - type=input_type, - help=key.ref_link, - autocomplete="off", # disable autofill by password manager etc. +def _get_create_space_config_input_values() -> str: + """Get values for space creation from session state.""" + space_name = st.session_state.get("create_space_name", "") + space_summary = st.session_state.get("create_space_summary", "") + ds_type = st.session_state.get("create_space_ds_type", ["", ""]) + space_configs = f"{space_name}::{space_summary}::{ds_type[1]}" + return quote_plus(base64.b64encode(space_configs.encode())) + + +def _get_random_key(prefix: str) -> str: + return prefix + str(random.randint(0, 1000000)) # noqa E501 + + +def _get_credential_request_params() -> dict: + return { + "code": st.experimental_get_query_params().get("code", [None])[0], + "email": handle_get_user_email(), + "state": _get_create_space_config_input_values() + } + + +def _render_file_storage_credential_request(configkey: ConfigKey, key: str, configs: Optional[dict]) -> None: + """Renders the credential request input field with an action button.""" + saved_credentials = configs.get(configkey.key) if configs else st.session_state.get(key, None) + global opacity + opacity = 0.4 if saved_credentials else 1.0 + st.markdown(f""" + + """, unsafe_allow_html=True) + st.write(f"{configkey.name}{'' if configkey.is_optional else ' *'}") + text_box, btn = st.columns([3, 1]) + + new_credentials, auth_url = None, None + + if saved_credentials is None: + params = _get_credential_request_params() + handler = configkey.options.get("handler", None) if configkey.options else None + response = handler(params) if handler else {} + new_credentials = response.get("credential") if response else None + auth_url = response.get("auth_url") if response else None + + opacity = 0.4 if bool(new_credentials or saved_credentials) else 1.0 + + text_box.text_input( + configkey.name, + value='*' * 64 if bool(new_credentials or saved_credentials) else "", + key=_get_random_key("_input_key"), + disabled=True, label_visibility="collapsed" + ) + + btn.button( + configkey.options.get("btn_label", "Get Credential") if configkey.options else "Get Credentials", + disabled=bool(saved_credentials or new_credentials), + key=_get_random_key("_btn_key"), + on_click=lambda: handle_redirect_to_url(auth_url, "gdrive") if auth_url else None, + ) + + if not bool(saved_credentials or new_credentials): + st.stop() + + if new_credentials is not None: + st.session_state[key] = new_credentials or saved_credentials + st.session_state[configkey.key] = new_credentials or saved_credentials + st.experimental_set_query_params() + + +def fetch_file_storage_root_folders(_configkey: ConfigKey, configs: Optional[dict]) -> tuple[list[dict], bool]: + """List File Storage System root foldesrs.""" + saved_settings = configs.get(_configkey.key) if configs else None + options = _configkey.options if _configkey.options else {} + if saved_settings is not None: + return [saved_settings], True + + else: + with st.spinner("Loading Options..."): + handler: Callable = options.get("handler", None) + if handler: + return handler(configs, st.session_state) + return [], False + + +def _set_options(configkey: ConfigKey, key: str, configs: Optional[dict]) -> None: + st.session_state[key] = ( + *fetch_file_storage_root_folders(configkey, configs), + configs.get(configkey.key) if configs else None + ) + + +def _render_file_storage_root_path_options(configkey: ConfigKey, key: str, configs: Optional[dict]) -> None: + """Renders the dynamic options for a config key.""" + temp_key = f"{key}_temp" + + if temp_key not in st.session_state: + _set_options(configkey, temp_key, configs) + + options, disabled, selected = st.session_state[temp_key] + if not options: + _set_options(configkey, temp_key, configs) + options, disabled, selected = st.session_state[temp_key] + + fmt_func = configkey.options.get("format_function", None) if configkey.options else None + + st.selectbox( + f"{configkey.name}{'' if configkey.is_optional else ' *'}", + options=options, + key=key, + format_func= fmt_func if fmt_func else lambda x: x, + help=configkey.ref_link, + disabled=disabled, + index=options.index(selected) if bool(selected and options) else 0, + ) + + +def _handle_custom_input_field(configkey: ConfigKey, key: str, configs: Optional[dict]) -> None: + """Handle Ui interactions.""" + if configkey.options and configkey.options.get("type") == "credential": + _render_file_storage_credential_request( + configkey, + key, + configs + ) + elif configkey.options and configkey.options.get("type") == "root_path": + _render_file_storage_root_path_options( + configkey, + key, + configs ) + else: + log.error("Unknown custom input field type: %s", str(configkey.options)) + + +def _render_space_data_source_config_input_fields(data_source: Tuple, prefix: str, configs: Optional[dict] = None) -> None: + config_key_list: List[ConfigKey] = data_source[2] + + for configkey in config_key_list: + _input_key = prefix + "ds_config_" + configkey.key + if configkey.options and configkey.options.get("type", None): + _handle_custom_input_field(configkey, _input_key, configs) + + else: + input_type = "password" if configkey.is_secret else "default" + st.text_input( + f"{configkey.name}{'' if configkey.is_optional else ' *'}", + value=configs.get(configkey.key) if configs else "", + key=_input_key, + type=input_type, + help=configkey.ref_link, + autocomplete="off", # disable autofill by password manager etc. + ) + + +def _get_create_space_form_values() -> Tuple[str, str, str]: + """Get default values for space creation from query string.""" + space_config = st.experimental_get_query_params().get("state", [None])[0] + if space_config: + try: + space_config = unquote_plus(space_config) + space_name, space_summary, ds_type = base64.b64decode(space_config).decode("utf-8").split("::") + st.session_state["create_space_defaults"] = (space_name, space_summary, ds_type) + except Exception as e: + st.session_state["create_space_defaults"] = ("", "", "") + log.error("Error parsing space config from query string: %s", e) + return st.session_state.get("create_space_defaults", ("", "", "")) def create_space_ui(expanded: bool = False) -> None: """Create a new space.""" data_sources = list_space_data_source_choices() - with st.expander("### + New Space", expanded=expanded): - st.text_input("Name", value="", key="create_space_name") - st.text_input("Summary", value="", key="create_space_summary") + space_name, space_summary, ds_type = _get_create_space_form_values() + _prefill_form = bool(space_name or space_summary or ds_type) + with st.expander("### + New Space", expanded=expanded or _prefill_form): + st.text_input("Name", value=space_name if space_name else "", key="create_space_name") + st.text_input("Summary", value=space_summary if space_summary else "", key="create_space_summary") ds = st.selectbox( "Data Source", options=data_sources, key="create_space_ds_type", format_func=lambda x: x[1], + index=[x[1] for x in data_sources].index(ds_type) if ds_type else 0, ) if ds: _render_space_data_source_config_input_fields(ds, "create_space_") if st.button("Create Space"): + st.session_state["create_space_defaults"] = ("", "", "") handle_create_space() @@ -900,7 +1069,7 @@ def _render_edit_space_details_form(space_data: Tuple, data_source: Tuple) -> No has_edit_perm = org_id == get_selected_org_id() if has_edit_perm: - with st.form(key=f"update_space_details_{id_}"): + with st.expander("Edit space", expanded=True): st.text_input("Name", value=name, key=f"update_space_details_{id_}_name") st.text_input("Summary", value=summary, key=f"update_space_details_{id_}_summary") st.checkbox("Is Archived", value=archived, key=f"update_space_details_{id_}_archived") @@ -913,7 +1082,8 @@ def _render_edit_space_details_form(space_data: Tuple, data_source: Tuple) -> No format_func=lambda x: x[1], ) _render_space_data_source_config_input_fields(data_source, f"update_space_details_{id_}_", ds_configs) - st.form_submit_button("Save", on_click=handle_update_space_details, args=(id_,)) + if st.button("Save", key=_get_random_key("_save_btn_key")): + handle_update_space_details(id_) def _render_edit_space_details(space_data: Tuple, data_source: Tuple) -> None: @@ -930,7 +1100,7 @@ def _render_manage_space_permissions_form(space_data: Tuple) -> None: if has_edit_perm: permissions = get_shared_space_permissions(id_) - with st.form(key=f"manage_space_permissions_{id_}"): + with st.expander("Manage Space Permissions", expanded=True): st.checkbox( "Public Access", value=permissions[SpaceAccessType.PUBLIC], @@ -950,7 +1120,8 @@ def _render_manage_space_permissions_form(space_data: Tuple) -> None: key=f"manage_space_permissions_{id_}_{SpaceAccessType.GROUP.name}", format_func=lambda x: x[1], ) - st.form_submit_button("Save", on_click=handle_manage_space_permissions, args=(id_,)) + if st.button("Save"): + handle_manage_space_permissions(id_) def _render_manage_space_permissions(space_data: Tuple) -> None: @@ -1014,7 +1185,7 @@ def _editor_view(q_param: str) -> None: _render_manage_space_permissions_form(s) -def admin_docs_ui(q_param: str = None) -> None: +def admin_docs_ui(q_param: Optional[str] = None) -> None: """Manage Documents UI.""" spaces = list_shared_spaces() if spaces: @@ -1038,9 +1209,10 @@ def admin_docs_ui(q_param: str = None) -> None: index=default_sid if default_sid else 0, ) - if selected: + if selected and q_param: st.experimental_set_query_params(**{q_param: selected[0]}) - _editor_view(q_param) + if q_param: + _editor_view(q_param) def org_selection_ui() -> None: @@ -1135,11 +1307,11 @@ def _validate_password(password: str, generator: DeltaGenerator) -> bool: return True -def validate_signup_form(form: str = "user-signup") -> None: +def validate_signup_form(form: str = "user-signup") ->bool: """Handle validation of the signup form.""" - name = st.session_state.get(f"{form}-name", None) - email = st.session_state.get(f"{form}-email", None) - password = st.session_state.get(f"{form}-password", None) + name: str = st.session_state.get(f"{form}-name", None) + email: str = st.session_state.get(f"{form}-email", None) + password: str = st.session_state.get(f"{form}-password", None) validator = st.session_state[f"{form}-validator"] if not _validate_name(name, validator):