Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Engine registration and custom options #170

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 94 additions & 106 deletions dj_database_url.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,101 @@
# -*- coding: utf-8 -*-

import collections
import os
import urllib.parse as urlparse
import warnings

try:
from django import VERSION as DJANGO_VERSION
except ImportError:
DJANGO_VERSION = None


# Register database schemes in URLs.
urlparse.uses_netloc.append("postgres")
urlparse.uses_netloc.append("postgresql")
urlparse.uses_netloc.append("pgsql")
urlparse.uses_netloc.append("postgis")
urlparse.uses_netloc.append("mysql")
urlparse.uses_netloc.append("mysql2")
urlparse.uses_netloc.append("mysqlgis")
urlparse.uses_netloc.append("mysql-connector")
urlparse.uses_netloc.append("mssql")
urlparse.uses_netloc.append("mssqlms")
urlparse.uses_netloc.append("spatialite")
urlparse.uses_netloc.append("sqlite")
urlparse.uses_netloc.append("oracle")
urlparse.uses_netloc.append("oraclegis")
urlparse.uses_netloc.append("redshift")
urlparse.uses_netloc.append("cockroach")
Engine = collections.namedtuple("Engine", ["backend", "string_ports", "options"])

DEFAULT_ENV = "DATABASE_URL"

SCHEMES = {
"postgis": "django.contrib.gis.db.backends.postgis",
"mysql": "django.db.backends.mysql",
"mysql2": "django.db.backends.mysql",
"mysqlgis": "django.contrib.gis.db.backends.mysql",
"mysql-connector": "mysql.connector.django",
"mssql": "sql_server.pyodbc",
"mssqlms": "mssql",
"spatialite": "django.contrib.gis.db.backends.spatialite",
"sqlite": "django.db.backends.sqlite3",
"oracle": "django.db.backends.oracle",
"oraclegis": "django.contrib.gis.db.backends.oracle",
"redshift": "django_redshift_backend",
"cockroach": "django_cockroachdb",
}

# https://docs.djangoproject.com/en/2.0/releases/2.0/#id1
if DJANGO_VERSION and DJANGO_VERSION < (2, 0):
SCHEMES["postgres"] = "django.db.backends.postgresql_psycopg2"
SCHEMES["postgresql"] = "django.db.backends.postgresql_psycopg2"
SCHEMES["pgsql"] = "django.db.backends.postgresql_psycopg2"
else:
SCHEMES["postgres"] = "django.db.backends.postgresql"
SCHEMES["postgresql"] = "django.db.backends.postgresql"
SCHEMES["pgsql"] = "django.db.backends.postgresql"


def config(
env=DEFAULT_ENV, default=None, engine=None, conn_max_age=0, ssl_require=False
):
ENGINE_SCHEMES = {}


def register(backend, schemes=None, string_ports=False, options=None):
if schemes is None:
schemes = [backend.rsplit(".")[-1]]
elif isinstance(schemes, str):
schemes = [schemes]

for scheme in schemes:
urlparse.uses_netloc.append(scheme)
ENGINE_SCHEMES[scheme] = Engine(backend, string_ports, options or {})


# Support all the first-party Django engines out of the box.
register(
"django.db.backends.postgresql",
("postgres", "postgresql", "pgsql"),
options={
"currentSchema": lambda values: {
"options": "-c search_path={}".format(values[-1])
},
},
)
register(
"django.contrib.gis.db.backends.postgis",
options={
"currentSchema": lambda values: {
"options": "-c search_path={}".format(values[-1])
},
},
)
register("django.contrib.gis.db.backends.spatialite")
register(
"django.db.backends.mysql",
options={
"ssl-ca": lambda values: {"ssl": {"ca": values[-1]}},
},
)
register("django.contrib.gis.db.backends.mysql", "mysqlgis")
register("django.db.backends.oracle", string_ports=True)
register("django.contrib.gis.db.backends.oracle", "oraclegis")
register("django.db.backends.sqlite3", "sqlite")


def config(env=DEFAULT_ENV, default=None, **settings):
"""Returns configured DATABASE dictionary from DATABASE_URL."""
s = os.environ.get(env, default)

if s:
return parse(s, engine, conn_max_age, ssl_require)

return {}
s = os.environ.get(env, default)
return parse(s, **settings) if s else {}


def parse(url, engine=None, conn_max_age=0, ssl_require=False):
def parse(url, backend=None, **settings):
"""Parses a database URL."""

if url == "sqlite://:memory:":
# this is a special case, because if we pass this URL into
# urlparse, urlparse will choke trying to interpret "memory"
# as a port number
return {"ENGINE": SCHEMES["sqlite"], "NAME": ":memory:"}
return {"ENGINE": ENGINE_SCHEMES["sqlite"].backend, "NAME": ":memory:"}
# note: no other settings are required for sqlite

# otherwise parse the url as normal
parsed_config = {}

url = urlparse.urlparse(url)
engine = ENGINE_SCHEMES[url.scheme]
options = {}

if "engine" in settings:
# Keep compatibility with dj-database-url for `engine` kwarg.
backend = settings.pop("engine")

if "conn_max_age" in settings:
warnings.warn(
"The `conn_max_age` argument is deprecated. Use `CONN_MAX_AGE` instead."
)
settings["CONN_MAX_AGE"] = settings.pop("conn_max_age")

if "ssl_require" in settings:
warnings.warn(
"The `ssl_require` argument is deprecated. "
"Use `OPTIONS={'sslmode': 'require'}` instead."
)
if settings.pop("ssl_require"):
options["sslmode"] = "require"

# Split query strings from path.
path = url.path[1:]
if "?" in path and not url.query:
path, query = path.split("?", 2)
path, query = path.split("?", 1)
else:
path, query = path, url.query
query = urlparse.parse_qs(query)
Expand All @@ -107,53 +116,32 @@ def parse(url, engine=None, conn_max_age=0, ssl_require=False):
hostname = hostname.split(":", 1)[0]
hostname = hostname.replace("%2f", "/").replace("%2F", "/")

# Lookup specified engine.
engine = SCHEMES[url.scheme] if engine is None else engine

port = (
str(url.port)
if url.port
and engine in [SCHEMES["oracle"], SCHEMES["mssql"], SCHEMES["mssqlms"]]
else url.port
)

# Update with environment configuration.
parsed_config.update(
{
"NAME": urlparse.unquote(path or ""),
"USER": urlparse.unquote(url.username or ""),
"PASSWORD": urlparse.unquote(url.password or ""),
"HOST": hostname,
"PORT": port or "",
"CONN_MAX_AGE": conn_max_age,
}
)
port = str(url.port) if url.port and engine.string_ports else url.port

# Pass the query string into OPTIONS.
options = {}
for key, values in query.items():
if url.scheme == "mysql" and key == "ssl-ca":
options["ssl"] = {"ca": values[-1]}
continue

options[key] = values[-1]
if key in engine.options:
options.update(engine.options[key](values))
else:
options[key] = values[-1]

if ssl_require:
options["sslmode"] = "require"
# Allow passed OPTIONS to override query string options.
options.update(settings.pop("OPTIONS", {}))

# Support for Postgres Schema URLs
if "currentSchema" in options and engine in (
"django.contrib.gis.db.backends.postgis",
"django.db.backends.postgresql_psycopg2",
"django.db.backends.postgresql",
"django_redshift_backend",
):
options["options"] = "-c search_path={0}".format(options.pop("currentSchema"))
# Update with environment configuration.
config = {
"ENGINE": backend or engine.backend,
"NAME": urlparse.unquote(path or ""),
"USER": urlparse.unquote(url.username or ""),
"PASSWORD": urlparse.unquote(url.password or ""),
"HOST": hostname,
"PORT": port or "",
}

if options:
parsed_config["OPTIONS"] = options
config["OPTIONS"] = options

if engine:
parsed_config["ENGINE"] = engine
# Update the final config with any settings passed in explicitly.
config.update(**settings)

return parsed_config
return config
28 changes: 21 additions & 7 deletions test_dj_database_url.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import os
import unittest

try:
from django import VERSION as DJANGO_VERSION
except ImportError:
DJANGO_VERSION = None

import dj_database_url

dj_database_url.register("mysql.connector.django", "mysql-connector")
dj_database_url.register("sql_server.pyodbc", "mssql", string_ports=True)
dj_database_url.register("mssql", "mssqlms")
dj_database_url.register(
"django_redshift_backend",
"redshift",
options={
"currentSchema": lambda values: {
"options": "-c search_path={}".format(values[-1])
},
},
)
dj_database_url.register("django_cockroachdb", "cockroach")


POSTGIS_URL = "postgis://uf07k1i6d8ia0v:[email protected]:5431/d8r82722r2kuvn"

# Django deprecated the `django.db.backends.postgresql_psycopg2` in 2.0.
# https://docs.djangoproject.com/en/2.0/releases/2.0/#id1
EXPECTED_POSTGRES_ENGINE = "django.db.backends.postgresql"
if DJANGO_VERSION and DJANGO_VERSION < (2, 0):
EXPECTED_POSTGRES_ENGINE = "django.db.backends.postgresql_psycopg2"


class DatabaseTestSuite(unittest.TestCase):
Expand Down Expand Up @@ -371,6 +379,12 @@ def test_mssqlms_parsing(self):
assert url["OPTIONS"]["driver"] == "ODBC Driver 13 for SQL Server"
assert "currentSchema" not in url["OPTIONS"]

def test_database_options(self):
url = "postgres://user:pass@host/db"
url = dj_database_url.parse(url, ATOMIC_REQUESTS=True, TEST={"NAME": "testdb"})
assert url["ATOMIC_REQUESTS"] is True
assert url["TEST"] == {"NAME": "testdb"}


if __name__ == "__main__":
unittest.main()