diff --git a/.vscode/settings.json b/.vscode/settings.json index e657cfe..010ab31 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,8 +2,7 @@ "python.pythonPath": "C:\\Python27\\python.exe", "python.formatting.provider": "autopep8", "python.linting.pylintEnabled": false, - "python.linting.flake8Enabled": true, - "python.linting.enabled": true, + "python.linting.enabled": false, "cSpell.words": [ "gocommandtext", "updatetmphk", diff --git a/houdini/python27/dlls/sqlite3.def b/houdini/python27/dlls/sqlite3.def new file mode 100644 index 0000000..cb082c7 --- /dev/null +++ b/houdini/python27/dlls/sqlite3.def @@ -0,0 +1,276 @@ +EXPORTS +sqlite3_aggregate_context +sqlite3_aggregate_count +sqlite3_auto_extension +sqlite3_backup_finish +sqlite3_backup_init +sqlite3_backup_pagecount +sqlite3_backup_remaining +sqlite3_backup_step +sqlite3_bind_blob +sqlite3_bind_blob64 +sqlite3_bind_double +sqlite3_bind_int +sqlite3_bind_int64 +sqlite3_bind_null +sqlite3_bind_parameter_count +sqlite3_bind_parameter_index +sqlite3_bind_parameter_name +sqlite3_bind_pointer +sqlite3_bind_text +sqlite3_bind_text16 +sqlite3_bind_text64 +sqlite3_bind_value +sqlite3_bind_zeroblob +sqlite3_bind_zeroblob64 +sqlite3_blob_bytes +sqlite3_blob_close +sqlite3_blob_open +sqlite3_blob_read +sqlite3_blob_reopen +sqlite3_blob_write +sqlite3_busy_handler +sqlite3_busy_timeout +sqlite3_cancel_auto_extension +sqlite3_changes +sqlite3_clear_bindings +sqlite3_close +sqlite3_close_v2 +sqlite3_collation_needed +sqlite3_collation_needed16 +sqlite3_column_blob +sqlite3_column_bytes +sqlite3_column_bytes16 +sqlite3_column_count +sqlite3_column_database_name +sqlite3_column_database_name16 +sqlite3_column_decltype +sqlite3_column_decltype16 +sqlite3_column_double +sqlite3_column_int +sqlite3_column_int64 +sqlite3_column_name +sqlite3_column_name16 +sqlite3_column_origin_name +sqlite3_column_origin_name16 +sqlite3_column_table_name +sqlite3_column_table_name16 +sqlite3_column_text +sqlite3_column_text16 +sqlite3_column_type +sqlite3_column_value +sqlite3_commit_hook +sqlite3_compileoption_get +sqlite3_compileoption_used +sqlite3_complete +sqlite3_complete16 +sqlite3_config +sqlite3_context_db_handle +sqlite3_create_collation +sqlite3_create_collation_v2 +sqlite3_create_collation16 +sqlite3_create_function +sqlite3_create_function_v2 +sqlite3_create_function16 +sqlite3_create_module +sqlite3_create_module_v2 +sqlite3_create_window_function +sqlite3_data_count +sqlite3_data_directory +sqlite3_db_cacheflush +sqlite3_db_config +sqlite3_db_filename +sqlite3_db_handle +sqlite3_db_mutex +sqlite3_db_readonly +sqlite3_db_release_memory +sqlite3_db_status +sqlite3_declare_vtab +sqlite3_deserialize +sqlite3_drop_modules +sqlite3_enable_load_extension +sqlite3_enable_shared_cache +sqlite3_errcode +sqlite3_errmsg +sqlite3_errmsg16 +sqlite3_errstr +sqlite3_exec +sqlite3_expanded_sql +sqlite3_expired +sqlite3_extended_errcode +sqlite3_extended_result_codes +sqlite3_file_control +sqlite3_filename_database +sqlite3_filename_journal +sqlite3_filename_wal +sqlite3_finalize +sqlite3_free +sqlite3_free_table +sqlite3_fts3_may_be_corrupt +sqlite3_fts5_may_be_corrupt +sqlite3_get_autocommit +sqlite3_get_auxdata +sqlite3_get_table +sqlite3_global_recover +sqlite3_hard_heap_limit64 +sqlite3_initialize +sqlite3_interrupt +sqlite3_keyword_check +sqlite3_keyword_count +sqlite3_keyword_name +sqlite3_last_insert_rowid +sqlite3_libversion +sqlite3_libversion_number +sqlite3_limit +sqlite3_load_extension +sqlite3_log +sqlite3_malloc +sqlite3_malloc64 +sqlite3_memory_alarm +sqlite3_memory_highwater +sqlite3_memory_used +sqlite3_mprintf +sqlite3_msize +sqlite3_mutex_alloc +sqlite3_mutex_enter +sqlite3_mutex_free +sqlite3_mutex_leave +sqlite3_mutex_try +sqlite3_next_stmt +sqlite3_open +sqlite3_open_v2 +sqlite3_open16 +sqlite3_os_end +sqlite3_os_init +sqlite3_overload_function +sqlite3_prepare +sqlite3_prepare_v2 +sqlite3_prepare_v3 +sqlite3_prepare16 +sqlite3_prepare16_v2 +sqlite3_prepare16_v3 +sqlite3_profile +sqlite3_progress_handler +sqlite3_randomness +sqlite3_realloc +sqlite3_realloc64 +sqlite3_release_memory +sqlite3_reset +sqlite3_reset_auto_extension +sqlite3_result_blob +sqlite3_result_blob64 +sqlite3_result_double +sqlite3_result_error +sqlite3_result_error_code +sqlite3_result_error_nomem +sqlite3_result_error_toobig +sqlite3_result_error16 +sqlite3_result_int +sqlite3_result_int64 +sqlite3_result_null +sqlite3_result_pointer +sqlite3_result_subtype +sqlite3_result_text +sqlite3_result_text16 +sqlite3_result_text16be +sqlite3_result_text16le +sqlite3_result_text64 +sqlite3_result_value +sqlite3_result_zeroblob +sqlite3_result_zeroblob64 +sqlite3_rollback_hook +sqlite3_rtree_geometry_callback +sqlite3_rtree_query_callback +sqlite3_serialize +sqlite3_set_authorizer +sqlite3_set_auxdata +sqlite3_set_last_insert_rowid +sqlite3_shutdown +sqlite3_sleep +sqlite3_snprintf +sqlite3_soft_heap_limit +sqlite3_soft_heap_limit64 +sqlite3_sourceid +sqlite3_sql +sqlite3_status +sqlite3_status64 +sqlite3_step +sqlite3_stmt_busy +sqlite3_stmt_isexplain +sqlite3_stmt_readonly +sqlite3_stmt_status +sqlite3_str_append +sqlite3_str_appendall +sqlite3_str_appendchar +sqlite3_str_appendf +sqlite3_str_errcode +sqlite3_str_finish +sqlite3_str_length +sqlite3_str_new +sqlite3_str_reset +sqlite3_str_value +sqlite3_str_vappendf +sqlite3_strglob +sqlite3_stricmp +sqlite3_strlike +sqlite3_strnicmp +sqlite3_system_errno +sqlite3_table_column_metadata +sqlite3_temp_directory +sqlite3_test_control +sqlite3_thread_cleanup +sqlite3_threadsafe +sqlite3_total_changes +sqlite3_trace +sqlite3_trace_v2 +sqlite3_transfer_bindings +sqlite3_update_hook +sqlite3_uri_boolean +sqlite3_uri_int64 +sqlite3_uri_key +sqlite3_uri_parameter +sqlite3_user_data +sqlite3_value_blob +sqlite3_value_bytes +sqlite3_value_bytes16 +sqlite3_value_double +sqlite3_value_dup +sqlite3_value_free +sqlite3_value_frombind +sqlite3_value_int +sqlite3_value_int64 +sqlite3_value_nochange +sqlite3_value_numeric_type +sqlite3_value_pointer +sqlite3_value_subtype +sqlite3_value_text +sqlite3_value_text16 +sqlite3_value_text16be +sqlite3_value_text16le +sqlite3_value_type +sqlite3_version +sqlite3_vfs_find +sqlite3_vfs_register +sqlite3_vfs_unregister +sqlite3_vmprintf +sqlite3_vsnprintf +sqlite3_vtab_collation +sqlite3_vtab_config +sqlite3_vtab_nochange +sqlite3_vtab_on_conflict +sqlite3_wal_autocheckpoint +sqlite3_wal_checkpoint +sqlite3_wal_checkpoint_v2 +sqlite3_wal_hook +sqlite3_win32_is_nt +sqlite3_win32_mbcs_to_utf8 +sqlite3_win32_mbcs_to_utf8_v2 +sqlite3_win32_set_directory +sqlite3_win32_set_directory16 +sqlite3_win32_set_directory8 +sqlite3_win32_sleep +sqlite3_win32_unicode_to_utf8 +sqlite3_win32_utf8_to_mbcs +sqlite3_win32_utf8_to_mbcs_v2 +sqlite3_win32_utf8_to_unicode +sqlite3_win32_write_debug diff --git a/python2.7libs/peewee-3.13.1.dist-info/INSTALLER b/python2.7libs/peewee-3.13.1.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/python2.7libs/peewee-3.13.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/python2.7libs/peewee-3.13.1.dist-info/LICENSE b/python2.7libs/peewee-3.13.1.dist-info/LICENSE new file mode 100644 index 0000000..c752ab3 --- /dev/null +++ b/python2.7libs/peewee-3.13.1.dist-info/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2010 Charles Leifer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/python2.7libs/peewee-3.13.1.dist-info/METADATA b/python2.7libs/peewee-3.13.1.dist-info/METADATA new file mode 100644 index 0000000..689495a --- /dev/null +++ b/python2.7libs/peewee-3.13.1.dist-info/METADATA @@ -0,0 +1,160 @@ +Metadata-Version: 2.1 +Name: peewee +Version: 3.13.1 +Summary: a little orm +Home-page: https://github.com/coleifer/peewee/ +Author: Charles Leifer +Author-email: coleifer@gmail.com +License: MIT License +Platform: any +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 2 +Classifier: Programming Language :: Python :: 2.7 +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.4 +Classifier: Programming Language :: Python :: 3.5 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Topic :: Software Development :: Libraries :: Python Modules + +.. image:: http://media.charlesleifer.com/blog/photos/peewee3-logo.png + +peewee +====== + +Peewee is a simple and small ORM. It has few (but expressive) concepts, making it easy to learn and intuitive to use. + +* a small, expressive ORM +* python 2.7+ and 3.4+ (developed with 3.6) +* supports sqlite, mysql, postgresql and cockroachdb +* tons of `extensions `_ + +.. image:: https://travis-ci.org/coleifer/peewee.svg?branch=master + :target: https://travis-ci.org/coleifer/peewee + +New to peewee? These may help: + +* `Quickstart `_ +* `Example twitter app `_ +* `Using peewee interactively `_ +* `Models and fields `_ +* `Querying `_ +* `Relationships and joins `_ + +Examples +-------- + +Defining models is similar to Django or SQLAlchemy: + +.. code-block:: python + + from peewee import * + import datetime + + + db = SqliteDatabase('my_database.db') + + class BaseModel(Model): + class Meta: + database = db + + class User(BaseModel): + username = CharField(unique=True) + + class Tweet(BaseModel): + user = ForeignKeyField(User, backref='tweets') + message = TextField() + created_date = DateTimeField(default=datetime.datetime.now) + is_published = BooleanField(default=True) + +Connect to the database and create tables: + +.. code-block:: python + + db.connect() + db.create_tables([User, Tweet]) + +Create a few rows: + +.. code-block:: python + + charlie = User.create(username='charlie') + huey = User(username='huey') + huey.save() + + # No need to set `is_published` or `created_date` since they + # will just use the default values we specified. + Tweet.create(user=charlie, message='My first tweet') + +Queries are expressive and composable: + +.. code-block:: python + + # A simple query selecting a user. + User.get(User.username == 'charlie') + + # Get tweets created by one of several users. + usernames = ['charlie', 'huey', 'mickey'] + users = User.select().where(User.username.in_(usernames)) + tweets = Tweet.select().where(Tweet.user.in_(users)) + + # We could accomplish the same using a JOIN: + tweets = (Tweet + .select() + .join(User) + .where(User.username.in_(usernames))) + + # How many tweets were published today? + tweets_today = (Tweet + .select() + .where( + (Tweet.created_date >= datetime.date.today()) & + (Tweet.is_published == True)) + .count()) + + # Paginate the user table and show me page 3 (users 41-60). + User.select().order_by(User.username).paginate(3, 20) + + # Order users by the number of tweets they've created: + tweet_ct = fn.Count(Tweet.id) + users = (User + .select(User, tweet_ct.alias('ct')) + .join(Tweet, JOIN.LEFT_OUTER) + .group_by(User) + .order_by(tweet_ct.desc())) + + # Do an atomic update + Counter.update(count=Counter.count + 1).where(Counter.url == request.url) + +Check out the `example twitter app `_. + +Learning more +------------- + +Check the `documentation `_ for more examples. + +Specific question? Come hang out in the #peewee channel on irc.freenode.net, or post to the mailing list, http://groups.google.com/group/peewee-orm . If you would like to report a bug, `create a new issue `_ on GitHub. + +Still want more info? +--------------------- + +.. image:: http://media.charlesleifer.com/blog/photos/wat.jpg + +I've written a number of blog posts about building applications and web-services with peewee (and usually Flask). If you'd like to see some real-life applications that use peewee, the following resources may be useful: + +* `Building a note-taking app with Flask and Peewee `_ as well as `Part 2 `_ and `Part 3 `_. +* `Analytics web service built with Flask and Peewee `_. +* `Personalized news digest (with a boolean query parser!) `_. +* `Structuring Flask apps with Peewee `_. +* `Creating a lastpass clone with Flask and Peewee `_. +* `Creating a bookmarking web-service that takes screenshots of your bookmarks `_. +* `Building a pastebin, wiki and a bookmarking service using Flask and Peewee `_. +* `Encrypted databases with Python and SQLCipher `_. +* `Dear Diary: An Encrypted, Command-Line Diary with Peewee `_. +* `Query Tree Structures in SQLite using Peewee and the Transitive Closure Extension `_. + + diff --git a/python2.7libs/peewee-3.13.1.dist-info/RECORD b/python2.7libs/peewee-3.13.1.dist-info/RECORD new file mode 100644 index 0000000..52fa977 --- /dev/null +++ b/python2.7libs/peewee-3.13.1.dist-info/RECORD @@ -0,0 +1,56 @@ +../../Scripts/pwiz.py,sha256=_QtosyH6-ZxVSBLJavgpx7YOMu7ZZ4EfkOBVRtJH2i8,8299 +../../Scripts/pwiz.pyc,, +peewee-3.13.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +peewee-3.13.1.dist-info/LICENSE,sha256=N0AJYSWwhzWiR7jdCM2C4LqYTTvr2SIdN4V2Y35SQNo,1058 +peewee-3.13.1.dist-info/METADATA,sha256=U-mXW9ZParUVeUSG30BHq8JSXk0xDZmeGTlo2jv4G4M,7001 +peewee-3.13.1.dist-info/RECORD,, +peewee-3.13.1.dist-info/WHEEL,sha256=dg7kGZASN2bfhrOe7EgPYNMV2qp9RJDAR5jrN3Aoyjg,98 +peewee-3.13.1.dist-info/top_level.txt,sha256=uV7RZ61bWm9zDrPVGNrGay4E4WDonEqtU2NPe5GGUWs,22 +peewee.py,sha256=DeOhmYattxIBlZsPOW4YIAMgDeUPJ7EL03IYQTYU1wM,258110 +peewee.pyc,, +playhouse/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +playhouse/__init__.pyc,, +playhouse/apsw_ext.py,sha256=09eLg7E6vRoBojJcwy6HDSiB5P9QCijJxofSJmL2emE,4555 +playhouse/apsw_ext.pyc,, +playhouse/cockroachdb.py,sha256=WpXgzskUe2Z6wudh2RfuCrXSVEbMRA5crBuNSETXVrs,8215 +playhouse/cockroachdb.pyc,, +playhouse/dataset.py,sha256=_1Ii9MNL_rElkWNC6SWWGn77tQWKRhlmzaeJIzfm_AE,13934 +playhouse/dataset.pyc,, +playhouse/db_url.py,sha256=JFhMZN268SbumQeQWpYZwCiKxZeXI76cbwXCCi9SAZw,4246 +playhouse/db_url.pyc,, +playhouse/fields.py,sha256=m1IkkO1FUUuaDzSIcdtg49NobhAJSJF3NlFNge0eFxY,1743 +playhouse/fields.pyc,, +playhouse/flask_utils.py,sha256=N0stRAzFJXUKu7gWtE46zUeAInReHjJ-F2N5lJGPAM8,6034 +playhouse/flask_utils.pyc,, +playhouse/hybrid.py,sha256=rRAPBImP2x61DoYh53mLm4JMPoNCLPFTd7_WIrq6_gU,1528 +playhouse/hybrid.pyc,, +playhouse/kv.py,sha256=S7AzM1v-G-BORNU1k_sJQm6MJettDZVBSWcRnqPoAoE,5375 +playhouse/kv.pyc,, +playhouse/migrate.py,sha256=TUlkT2_1yBudidLdL7m5Oj4zv4Rrsh8CuuZgIeuRUAY,30659 +playhouse/migrate.pyc,, +playhouse/mysql_ext.py,sha256=bCng_G-tk1qlWx_3UKha-8M_uwC6dhz9KbvDSrNm5P4,1436 +playhouse/mysql_ext.pyc,, +playhouse/pool.py,sha256=nrp-zLRmzDQsbIVvT8r4GI6NwIP53Are2Sj8jm0uC3c,11476 +playhouse/pool.pyc,, +playhouse/postgres_ext.py,sha256=wHXZiWeOfHPmbZnj5Y5nzynwmVhfmYMYZVyxofXf634,14059 +playhouse/postgres_ext.pyc,, +playhouse/reflection.py,sha256=tmLl4osVMamn89YVvv3UMoR1RN4ZKcGyYuwgBkDeQfk,30076 +playhouse/reflection.pyc,, +playhouse/shortcuts.py,sha256=rdXYwWHsIMmQ-3rHe7zyIfZcHgFoc2zi0InDhWy3di4,8442 +playhouse/shortcuts.pyc,, +playhouse/signals.py,sha256=9jFNDCpbMmBOfwCNIKlB3PlL7Hc1j45Rn9RObELZc6M,2523 +playhouse/signals.pyc,, +playhouse/sqlcipher_ext.py,sha256=iQnwkexhirMhmnxBKm30a8R7mtaqhHW-gBDp0SpVjPo,3534 +playhouse/sqlcipher_ext.pyc,, +playhouse/sqlite_changelog.py,sha256=JrTcwlZzcFzoTdorr2Xtt_oAkv6MyLU7sRdTFADG950,4567 +playhouse/sqlite_changelog.pyc,, +playhouse/sqlite_ext.py,sha256=UOrzConjJPt6pP4g9LQT0uBeZBVidgIVCyCtAwYjnYA,44303 +playhouse/sqlite_ext.pyc,, +playhouse/sqlite_udf.py,sha256=hehQVYY9pnLPQMPvYqT9fygTnhfWLjP0JI2EH6iHdWk,13291 +playhouse/sqlite_udf.pyc,, +playhouse/sqliteq.py,sha256=8rO_mJDjZhcobVPwMdlQFJfDOeB0UOr6fFVEgW1cuCU,10075 +playhouse/sqliteq.pyc,, +playhouse/test_utils.py,sha256=sJwKeBq2ebR_VFVBGjHDmfzlAqR04GPXKnD3dfA-GAA,1737 +playhouse/test_utils.pyc,, +pwiz.py,sha256=lLB_VkKehqLQSAGqseqPe-3FD23n4jO_E0vNY4ZZRhI,8089 +pwiz.pyc,, diff --git a/python2.7libs/peewee-3.13.1.dist-info/WHEEL b/python2.7libs/peewee-3.13.1.dist-info/WHEEL new file mode 100644 index 0000000..6118170 --- /dev/null +++ b/python2.7libs/peewee-3.13.1.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.33.6) +Root-Is-Purelib: true +Tag: cp27-none-any + diff --git a/python2.7libs/peewee-3.13.1.dist-info/top_level.txt b/python2.7libs/peewee-3.13.1.dist-info/top_level.txt new file mode 100644 index 0000000..1d507be --- /dev/null +++ b/python2.7libs/peewee-3.13.1.dist-info/top_level.txt @@ -0,0 +1,3 @@ +peewee +playhouse +pwiz diff --git a/python2.7libs/peewee.py b/python2.7libs/peewee.py new file mode 100644 index 0000000..392079e --- /dev/null +++ b/python2.7libs/peewee.py @@ -0,0 +1,7590 @@ +from bisect import bisect_left +from bisect import bisect_right +from contextlib import contextmanager +from copy import deepcopy +from functools import wraps +from inspect import isclass +import calendar +import collections +import datetime +import decimal +import hashlib +import itertools +import logging +import operator +import re +import socket +import struct +import sys +import threading +import time +import uuid +import warnings +try: + from collections.abc import Mapping +except ImportError: + from collections import Mapping + +try: + from pysqlite3 import dbapi2 as pysq3 +except ImportError: + try: + from pysqlite2 import dbapi2 as pysq3 + except ImportError: + pysq3 = None +try: + import sqlite3 +except ImportError: + sqlite3 = pysq3 +else: + if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info: + sqlite3 = pysq3 +try: + from psycopg2cffi import compat + compat.register() +except ImportError: + pass +try: + import psycopg2 + from psycopg2 import extensions as pg_extensions + try: + from psycopg2 import errors as pg_errors + except ImportError: + pg_errors = None +except ImportError: + psycopg2 = pg_errors = None + +mysql_passwd = False +try: + import pymysql as mysql +except ImportError: + try: + import MySQLdb as mysql + mysql_passwd = True + except ImportError: + mysql = None + + +__version__ = '3.13.1' +__all__ = [ + 'AsIs', + 'AutoField', + 'BareField', + 'BigAutoField', + 'BigBitField', + 'BigIntegerField', + 'BinaryUUIDField', + 'BitField', + 'BlobField', + 'BooleanField', + 'Case', + 'Cast', + 'CharField', + 'Check', + 'chunked', + 'Column', + 'CompositeKey', + 'Context', + 'Database', + 'DatabaseError', + 'DatabaseProxy', + 'DataError', + 'DateField', + 'DateTimeField', + 'DecimalField', + 'DeferredForeignKey', + 'DeferredThroughModel', + 'DJANGO_MAP', + 'DoesNotExist', + 'DoubleField', + 'DQ', + 'EXCLUDED', + 'Field', + 'FixedCharField', + 'FloatField', + 'fn', + 'ForeignKeyField', + 'IdentityField', + 'ImproperlyConfigured', + 'Index', + 'IntegerField', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'IPField', + 'JOIN', + 'ManyToManyField', + 'Model', + 'ModelIndex', + 'MySQLDatabase', + 'NotSupportedError', + 'OP', + 'OperationalError', + 'PostgresqlDatabase', + 'PrimaryKeyField', # XXX: Deprecated, change to AutoField. + 'prefetch', + 'ProgrammingError', + 'Proxy', + 'QualifiedNames', + 'SchemaManager', + 'SmallIntegerField', + 'Select', + 'SQL', + 'SqliteDatabase', + 'Table', + 'TextField', + 'TimeField', + 'TimestampField', + 'Tuple', + 'UUIDField', + 'Value', + 'ValuesList', + 'Window', +] + +try: # Python 2.7+ + from logging import NullHandler +except ImportError: + class NullHandler(logging.Handler): + def emit(self, record): + pass + +logger = logging.getLogger('peewee') +logger.addHandler(NullHandler()) + + +if sys.version_info[0] == 2: + text_type = unicode + bytes_type = str + buffer_type = buffer + izip_longest = itertools.izip_longest + callable_ = callable + exec('def reraise(tp, value, tb=None): raise tp, value, tb') + def print_(s): + sys.stdout.write(s) + sys.stdout.write('\n') +else: + import builtins + try: + from collections.abc import Callable + except ImportError: + from collections import Callable + from functools import reduce + callable_ = lambda c: isinstance(c, Callable) + text_type = str + bytes_type = bytes + buffer_type = memoryview + basestring = str + long = int + print_ = getattr(builtins, 'print') + izip_longest = itertools.zip_longest + def reraise(tp, value, tb=None): + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + + +if sqlite3: + sqlite3.register_adapter(decimal.Decimal, str) + sqlite3.register_adapter(datetime.date, str) + sqlite3.register_adapter(datetime.time, str) + __sqlite_version__ = sqlite3.sqlite_version_info +else: + __sqlite_version__ = (0, 0, 0) + + +__date_parts__ = set(('year', 'month', 'day', 'hour', 'minute', 'second')) + +# Sqlite does not support the `date_part` SQL function, so we will define an +# implementation in python. +__sqlite_datetime_formats__ = ( + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d', + '%H:%M:%S', + '%H:%M:%S.%f', + '%H:%M') + +__sqlite_date_trunc__ = { + 'year': '%Y-01-01 00:00:00', + 'month': '%Y-%m-01 00:00:00', + 'day': '%Y-%m-%d 00:00:00', + 'hour': '%Y-%m-%d %H:00:00', + 'minute': '%Y-%m-%d %H:%M:00', + 'second': '%Y-%m-%d %H:%M:%S'} + +__mysql_date_trunc__ = __sqlite_date_trunc__.copy() +__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i:00' +__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S' + +def _sqlite_date_part(lookup_type, datetime_string): + assert lookup_type in __date_parts__ + if not datetime_string: + return + dt = format_date_time(datetime_string, __sqlite_datetime_formats__) + return getattr(dt, lookup_type) + +def _sqlite_date_trunc(lookup_type, datetime_string): + assert lookup_type in __sqlite_date_trunc__ + if not datetime_string: + return + dt = format_date_time(datetime_string, __sqlite_datetime_formats__) + return dt.strftime(__sqlite_date_trunc__[lookup_type]) + + +def __deprecated__(s): + warnings.warn(s, DeprecationWarning) + + +class attrdict(dict): + def __getattr__(self, attr): + try: + return self[attr] + except KeyError: + raise AttributeError(attr) + def __setattr__(self, attr, value): self[attr] = value + def __iadd__(self, rhs): self.update(rhs); return self + def __add__(self, rhs): d = attrdict(self); d.update(rhs); return d + +SENTINEL = object() + +#: Operations for use in SQL expressions. +OP = attrdict( + AND='AND', + OR='OR', + ADD='+', + SUB='-', + MUL='*', + DIV='/', + BIN_AND='&', + BIN_OR='|', + XOR='#', + MOD='%', + EQ='=', + LT='<', + LTE='<=', + GT='>', + GTE='>=', + NE='!=', + IN='IN', + NOT_IN='NOT IN', + IS='IS', + IS_NOT='IS NOT', + LIKE='LIKE', + ILIKE='ILIKE', + BETWEEN='BETWEEN', + REGEXP='REGEXP', + IREGEXP='IREGEXP', + CONCAT='||', + BITWISE_NEGATION='~') + +# To support "django-style" double-underscore filters, create a mapping between +# operation name and operation code, e.g. "__eq" == OP.EQ. +DJANGO_MAP = attrdict({ + 'eq': operator.eq, + 'lt': operator.lt, + 'lte': operator.le, + 'gt': operator.gt, + 'gte': operator.ge, + 'ne': operator.ne, + 'in': operator.lshift, + 'is': lambda l, r: Expression(l, OP.IS, r), + 'like': lambda l, r: Expression(l, OP.LIKE, r), + 'ilike': lambda l, r: Expression(l, OP.ILIKE, r), + 'regexp': lambda l, r: Expression(l, OP.REGEXP, r), +}) + +#: Mapping of field type to the data-type supported by the database. Databases +#: may override or add to this list. +FIELD = attrdict( + AUTO='INTEGER', + BIGAUTO='BIGINT', + BIGINT='BIGINT', + BLOB='BLOB', + BOOL='SMALLINT', + CHAR='CHAR', + DATE='DATE', + DATETIME='DATETIME', + DECIMAL='DECIMAL', + DEFAULT='', + DOUBLE='REAL', + FLOAT='REAL', + INT='INTEGER', + SMALLINT='SMALLINT', + TEXT='TEXT', + TIME='TIME', + UUID='TEXT', + UUIDB='BLOB', + VARCHAR='VARCHAR') + +#: Join helpers (for convenience) -- all join types are supported, this object +#: is just to help avoid introducing errors by using strings everywhere. +JOIN = attrdict( + INNER='INNER JOIN', + LEFT_OUTER='LEFT OUTER JOIN', + RIGHT_OUTER='RIGHT OUTER JOIN', + FULL='FULL JOIN', + FULL_OUTER='FULL OUTER JOIN', + CROSS='CROSS JOIN', + NATURAL='NATURAL JOIN', + LATERAL='LATERAL', + LEFT_LATERAL='LEFT JOIN LATERAL') + +# Row representations. +ROW = attrdict( + TUPLE=1, + DICT=2, + NAMED_TUPLE=3, + CONSTRUCTOR=4, + MODEL=5) + +SCOPE_NORMAL = 1 +SCOPE_SOURCE = 2 +SCOPE_VALUES = 4 +SCOPE_CTE = 8 +SCOPE_COLUMN = 16 + +# Rules for parentheses around subqueries in compound select. +CSQ_PARENTHESES_NEVER = 0 +CSQ_PARENTHESES_ALWAYS = 1 +CSQ_PARENTHESES_UNNESTED = 2 + +# Regular expressions used to convert class names to snake-case table names. +# First regex handles acronym followed by word or initial lower-word followed +# by a capitalized word. e.g. APIResponse -> API_Response / fooBar -> foo_Bar. +# Second regex handles the normal case of two title-cased words. +SNAKE_CASE_STEP1 = re.compile('(.)_*([A-Z][a-z]+)') +SNAKE_CASE_STEP2 = re.compile('([a-z0-9])_*([A-Z])') + +# Helper functions that are used in various parts of the codebase. +MODEL_BASE = '_metaclass_helper_' + +def with_metaclass(meta, base=object): + return meta(MODEL_BASE, (base,), {}) + +def merge_dict(source, overrides): + merged = source.copy() + if overrides: + merged.update(overrides) + return merged + +def quote(path, quote_chars): + if len(path) == 1: + return path[0].join(quote_chars) + return '.'.join([part.join(quote_chars) for part in path]) + +is_model = lambda o: isclass(o) and issubclass(o, Model) + +def ensure_tuple(value): + if value is not None: + return value if isinstance(value, (list, tuple)) else (value,) + +def ensure_entity(value): + if value is not None: + return value if isinstance(value, Node) else Entity(value) + +def make_snake_case(s): + first = SNAKE_CASE_STEP1.sub(r'\1_\2', s) + return SNAKE_CASE_STEP2.sub(r'\1_\2', first).lower() + +def chunked(it, n): + marker = object() + for group in (list(g) for g in izip_longest(*[iter(it)] * n, + fillvalue=marker)): + if group[-1] is marker: + del group[group.index(marker):] + yield group + + +class _callable_context_manager(object): + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + with self: + return fn(*args, **kwargs) + return inner + + +class Proxy(object): + """ + Create a proxy or placeholder for another object. + """ + __slots__ = ('obj', '_callbacks') + + def __init__(self): + self._callbacks = [] + self.initialize(None) + + def initialize(self, obj): + self.obj = obj + for callback in self._callbacks: + callback(obj) + + def attach_callback(self, callback): + self._callbacks.append(callback) + return callback + + def passthrough(method): + def inner(self, *args, **kwargs): + if self.obj is None: + raise AttributeError('Cannot use uninitialized Proxy.') + return getattr(self.obj, method)(*args, **kwargs) + return inner + + # Allow proxy to be used as a context-manager. + __enter__ = passthrough('__enter__') + __exit__ = passthrough('__exit__') + + def __getattr__(self, attr): + if self.obj is None: + raise AttributeError('Cannot use uninitialized Proxy.') + return getattr(self.obj, attr) + + def __setattr__(self, attr, value): + if attr not in self.__slots__: + raise AttributeError('Cannot set attribute on proxy.') + return super(Proxy, self).__setattr__(attr, value) + + +class DatabaseProxy(Proxy): + """ + Proxy implementation specifically for proxying `Database` objects. + """ + def connection_context(self): + return ConnectionContext(self) + def atomic(self, *args, **kwargs): + return _atomic(self, *args, **kwargs) + def manual_commit(self): + return _manual(self) + def transaction(self, *args, **kwargs): + return _transaction(self, *args, **kwargs) + def savepoint(self): + return _savepoint(self) + + +class ModelDescriptor(object): pass + + +# SQL Generation. + + +class AliasManager(object): + __slots__ = ('_counter', '_current_index', '_mapping') + + def __init__(self): + # A list of dictionaries containing mappings at various depths. + self._counter = 0 + self._current_index = 0 + self._mapping = [] + self.push() + + @property + def mapping(self): + return self._mapping[self._current_index - 1] + + def add(self, source): + if source not in self.mapping: + self._counter += 1 + self[source] = 't%d' % self._counter + return self.mapping[source] + + def get(self, source, any_depth=False): + if any_depth: + for idx in reversed(range(self._current_index)): + if source in self._mapping[idx]: + return self._mapping[idx][source] + return self.add(source) + + def __getitem__(self, source): + return self.get(source) + + def __setitem__(self, source, alias): + self.mapping[source] = alias + + def push(self): + self._current_index += 1 + if self._current_index > len(self._mapping): + self._mapping.append({}) + + def pop(self): + if self._current_index == 1: + raise ValueError('Cannot pop() from empty alias manager.') + self._current_index -= 1 + + +class State(collections.namedtuple('_State', ('scope', 'parentheses', + 'settings'))): + def __new__(cls, scope=SCOPE_NORMAL, parentheses=False, **kwargs): + return super(State, cls).__new__(cls, scope, parentheses, kwargs) + + def __call__(self, scope=None, parentheses=None, **kwargs): + # Scope and settings are "inherited" (parentheses is not, however). + scope = self.scope if scope is None else scope + + # Try to avoid unnecessary dict copying. + if kwargs and self.settings: + settings = self.settings.copy() # Copy original settings dict. + settings.update(kwargs) # Update copy with overrides. + elif kwargs: + settings = kwargs + else: + settings = self.settings + return State(scope, parentheses, **settings) + + def __getattr__(self, attr_name): + return self.settings.get(attr_name) + + +def __scope_context__(scope): + @contextmanager + def inner(self, **kwargs): + with self(scope=scope, **kwargs): + yield self + return inner + + +class Context(object): + __slots__ = ('stack', '_sql', '_values', 'alias_manager', 'state') + + def __init__(self, **settings): + self.stack = [] + self._sql = [] + self._values = [] + self.alias_manager = AliasManager() + self.state = State(**settings) + + def as_new(self): + return Context(**self.state.settings) + + def column_sort_key(self, item): + return item[0].get_sort_key(self) + + @property + def scope(self): + return self.state.scope + + @property + def parentheses(self): + return self.state.parentheses + + @property + def subquery(self): + return self.state.subquery + + def __call__(self, **overrides): + if overrides and overrides.get('scope') == self.scope: + del overrides['scope'] + + self.stack.append(self.state) + self.state = self.state(**overrides) + return self + + scope_normal = __scope_context__(SCOPE_NORMAL) + scope_source = __scope_context__(SCOPE_SOURCE) + scope_values = __scope_context__(SCOPE_VALUES) + scope_cte = __scope_context__(SCOPE_CTE) + scope_column = __scope_context__(SCOPE_COLUMN) + + def __enter__(self): + if self.parentheses: + self.literal('(') + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.parentheses: + self.literal(')') + self.state = self.stack.pop() + + @contextmanager + def push_alias(self): + self.alias_manager.push() + yield + self.alias_manager.pop() + + def sql(self, obj): + if isinstance(obj, (Node, Context)): + return obj.__sql__(self) + elif is_model(obj): + return obj._meta.table.__sql__(self) + else: + return self.sql(Value(obj)) + + def literal(self, keyword): + self._sql.append(keyword) + return self + + def value(self, value, converter=None, add_param=True): + if converter: + value = converter(value) + if isinstance(value, Node): + return self.sql(value) + elif converter is None and self.state.converter: + # Explicitly check for None so that "False" can be used to signify + # that no conversion should be applied. + value = self.state.converter(value) + + if isinstance(value, Node): + with self(converter=None): + return self.sql(value) + + self._values.append(value) + return self.literal(self.state.param or '?') if add_param else self + + def __sql__(self, ctx): + ctx._sql.extend(self._sql) + ctx._values.extend(self._values) + return ctx + + def parse(self, node): + return self.sql(node).query() + + def query(self): + return ''.join(self._sql), self._values + + +def query_to_string(query): + # NOTE: this function is not exported by default as it might be misused -- + # and this misuse could lead to sql injection vulnerabilities. This + # function is intended for debugging or logging purposes ONLY. + db = getattr(query, '_database', None) + if db is not None: + ctx = db.get_sql_context() + else: + ctx = Context() + + sql, params = ctx.sql(query).query() + if not params: + return sql + + param = ctx.state.param or '?' + if param == '?': + sql = sql.replace('?', '%s') + + return sql % tuple(map(_query_val_transform, params)) + +def _query_val_transform(v): + # Interpolate parameters. + if isinstance(v, (text_type, datetime.datetime, datetime.date, + datetime.time)): + v = "'%s'" % v + elif isinstance(v, bytes_type): + try: + v = v.decode('utf8') + except UnicodeDecodeError: + v = v.decode('raw_unicode_escape') + v = "'%s'" % v + elif isinstance(v, int): + v = '%s' % int(v) # Also handles booleans -> 1 or 0. + elif v is None: + v = 'NULL' + else: + v = str(v) + return v + + +# AST. + + +class Node(object): + _coerce = True + + def clone(self): + obj = self.__class__.__new__(self.__class__) + obj.__dict__ = self.__dict__.copy() + return obj + + def __sql__(self, ctx): + raise NotImplementedError + + @staticmethod + def copy(method): + def inner(self, *args, **kwargs): + clone = self.clone() + method(clone, *args, **kwargs) + return clone + return inner + + def coerce(self, _coerce=True): + if _coerce != self._coerce: + clone = self.clone() + clone._coerce = _coerce + return clone + return self + + def is_alias(self): + return False + + def unwrap(self): + return self + + +class ColumnFactory(object): + __slots__ = ('node',) + + def __init__(self, node): + self.node = node + + def __getattr__(self, attr): + return Column(self.node, attr) + + +class _DynamicColumn(object): + __slots__ = () + + def __get__(self, instance, instance_type=None): + if instance is not None: + return ColumnFactory(instance) # Implements __getattr__(). + return self + + +class _ExplicitColumn(object): + __slots__ = () + + def __get__(self, instance, instance_type=None): + if instance is not None: + raise AttributeError( + '%s specifies columns explicitly, and does not support ' + 'dynamic column lookups.' % instance) + return self + + +class Source(Node): + c = _DynamicColumn() + + def __init__(self, alias=None): + super(Source, self).__init__() + self._alias = alias + + @Node.copy + def alias(self, name): + self._alias = name + + def select(self, *columns): + if not columns: + columns = (SQL('*'),) + return Select((self,), columns) + + def join(self, dest, join_type=JOIN.INNER, on=None): + return Join(self, dest, join_type, on) + + def left_outer_join(self, dest, on=None): + return Join(self, dest, JOIN.LEFT_OUTER, on) + + def cte(self, name, recursive=False, columns=None, materialized=None): + return CTE(name, self, recursive=recursive, columns=columns, + materialized=materialized) + + def get_sort_key(self, ctx): + if self._alias: + return (self._alias,) + return (ctx.alias_manager[self],) + + def apply_alias(self, ctx): + # If we are defining the source, include the "AS alias" declaration. An + # alias is created for the source if one is not already defined. + if ctx.scope == SCOPE_SOURCE: + if self._alias: + ctx.alias_manager[self] = self._alias + ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) + return ctx + + def apply_column(self, ctx): + if self._alias: + ctx.alias_manager[self] = self._alias + return ctx.sql(Entity(ctx.alias_manager[self])) + + +class _HashableSource(object): + def __init__(self, *args, **kwargs): + super(_HashableSource, self).__init__(*args, **kwargs) + self._update_hash() + + @Node.copy + def alias(self, name): + self._alias = name + self._update_hash() + + def _update_hash(self): + self._hash = self._get_hash() + + def _get_hash(self): + return hash((self.__class__, self._path, self._alias)) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + return self._hash == other._hash + + def __ne__(self, other): + return not (self == other) + + +def __bind_database__(meth): + @wraps(meth) + def inner(self, *args, **kwargs): + result = meth(self, *args, **kwargs) + if self._database: + return result.bind(self._database) + return result + return inner + + +def __join__(join_type=JOIN.INNER, inverted=False): + def method(self, other): + if inverted: + self, other = other, self + return Join(self, other, join_type=join_type) + return method + + +class BaseTable(Source): + __and__ = __join__(JOIN.INNER) + __add__ = __join__(JOIN.LEFT_OUTER) + __sub__ = __join__(JOIN.RIGHT_OUTER) + __or__ = __join__(JOIN.FULL_OUTER) + __mul__ = __join__(JOIN.CROSS) + __rand__ = __join__(JOIN.INNER, inverted=True) + __radd__ = __join__(JOIN.LEFT_OUTER, inverted=True) + __rsub__ = __join__(JOIN.RIGHT_OUTER, inverted=True) + __ror__ = __join__(JOIN.FULL_OUTER, inverted=True) + __rmul__ = __join__(JOIN.CROSS, inverted=True) + + +class _BoundTableContext(_callable_context_manager): + def __init__(self, table, database): + self.table = table + self.database = database + + def __enter__(self): + self._orig_database = self.table._database + self.table.bind(self.database) + if self.table._model is not None: + self.table._model.bind(self.database) + return self.table + + def __exit__(self, exc_type, exc_val, exc_tb): + self.table.bind(self._orig_database) + if self.table._model is not None: + self.table._model.bind(self._orig_database) + + +class Table(_HashableSource, BaseTable): + def __init__(self, name, columns=None, primary_key=None, schema=None, + alias=None, _model=None, _database=None): + self.__name__ = name + self._columns = columns + self._primary_key = primary_key + self._schema = schema + self._path = (schema, name) if schema else (name,) + self._model = _model + self._database = _database + super(Table, self).__init__(alias=alias) + + # Allow tables to restrict what columns are available. + if columns is not None: + self.c = _ExplicitColumn() + for column in columns: + setattr(self, column, Column(self, column)) + + if primary_key: + col_src = self if self._columns else self.c + self.primary_key = getattr(col_src, primary_key) + else: + self.primary_key = None + + def clone(self): + # Ensure a deep copy of the column instances. + return Table( + self.__name__, + columns=self._columns, + primary_key=self._primary_key, + schema=self._schema, + alias=self._alias, + _model=self._model, + _database=self._database) + + def bind(self, database=None): + self._database = database + return self + + def bind_ctx(self, database=None): + return _BoundTableContext(self, database) + + def _get_hash(self): + return hash((self.__class__, self._path, self._alias, self._model)) + + @__bind_database__ + def select(self, *columns): + if not columns and self._columns: + columns = [Column(self, column) for column in self._columns] + return Select((self,), columns) + + @__bind_database__ + def insert(self, insert=None, columns=None, **kwargs): + if kwargs: + insert = {} if insert is None else insert + src = self if self._columns else self.c + for key, value in kwargs.items(): + insert[getattr(src, key)] = value + return Insert(self, insert=insert, columns=columns) + + @__bind_database__ + def replace(self, insert=None, columns=None, **kwargs): + return (self + .insert(insert=insert, columns=columns) + .on_conflict('REPLACE')) + + @__bind_database__ + def update(self, update=None, **kwargs): + if kwargs: + update = {} if update is None else update + for key, value in kwargs.items(): + src = self if self._columns else self.c + update[getattr(src, key)] = value + return Update(self, update=update) + + @__bind_database__ + def delete(self): + return Delete(self) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + # Return the quoted table name. + return ctx.sql(Entity(*self._path)) + + if self._alias: + ctx.alias_manager[self] = self._alias + + if ctx.scope == SCOPE_SOURCE: + # Define the table and its alias. + return self.apply_alias(ctx.sql(Entity(*self._path))) + else: + # Refer to the table using the alias. + return self.apply_column(ctx) + + +class Join(BaseTable): + def __init__(self, lhs, rhs, join_type=JOIN.INNER, on=None, alias=None): + super(Join, self).__init__(alias=alias) + self.lhs = lhs + self.rhs = rhs + self.join_type = join_type + self._on = on + + def on(self, predicate): + self._on = predicate + return self + + def __sql__(self, ctx): + (ctx + .sql(self.lhs) + .literal(' %s ' % self.join_type) + .sql(self.rhs)) + if self._on is not None: + ctx.literal(' ON ').sql(self._on) + return ctx + + +class ValuesList(_HashableSource, BaseTable): + def __init__(self, values, columns=None, alias=None): + self._values = values + self._columns = columns + super(ValuesList, self).__init__(alias=alias) + + def _get_hash(self): + return hash((self.__class__, id(self._values), self._alias)) + + @Node.copy + def columns(self, *names): + self._columns = names + + def __sql__(self, ctx): + if self._alias: + ctx.alias_manager[self] = self._alias + + if ctx.scope == SCOPE_SOURCE or ctx.scope == SCOPE_NORMAL: + with ctx(parentheses=not ctx.parentheses): + ctx = (ctx + .literal('VALUES ') + .sql(CommaNodeList([ + EnclosedNodeList(row) for row in self._values]))) + + if ctx.scope == SCOPE_SOURCE: + ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) + if self._columns: + entities = [Entity(c) for c in self._columns] + ctx.sql(EnclosedNodeList(entities)) + else: + ctx.sql(Entity(ctx.alias_manager[self])) + + return ctx + + +class CTE(_HashableSource, Source): + def __init__(self, name, query, recursive=False, columns=None, + materialized=None): + self._alias = name + self._query = query + self._recursive = recursive + self._materialized = materialized + if columns is not None: + columns = [Entity(c) if isinstance(c, basestring) else c + for c in columns] + self._columns = columns + query._cte_list = () + super(CTE, self).__init__(alias=name) + + def select_from(self, *columns): + if not columns: + raise ValueError('select_from() must specify one or more columns ' + 'from the CTE to select.') + + query = (Select((self,), columns) + .with_cte(self) + .bind(self._query._database)) + try: + query = query.objects(self._query.model) + except AttributeError: + pass + return query + + def _get_hash(self): + return hash((self.__class__, self._alias, id(self._query))) + + def union_all(self, rhs): + clone = self._query.clone() + return CTE(self._alias, clone + rhs, self._recursive, self._columns) + __add__ = union_all + + def __sql__(self, ctx): + if ctx.scope != SCOPE_CTE: + return ctx.sql(Entity(self._alias)) + + with ctx.push_alias(): + ctx.alias_manager[self] = self._alias + ctx.sql(Entity(self._alias)) + + if self._columns: + ctx.literal(' ').sql(EnclosedNodeList(self._columns)) + ctx.literal(' AS ') + + if self._materialized: + ctx.literal('MATERIALIZED ') + elif self._materialized is False: + ctx.literal('NOT MATERIALIZED ') + + with ctx.scope_normal(parentheses=True): + ctx.sql(self._query) + return ctx + + +class ColumnBase(Node): + def alias(self, alias): + if alias: + return Alias(self, alias) + return self + + def unalias(self): + return self + + def cast(self, as_type): + return Cast(self, as_type) + + def asc(self, collation=None, nulls=None): + return Asc(self, collation=collation, nulls=nulls) + __pos__ = asc + + def desc(self, collation=None, nulls=None): + return Desc(self, collation=collation, nulls=nulls) + __neg__ = desc + + def __invert__(self): + return Negated(self) + + def _e(op, inv=False): + """ + Lightweight factory which returns a method that builds an Expression + consisting of the left-hand and right-hand operands, using `op`. + """ + def inner(self, rhs): + if inv: + return Expression(rhs, op, self) + return Expression(self, op, rhs) + return inner + __and__ = _e(OP.AND) + __or__ = _e(OP.OR) + + __add__ = _e(OP.ADD) + __sub__ = _e(OP.SUB) + __mul__ = _e(OP.MUL) + __div__ = __truediv__ = _e(OP.DIV) + __xor__ = _e(OP.XOR) + __radd__ = _e(OP.ADD, inv=True) + __rsub__ = _e(OP.SUB, inv=True) + __rmul__ = _e(OP.MUL, inv=True) + __rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True) + __rand__ = _e(OP.AND, inv=True) + __ror__ = _e(OP.OR, inv=True) + __rxor__ = _e(OP.XOR, inv=True) + + def __eq__(self, rhs): + op = OP.IS if rhs is None else OP.EQ + return Expression(self, op, rhs) + def __ne__(self, rhs): + op = OP.IS_NOT if rhs is None else OP.NE + return Expression(self, op, rhs) + + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lshift__ = _e(OP.IN) + __rshift__ = _e(OP.IS) + __mod__ = _e(OP.LIKE) + __pow__ = _e(OP.ILIKE) + + bin_and = _e(OP.BIN_AND) + bin_or = _e(OP.BIN_OR) + in_ = _e(OP.IN) + not_in = _e(OP.NOT_IN) + regexp = _e(OP.REGEXP) + + # Special expressions. + def is_null(self, is_null=True): + op = OP.IS if is_null else OP.IS_NOT + return Expression(self, op, None) + def contains(self, rhs): + if isinstance(rhs, Node): + rhs = Expression('%', OP.CONCAT, + Expression(rhs, OP.CONCAT, '%')) + else: + rhs = '%%%s%%' % rhs + return Expression(self, OP.ILIKE, rhs) + def startswith(self, rhs): + if isinstance(rhs, Node): + rhs = Expression(rhs, OP.CONCAT, '%') + else: + rhs = '%s%%' % rhs + return Expression(self, OP.ILIKE, rhs) + def endswith(self, rhs): + if isinstance(rhs, Node): + rhs = Expression('%', OP.CONCAT, rhs) + else: + rhs = '%%%s' % rhs + return Expression(self, OP.ILIKE, rhs) + def between(self, lo, hi): + return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi))) + def concat(self, rhs): + return StringExpression(self, OP.CONCAT, rhs) + def regexp(self, rhs): + return Expression(self, OP.REGEXP, rhs) + def iregexp(self, rhs): + return Expression(self, OP.IREGEXP, rhs) + def __getitem__(self, item): + if isinstance(item, slice): + if item.start is None or item.stop is None: + raise ValueError('BETWEEN range must have both a start- and ' + 'end-point.') + return self.between(item.start, item.stop) + return self == item + + def distinct(self): + return NodeList((SQL('DISTINCT'), self)) + + def collate(self, collation): + return NodeList((self, SQL('COLLATE %s' % collation))) + + def get_sort_key(self, ctx): + return () + + +class Column(ColumnBase): + def __init__(self, source, name): + self.source = source + self.name = name + + def get_sort_key(self, ctx): + if ctx.scope == SCOPE_VALUES: + return (self.name,) + else: + return self.source.get_sort_key(ctx) + (self.name,) + + def __hash__(self): + return hash((self.source, self.name)) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + return ctx.sql(Entity(self.name)) + else: + with ctx.scope_column(): + return ctx.sql(self.source).literal('.').sql(Entity(self.name)) + + +class WrappedNode(ColumnBase): + def __init__(self, node): + self.node = node + self._coerce = getattr(node, '_coerce', True) + + def is_alias(self): + return self.node.is_alias() + + def unwrap(self): + return self.node.unwrap() + + +class EntityFactory(object): + __slots__ = ('node',) + def __init__(self, node): + self.node = node + def __getattr__(self, attr): + return Entity(self.node, attr) + + +class _DynamicEntity(object): + __slots__ = () + def __get__(self, instance, instance_type=None): + if instance is not None: + return EntityFactory(instance._alias) # Implements __getattr__(). + return self + + +class Alias(WrappedNode): + c = _DynamicEntity() + + def __init__(self, node, alias): + super(Alias, self).__init__(node) + self._alias = alias + + def __hash__(self): + return hash(self._alias) + + def alias(self, alias=None): + if alias is None: + return self.node + else: + return Alias(self.node, alias) + + def unalias(self): + return self.node + + def is_alias(self): + return True + + def __sql__(self, ctx): + if ctx.scope == SCOPE_SOURCE: + return (ctx + .sql(self.node) + .literal(' AS ') + .sql(Entity(self._alias))) + else: + return ctx.sql(Entity(self._alias)) + + +class Negated(WrappedNode): + def __invert__(self): + return self.node + + def __sql__(self, ctx): + return ctx.literal('NOT ').sql(self.node) + + +class BitwiseMixin(object): + def __and__(self, other): + return self.bin_and(other) + + def __or__(self, other): + return self.bin_or(other) + + def __sub__(self, other): + return self.bin_and(other.bin_negated()) + + def __invert__(self): + return BitwiseNegated(self) + + +class BitwiseNegated(BitwiseMixin, WrappedNode): + def __invert__(self): + return self.node + + def __sql__(self, ctx): + if ctx.state.operations: + op_sql = ctx.state.operations.get(self.op, self.op) + else: + op_sql = self.op + return ctx.literal(op_sql).sql(self.node) + + +class Value(ColumnBase): + _multi_types = (list, tuple, frozenset, set) + + def __init__(self, value, converter=None, unpack=True): + self.value = value + self.converter = converter + self.multi = isinstance(self.value, self._multi_types) and unpack + if self.multi: + self.values = [] + for item in self.value: + if isinstance(item, Node): + self.values.append(item) + else: + self.values.append(Value(item, self.converter)) + + def __sql__(self, ctx): + if self.multi: + # For multi-part values (e.g. lists of IDs). + return ctx.sql(EnclosedNodeList(self.values)) + + return ctx.value(self.value, self.converter) + + +def AsIs(value): + return Value(value, unpack=False) + + +class Cast(WrappedNode): + def __init__(self, node, cast): + super(Cast, self).__init__(node) + self._cast = cast + self._coerce = False + + def __sql__(self, ctx): + return (ctx + .literal('CAST(') + .sql(self.node) + .literal(' AS %s)' % self._cast)) + + +class Ordering(WrappedNode): + def __init__(self, node, direction, collation=None, nulls=None): + super(Ordering, self).__init__(node) + self.direction = direction + self.collation = collation + self.nulls = nulls + if nulls and nulls.lower() not in ('first', 'last'): + raise ValueError('Ordering nulls= parameter must be "first" or ' + '"last", got: %s' % nulls) + + def collate(self, collation=None): + return Ordering(self.node, self.direction, collation) + + def _null_ordering_case(self, nulls): + if nulls.lower() == 'last': + ifnull, notnull = 1, 0 + elif nulls.lower() == 'first': + ifnull, notnull = 0, 1 + else: + raise ValueError('unsupported value for nulls= ordering.') + return Case(None, ((self.node.is_null(), ifnull),), notnull) + + def __sql__(self, ctx): + if self.nulls and not ctx.state.nulls_ordering: + ctx.sql(self._null_ordering_case(self.nulls)).literal(', ') + + ctx.sql(self.node).literal(' %s' % self.direction) + if self.collation: + ctx.literal(' COLLATE %s' % self.collation) + if self.nulls and ctx.state.nulls_ordering: + ctx.literal(' NULLS %s' % self.nulls) + return ctx + + +def Asc(node, collation=None, nulls=None): + return Ordering(node, 'ASC', collation, nulls) + + +def Desc(node, collation=None, nulls=None): + return Ordering(node, 'DESC', collation, nulls) + + +class Expression(ColumnBase): + def __init__(self, lhs, op, rhs, flat=False): + self.lhs = lhs + self.op = op + self.rhs = rhs + self.flat = flat + + def __sql__(self, ctx): + overrides = {'parentheses': not self.flat, 'in_expr': True} + + # First attempt to unwrap the node on the left-hand-side, so that we + # can get at the underlying Field if one is present. + node = raw_node = self.lhs + if isinstance(raw_node, WrappedNode): + node = raw_node.unwrap() + + # Set up the appropriate converter if we have a field on the left side. + if isinstance(node, Field) and raw_node._coerce: + overrides['converter'] = node.db_value + else: + overrides['converter'] = None + + if ctx.state.operations: + op_sql = ctx.state.operations.get(self.op, self.op) + else: + op_sql = self.op + + with ctx(**overrides): + # Postgresql reports an error for IN/NOT IN (), so convert to + # the equivalent boolean expression. + op_in = self.op == OP.IN or self.op == OP.NOT_IN + if op_in and ctx.as_new().parse(self.rhs)[0] == '()': + return ctx.literal('0 = 1' if self.op == OP.IN else '1 = 1') + + return (ctx + .sql(self.lhs) + .literal(' %s ' % op_sql) + .sql(self.rhs)) + + +class StringExpression(Expression): + def __add__(self, rhs): + return self.concat(rhs) + def __radd__(self, lhs): + return StringExpression(lhs, OP.CONCAT, self) + + +class Entity(ColumnBase): + def __init__(self, *path): + self._path = [part.replace('"', '""') for part in path if part] + + def __getattr__(self, attr): + return Entity(*self._path + [attr]) + + def get_sort_key(self, ctx): + return tuple(self._path) + + def __hash__(self): + return hash((self.__class__.__name__, tuple(self._path))) + + def __sql__(self, ctx): + return ctx.literal(quote(self._path, ctx.state.quote or '""')) + + +class SQL(ColumnBase): + def __init__(self, sql, params=None): + self.sql = sql + self.params = params + + def __sql__(self, ctx): + ctx.literal(self.sql) + if self.params: + for param in self.params: + ctx.value(param, False, add_param=False) + return ctx + + +def Check(constraint): + return SQL('CHECK (%s)' % constraint) + + +class Function(ColumnBase): + def __init__(self, name, arguments, coerce=True, python_value=None): + self.name = name + self.arguments = arguments + self._filter = None + self._python_value = python_value + if name and name.lower() in ('sum', 'count', 'cast'): + self._coerce = False + else: + self._coerce = coerce + + def __getattr__(self, attr): + def decorator(*args, **kwargs): + return Function(attr, args, **kwargs) + return decorator + + @Node.copy + def filter(self, where=None): + self._filter = where + + @Node.copy + def python_value(self, func=None): + self._python_value = func + + def over(self, partition_by=None, order_by=None, start=None, end=None, + frame_type=None, window=None, exclude=None): + if isinstance(partition_by, Window) and window is None: + window = partition_by + + if window is not None: + node = WindowAlias(window) + else: + node = Window(partition_by=partition_by, order_by=order_by, + start=start, end=end, frame_type=frame_type, + exclude=exclude, _inline=True) + return NodeList((self, SQL('OVER'), node)) + + def __sql__(self, ctx): + ctx.literal(self.name) + if not len(self.arguments): + ctx.literal('()') + else: + with ctx(in_function=True, function_arg_count=len(self.arguments)): + ctx.sql(EnclosedNodeList([ + (argument if isinstance(argument, Node) + else Value(argument, False)) + for argument in self.arguments])) + + if self._filter: + ctx.literal(' FILTER (WHERE ').sql(self._filter).literal(')') + return ctx + + +fn = Function(None, None) + + +class Window(Node): + # Frame start/end and frame exclusion. + CURRENT_ROW = SQL('CURRENT ROW') + GROUP = SQL('GROUP') + TIES = SQL('TIES') + NO_OTHERS = SQL('NO OTHERS') + + # Frame types. + GROUPS = 'GROUPS' + RANGE = 'RANGE' + ROWS = 'ROWS' + + def __init__(self, partition_by=None, order_by=None, start=None, end=None, + frame_type=None, extends=None, exclude=None, alias=None, + _inline=False): + super(Window, self).__init__() + if start is not None and not isinstance(start, SQL): + start = SQL(start) + if end is not None and not isinstance(end, SQL): + end = SQL(end) + + self.partition_by = ensure_tuple(partition_by) + self.order_by = ensure_tuple(order_by) + self.start = start + self.end = end + if self.start is None and self.end is not None: + raise ValueError('Cannot specify WINDOW end without start.') + self._alias = alias or 'w' + self._inline = _inline + self.frame_type = frame_type + self._extends = extends + self._exclude = exclude + + def alias(self, alias=None): + self._alias = alias or 'w' + return self + + @Node.copy + def as_range(self): + self.frame_type = Window.RANGE + + @Node.copy + def as_rows(self): + self.frame_type = Window.ROWS + + @Node.copy + def as_groups(self): + self.frame_type = Window.GROUPS + + @Node.copy + def extends(self, window=None): + self._extends = window + + @Node.copy + def exclude(self, frame_exclusion=None): + if isinstance(frame_exclusion, basestring): + frame_exclusion = SQL(frame_exclusion) + self._exclude = frame_exclusion + + @staticmethod + def following(value=None): + if value is None: + return SQL('UNBOUNDED FOLLOWING') + return SQL('%d FOLLOWING' % value) + + @staticmethod + def preceding(value=None): + if value is None: + return SQL('UNBOUNDED PRECEDING') + return SQL('%d PRECEDING' % value) + + def __sql__(self, ctx): + if ctx.scope != SCOPE_SOURCE and not self._inline: + ctx.literal(self._alias) + ctx.literal(' AS ') + + with ctx(parentheses=True): + parts = [] + if self._extends is not None: + ext = self._extends + if isinstance(ext, Window): + ext = SQL(ext._alias) + elif isinstance(ext, basestring): + ext = SQL(ext) + parts.append(ext) + if self.partition_by: + parts.extend(( + SQL('PARTITION BY'), + CommaNodeList(self.partition_by))) + if self.order_by: + parts.extend(( + SQL('ORDER BY'), + CommaNodeList(self.order_by))) + if self.start is not None and self.end is not None: + frame = self.frame_type or 'ROWS' + parts.extend(( + SQL('%s BETWEEN' % frame), + self.start, + SQL('AND'), + self.end)) + elif self.start is not None: + parts.extend((SQL(self.frame_type or 'ROWS'), self.start)) + elif self.frame_type is not None: + parts.append(SQL('%s UNBOUNDED PRECEDING' % self.frame_type)) + if self._exclude is not None: + parts.extend((SQL('EXCLUDE'), self._exclude)) + ctx.sql(NodeList(parts)) + return ctx + + +class WindowAlias(Node): + def __init__(self, window): + self.window = window + + def alias(self, window_alias): + self.window._alias = window_alias + return self + + def __sql__(self, ctx): + return ctx.literal(self.window._alias or 'w') + + +class ForUpdate(Node): + def __init__(self, expr, of=None, nowait=None): + expr = 'FOR UPDATE' if expr is True else expr + if expr.lower().endswith('nowait'): + expr = expr[:-7] # Strip off the "nowait" bit. + nowait = True + + self._expr = expr + if of is not None and not isinstance(of, (list, set, tuple)): + of = (of,) + self._of = of + self._nowait = nowait + + def __sql__(self, ctx): + ctx.literal(self._expr) + if self._of is not None: + ctx.literal(' OF ').sql(CommaNodeList(self._of)) + if self._nowait: + ctx.literal(' NOWAIT') + return ctx + + +def Case(predicate, expression_tuples, default=None): + clauses = [SQL('CASE')] + if predicate is not None: + clauses.append(predicate) + for expr, value in expression_tuples: + clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value)) + if default is not None: + clauses.extend((SQL('ELSE'), default)) + clauses.append(SQL('END')) + return NodeList(clauses) + + +class NodeList(ColumnBase): + def __init__(self, nodes, glue=' ', parens=False): + self.nodes = nodes + self.glue = glue + self.parens = parens + if parens and len(self.nodes) == 1: + if isinstance(self.nodes[0], Expression): + # Hack to avoid double-parentheses. + self.nodes[0].flat = True + + def __sql__(self, ctx): + n_nodes = len(self.nodes) + if n_nodes == 0: + return ctx.literal('()') if self.parens else ctx + with ctx(parentheses=self.parens): + for i in range(n_nodes - 1): + ctx.sql(self.nodes[i]) + ctx.literal(self.glue) + ctx.sql(self.nodes[n_nodes - 1]) + return ctx + + +def CommaNodeList(nodes): + return NodeList(nodes, ', ') + + +def EnclosedNodeList(nodes): + return NodeList(nodes, ', ', True) + + +class _Namespace(Node): + __slots__ = ('_name',) + def __init__(self, name): + self._name = name + def __getattr__(self, attr): + return NamespaceAttribute(self, attr) + __getitem__ = __getattr__ + +class NamespaceAttribute(ColumnBase): + def __init__(self, namespace, attribute): + self._namespace = namespace + self._attribute = attribute + + def __sql__(self, ctx): + return (ctx + .literal(self._namespace._name + '.') + .sql(Entity(self._attribute))) + +EXCLUDED = _Namespace('EXCLUDED') + + +class DQ(ColumnBase): + def __init__(self, **query): + super(DQ, self).__init__() + self.query = query + self._negated = False + + @Node.copy + def __invert__(self): + self._negated = not self._negated + + def clone(self): + node = DQ(**self.query) + node._negated = self._negated + return node + +#: Represent a row tuple. +Tuple = lambda *a: EnclosedNodeList(a) + + +class QualifiedNames(WrappedNode): + def __sql__(self, ctx): + with ctx.scope_column(): + return ctx.sql(self.node) + + +def qualify_names(node): + # Search a node heirarchy to ensure that any column-like objects are + # referenced using fully-qualified names. + if isinstance(node, Expression): + return node.__class__(qualify_names(node.lhs), node.op, + qualify_names(node.rhs), node.flat) + elif isinstance(node, ColumnBase): + return QualifiedNames(node) + return node + + +class OnConflict(Node): + def __init__(self, action=None, update=None, preserve=None, where=None, + conflict_target=None, conflict_where=None, + conflict_constraint=None): + self._action = action + self._update = update + self._preserve = ensure_tuple(preserve) + self._where = where + if conflict_target is not None and conflict_constraint is not None: + raise ValueError('only one of "conflict_target" and ' + '"conflict_constraint" may be specified.') + self._conflict_target = ensure_tuple(conflict_target) + self._conflict_where = conflict_where + self._conflict_constraint = conflict_constraint + + def get_conflict_statement(self, ctx, query): + return ctx.state.conflict_statement(self, query) + + def get_conflict_update(self, ctx, query): + return ctx.state.conflict_update(self, query) + + @Node.copy + def preserve(self, *columns): + self._preserve = columns + + @Node.copy + def update(self, _data=None, **kwargs): + if _data and kwargs and not isinstance(_data, dict): + raise ValueError('Cannot mix data with keyword arguments in the ' + 'OnConflict update method.') + _data = _data or {} + if kwargs: + _data.update(kwargs) + self._update = _data + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def conflict_target(self, *constraints): + self._conflict_constraint = None + self._conflict_target = constraints + + @Node.copy + def conflict_where(self, *expressions): + if self._conflict_where is not None: + expressions = (self._conflict_where,) + expressions + self._conflict_where = reduce(operator.and_, expressions) + + @Node.copy + def conflict_constraint(self, constraint): + self._conflict_constraint = constraint + self._conflict_target = None + + +def database_required(method): + @wraps(method) + def inner(self, database=None, *args, **kwargs): + database = self._database if database is None else database + if not database: + raise InterfaceError('Query must be bound to a database in order ' + 'to call "%s".' % method.__name__) + return method(self, database, *args, **kwargs) + return inner + +# BASE QUERY INTERFACE. + +class BaseQuery(Node): + default_row_type = ROW.DICT + + def __init__(self, _database=None, **kwargs): + self._database = _database + self._cursor_wrapper = None + self._row_type = None + self._constructor = None + super(BaseQuery, self).__init__(**kwargs) + + def bind(self, database=None): + self._database = database + return self + + def clone(self): + query = super(BaseQuery, self).clone() + query._cursor_wrapper = None + return query + + @Node.copy + def dicts(self, as_dict=True): + self._row_type = ROW.DICT if as_dict else None + return self + + @Node.copy + def tuples(self, as_tuple=True): + self._row_type = ROW.TUPLE if as_tuple else None + return self + + @Node.copy + def namedtuples(self, as_namedtuple=True): + self._row_type = ROW.NAMED_TUPLE if as_namedtuple else None + return self + + @Node.copy + def objects(self, constructor=None): + self._row_type = ROW.CONSTRUCTOR if constructor else None + self._constructor = constructor + return self + + def _get_cursor_wrapper(self, cursor): + row_type = self._row_type or self.default_row_type + + if row_type == ROW.DICT: + return DictCursorWrapper(cursor) + elif row_type == ROW.TUPLE: + return CursorWrapper(cursor) + elif row_type == ROW.NAMED_TUPLE: + return NamedTupleCursorWrapper(cursor) + elif row_type == ROW.CONSTRUCTOR: + return ObjectCursorWrapper(cursor, self._constructor) + else: + raise ValueError('Unrecognized row type: "%s".' % row_type) + + def __sql__(self, ctx): + raise NotImplementedError + + def sql(self): + if self._database: + context = self._database.get_sql_context() + else: + context = Context() + return context.parse(self) + + @database_required + def execute(self, database): + return self._execute(database) + + def _execute(self, database): + raise NotImplementedError + + def iterator(self, database=None): + return iter(self.execute(database).iterator()) + + def _ensure_execution(self): + if not self._cursor_wrapper: + if not self._database: + raise ValueError('Query has not been executed.') + self.execute() + + def __iter__(self): + self._ensure_execution() + return iter(self._cursor_wrapper) + + def __getitem__(self, value): + self._ensure_execution() + if isinstance(value, slice): + index = value.stop + else: + index = value + if index is not None: + index = index + 1 if index >= 0 else 0 + self._cursor_wrapper.fill_cache(index) + return self._cursor_wrapper.row_cache[value] + + def __len__(self): + self._ensure_execution() + return len(self._cursor_wrapper) + + def __str__(self): + return query_to_string(self) + + +class RawQuery(BaseQuery): + def __init__(self, sql=None, params=None, **kwargs): + super(RawQuery, self).__init__(**kwargs) + self._sql = sql + self._params = params + + def __sql__(self, ctx): + ctx.literal(self._sql) + if self._params: + for param in self._params: + ctx.value(param, add_param=False) + return ctx + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + +class Query(BaseQuery): + def __init__(self, where=None, order_by=None, limit=None, offset=None, + **kwargs): + super(Query, self).__init__(**kwargs) + self._where = where + self._order_by = order_by + self._limit = limit + self._offset = offset + + self._cte_list = None + + @Node.copy + def with_cte(self, *cte_list): + self._cte_list = cte_list + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def orwhere(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.or_, expressions) + + @Node.copy + def order_by(self, *values): + self._order_by = values + + @Node.copy + def order_by_extend(self, *values): + self._order_by = ((self._order_by or ()) + values) or None + + @Node.copy + def limit(self, value=None): + self._limit = value + + @Node.copy + def offset(self, value=None): + self._offset = value + + @Node.copy + def paginate(self, page, paginate_by=20): + if page > 0: + page -= 1 + self._limit = paginate_by + self._offset = page * paginate_by + + def _apply_ordering(self, ctx): + if self._order_by: + (ctx + .literal(' ORDER BY ') + .sql(CommaNodeList(self._order_by))) + if self._limit is not None or (self._offset is not None and + ctx.state.limit_max): + ctx.literal(' LIMIT ').sql(self._limit or ctx.state.limit_max) + if self._offset is not None: + ctx.literal(' OFFSET ').sql(self._offset) + return ctx + + def __sql__(self, ctx): + if self._cte_list: + # The CTE scope is only used at the very beginning of the query, + # when we are describing the various CTEs we will be using. + recursive = any(cte._recursive for cte in self._cte_list) + + # Explicitly disable the "subquery" flag here, so as to avoid + # unnecessary parentheses around subsequent selects. + with ctx.scope_cte(subquery=False): + (ctx + .literal('WITH RECURSIVE ' if recursive else 'WITH ') + .sql(CommaNodeList(self._cte_list)) + .literal(' ')) + return ctx + + +def __compound_select__(operation, inverted=False): + def method(self, other): + if inverted: + self, other = other, self + return CompoundSelectQuery(self, operation, other) + return method + + +class SelectQuery(Query): + union_all = __add__ = __compound_select__('UNION ALL') + union = __or__ = __compound_select__('UNION') + intersect = __and__ = __compound_select__('INTERSECT') + except_ = __sub__ = __compound_select__('EXCEPT') + __radd__ = __compound_select__('UNION ALL', inverted=True) + __ror__ = __compound_select__('UNION', inverted=True) + __rand__ = __compound_select__('INTERSECT', inverted=True) + __rsub__ = __compound_select__('EXCEPT', inverted=True) + + def select_from(self, *columns): + if not columns: + raise ValueError('select_from() must specify one or more columns.') + + query = (Select((self,), columns) + .bind(self._database)) + if getattr(self, 'model', None) is not None: + # Bind to the sub-select's model type, if defined. + query = query.objects(self.model) + return query + + +class SelectBase(_HashableSource, Source, SelectQuery): + def _get_hash(self): + return hash((self.__class__, self._alias or id(self))) + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + @database_required + def peek(self, database, n=1): + rows = self.execute(database)[:n] + if rows: + return rows[0] if n == 1 else rows + + @database_required + def first(self, database, n=1): + if self._limit != n: + self._limit = n + self._cursor_wrapper = None + return self.peek(database, n=n) + + @database_required + def scalar(self, database, as_tuple=False): + row = self.tuples().peek(database) + return row[0] if row and not as_tuple else row + + @database_required + def count(self, database, clear_limit=False): + clone = self.order_by().alias('_wrapped') + if clear_limit: + clone._limit = clone._offset = None + try: + if clone._having is None and clone._group_by is None and \ + clone._windows is None and clone._distinct is None and \ + clone._simple_distinct is not True: + clone = clone.select(SQL('1')) + except AttributeError: + pass + return Select([clone], [fn.COUNT(SQL('1'))]).scalar(database) + + @database_required + def exists(self, database): + clone = self.columns(SQL('1')) + clone._limit = 1 + clone._offset = None + return bool(clone.scalar()) + + @database_required + def get(self, database): + self._cursor_wrapper = None + try: + return self.execute(database)[0] + except IndexError: + pass + + +# QUERY IMPLEMENTATIONS. + + +class CompoundSelectQuery(SelectBase): + def __init__(self, lhs, op, rhs): + super(CompoundSelectQuery, self).__init__() + self.lhs = lhs + self.op = op + self.rhs = rhs + + @property + def _returning(self): + return self.lhs._returning + + @database_required + def exists(self, database): + query = Select((self.limit(1),), (SQL('1'),)).bind(database) + return bool(query.scalar()) + + def _get_query_key(self): + return (self.lhs.get_query_key(), self.rhs.get_query_key()) + + def _wrap_parens(self, ctx, subq): + csq_setting = ctx.state.compound_select_parentheses + + if not csq_setting or csq_setting == CSQ_PARENTHESES_NEVER: + return False + elif csq_setting == CSQ_PARENTHESES_ALWAYS: + return True + elif csq_setting == CSQ_PARENTHESES_UNNESTED: + if ctx.state.in_expr or ctx.state.in_function: + # If this compound select query is being used inside an + # expression, e.g., an IN or EXISTS(). + return False + + # If the query on the left or right is itself a compound select + # query, then we do not apply parentheses. However, if it is a + # regular SELECT query, we will apply parentheses. + return not isinstance(subq, CompoundSelectQuery) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_COLUMN: + return self.apply_column(ctx) + + outer_parens = ctx.subquery or (ctx.scope == SCOPE_SOURCE) + with ctx(parentheses=outer_parens): + # Should the left-hand query be wrapped in parentheses? + lhs_parens = self._wrap_parens(ctx, self.lhs) + with ctx.scope_normal(parentheses=lhs_parens, subquery=False): + ctx.sql(self.lhs) + ctx.literal(' %s ' % self.op) + with ctx.push_alias(): + # Should the right-hand query be wrapped in parentheses? + rhs_parens = self._wrap_parens(ctx, self.rhs) + with ctx.scope_normal(parentheses=rhs_parens, subquery=False): + ctx.sql(self.rhs) + + # Apply ORDER BY, LIMIT, OFFSET. We use the "values" scope so that + # entity names are not fully-qualified. This is a bit of a hack, as + # we're relying on the logic in Column.__sql__() to not fully + # qualify column names. + with ctx.scope_values(): + self._apply_ordering(ctx) + + return self.apply_alias(ctx) + + +class Select(SelectBase): + def __init__(self, from_list=None, columns=None, group_by=None, + having=None, distinct=None, windows=None, for_update=None, + for_update_of=None, nowait=None, **kwargs): + super(Select, self).__init__(**kwargs) + self._from_list = (list(from_list) if isinstance(from_list, tuple) + else from_list) or [] + self._returning = columns + self._group_by = group_by + self._having = having + self._windows = None + self._for_update = for_update # XXX: consider reorganizing. + self._for_update_of = for_update_of + self._for_update_nowait = nowait + + self._distinct = self._simple_distinct = None + if distinct: + if isinstance(distinct, bool): + self._simple_distinct = distinct + else: + self._distinct = distinct + + self._cursor_wrapper = None + + def clone(self): + clone = super(Select, self).clone() + if clone._from_list: + clone._from_list = list(clone._from_list) + return clone + + @Node.copy + def columns(self, *columns, **kwargs): + self._returning = columns + select = columns + + @Node.copy + def select_extend(self, *columns): + self._returning = tuple(self._returning) + columns + + @Node.copy + def from_(self, *sources): + self._from_list = list(sources) + + @Node.copy + def join(self, dest, join_type=JOIN.INNER, on=None): + if not self._from_list: + raise ValueError('No sources to join on.') + item = self._from_list.pop() + self._from_list.append(Join(item, dest, join_type, on)) + + @Node.copy + def group_by(self, *columns): + grouping = [] + for column in columns: + if isinstance(column, Table): + if not column._columns: + raise ValueError('Cannot pass a table to group_by() that ' + 'does not have columns explicitly ' + 'declared.') + grouping.extend([getattr(column, col_name) + for col_name in column._columns]) + else: + grouping.append(column) + self._group_by = grouping + + def group_by_extend(self, *values): + """@Node.copy used from group_by() call""" + group_by = tuple(self._group_by or ()) + values + return self.group_by(*group_by) + + @Node.copy + def having(self, *expressions): + if self._having is not None: + expressions = (self._having,) + expressions + self._having = reduce(operator.and_, expressions) + + @Node.copy + def distinct(self, *columns): + if len(columns) == 1 and (columns[0] is True or columns[0] is False): + self._simple_distinct = columns[0] + else: + self._simple_distinct = False + self._distinct = columns + + @Node.copy + def window(self, *windows): + self._windows = windows if windows else None + + @Node.copy + def for_update(self, for_update=True, of=None, nowait=None): + if not for_update and (of is not None or nowait): + for_update = True + self._for_update = for_update + self._for_update_of = of + self._for_update_nowait = nowait + + def _get_query_key(self): + return self._alias + + def __sql_selection__(self, ctx, is_subquery=False): + return ctx.sql(CommaNodeList(self._returning)) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_COLUMN: + return self.apply_column(ctx) + + is_subquery = ctx.subquery + state = { + 'converter': None, + 'in_function': False, + 'parentheses': is_subquery or (ctx.scope == SCOPE_SOURCE), + 'subquery': True, + } + if ctx.state.in_function and ctx.state.function_arg_count == 1: + state['parentheses'] = False + + with ctx.scope_normal(**state): + # Defer calling parent SQL until here. This ensures that any CTEs + # for this query will be properly nested if this query is a + # sub-select or is used in an expression. See GH#1809 for example. + super(Select, self).__sql__(ctx) + + ctx.literal('SELECT ') + if self._simple_distinct or self._distinct is not None: + ctx.literal('DISTINCT ') + if self._distinct: + (ctx + .literal('ON ') + .sql(EnclosedNodeList(self._distinct)) + .literal(' ')) + + with ctx.scope_source(): + ctx = self.__sql_selection__(ctx, is_subquery) + + if self._from_list: + with ctx.scope_source(parentheses=False): + ctx.literal(' FROM ').sql(CommaNodeList(self._from_list)) + + if self._where is not None: + ctx.literal(' WHERE ').sql(self._where) + + if self._group_by: + ctx.literal(' GROUP BY ').sql(CommaNodeList(self._group_by)) + + if self._having is not None: + ctx.literal(' HAVING ').sql(self._having) + + if self._windows is not None: + ctx.literal(' WINDOW ') + ctx.sql(CommaNodeList(self._windows)) + + # Apply ORDER BY, LIMIT, OFFSET. + self._apply_ordering(ctx) + + if self._for_update: + if not ctx.state.for_update: + raise ValueError('FOR UPDATE specified but not supported ' + 'by database.') + ctx.literal(' ') + ctx.sql(ForUpdate(self._for_update, self._for_update_of, + self._for_update_nowait)) + + # If the subquery is inside a function -or- we are evaluating a + # subquery on either side of an expression w/o an explicit alias, do + # not generate an alias + AS clause. + if ctx.state.in_function or (ctx.state.in_expr and + self._alias is None): + return ctx + + return self.apply_alias(ctx) + + +class _WriteQuery(Query): + def __init__(self, table, returning=None, **kwargs): + self.table = table + self._returning = returning + self._return_cursor = True if returning else False + super(_WriteQuery, self).__init__(**kwargs) + + @Node.copy + def returning(self, *returning): + self._returning = returning + self._return_cursor = True if returning else False + + def apply_returning(self, ctx): + if self._returning: + with ctx.scope_source(): + ctx.literal(' RETURNING ').sql(CommaNodeList(self._returning)) + return ctx + + def _execute(self, database): + if self._returning: + cursor = self.execute_returning(database) + else: + cursor = database.execute(self) + return self.handle_result(database, cursor) + + def execute_returning(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + def handle_result(self, database, cursor): + if self._return_cursor: + return cursor + return database.rows_affected(cursor) + + def _set_table_alias(self, ctx): + ctx.alias_manager[self.table] = self.table.__name__ + + def __sql__(self, ctx): + super(_WriteQuery, self).__sql__(ctx) + # We explicitly set the table alias to the table's name, which ensures + # that if a sub-select references a column on the outer table, we won't + # assign it a new alias (e.g. t2) but will refer to it as table.column. + self._set_table_alias(ctx) + return ctx + + +class Update(_WriteQuery): + def __init__(self, table, update=None, **kwargs): + super(Update, self).__init__(table, **kwargs) + self._update = update + self._from = None + + @Node.copy + def from_(self, *sources): + self._from = sources + + def __sql__(self, ctx): + super(Update, self).__sql__(ctx) + + with ctx.scope_values(subquery=True): + ctx.literal('UPDATE ') + + expressions = [] + for k, v in sorted(self._update.items(), key=ctx.column_sort_key): + if not isinstance(v, Node): + if isinstance(k, Field): + v = k.to_value(v) + else: + v = Value(v, unpack=False) + if not isinstance(v, Value): + v = qualify_names(v) + expressions.append(NodeList((k, SQL('='), v))) + + (ctx + .sql(self.table) + .literal(' SET ') + .sql(CommaNodeList(expressions))) + + if self._from: + with ctx.scope_source(parentheses=False): + ctx.literal(' FROM ').sql(CommaNodeList(self._from)) + + if self._where: + with ctx.scope_normal(): + ctx.literal(' WHERE ').sql(self._where) + self._apply_ordering(ctx) + return self.apply_returning(ctx) + + +class Insert(_WriteQuery): + SIMPLE = 0 + QUERY = 1 + MULTI = 2 + class DefaultValuesException(Exception): pass + + def __init__(self, table, insert=None, columns=None, on_conflict=None, + **kwargs): + super(Insert, self).__init__(table, **kwargs) + self._insert = insert + self._columns = columns + self._on_conflict = on_conflict + self._query_type = None + + def where(self, *expressions): + raise NotImplementedError('INSERT queries cannot have a WHERE clause.') + + @Node.copy + def on_conflict_ignore(self, ignore=True): + self._on_conflict = OnConflict('IGNORE') if ignore else None + + @Node.copy + def on_conflict_replace(self, replace=True): + self._on_conflict = OnConflict('REPLACE') if replace else None + + @Node.copy + def on_conflict(self, *args, **kwargs): + self._on_conflict = (OnConflict(*args, **kwargs) if (args or kwargs) + else None) + + def _simple_insert(self, ctx): + if not self._insert: + raise self.DefaultValuesException('Error: no data to insert.') + return self._generate_insert((self._insert,), ctx) + + def get_default_data(self): + return {} + + def get_default_columns(self): + if self.table._columns: + return [getattr(self.table, col) for col in self.table._columns + if col != self.table._primary_key] + + def _generate_insert(self, insert, ctx): + rows_iter = iter(insert) + columns = self._columns + + # Load and organize column defaults (if provided). + defaults = self.get_default_data() + + # First figure out what columns are being inserted (if they weren't + # specified explicitly). Resulting columns are normalized and ordered. + if not columns: + try: + row = next(rows_iter) + except StopIteration: + raise self.DefaultValuesException('Error: no rows to insert.') + + if not isinstance(row, Mapping): + columns = self.get_default_columns() + if columns is None: + raise ValueError('Bulk insert must specify columns.') + else: + # Infer column names from the dict of data being inserted. + accum = [] + for column in row: + if isinstance(column, basestring): + column = getattr(self.table, column) + accum.append(column) + + # Add any columns present in the default data that are not + # accounted for by the dictionary of row data. + column_set = set(accum) + for col in (set(defaults) - column_set): + accum.append(col) + + columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx)) + rows_iter = itertools.chain(iter((row,)), rows_iter) + else: + clean_columns = [] + seen = set() + for column in columns: + if isinstance(column, basestring): + column_obj = getattr(self.table, column) + else: + column_obj = column + clean_columns.append(column_obj) + seen.add(column_obj) + + columns = clean_columns + for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)): + if col not in seen: + columns.append(col) + + value_lookups = {} + for column in columns: + lookups = [column, column.name] + if isinstance(column, Field) and column.name != column.column_name: + lookups.append(column.column_name) + value_lookups[column] = lookups + + ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ') + columns_converters = [ + (column, column.db_value if isinstance(column, Field) else None) + for column in columns] + + all_values = [] + for row in rows_iter: + values = [] + is_dict = isinstance(row, Mapping) + for i, (column, converter) in enumerate(columns_converters): + try: + if is_dict: + # The logic is a bit convoluted, but in order to be + # flexible in what we accept (dict keyed by + # column/field, field name, or underlying column name), + # we try accessing the row data dict using each + # possible key. If no match is found, throw an error. + for lookup in value_lookups[column]: + try: + val = row[lookup] + except KeyError: pass + else: break + else: + raise KeyError + else: + val = row[i] + except (KeyError, IndexError): + if column in defaults: + val = defaults[column] + if callable_(val): + val = val() + else: + raise ValueError('Missing value for %s.' % column.name) + + if not isinstance(val, Node): + val = Value(val, converter=converter, unpack=False) + values.append(val) + + all_values.append(EnclosedNodeList(values)) + + if not all_values: + raise self.DefaultValuesException('Error: no data to insert.') + + with ctx.scope_values(subquery=True): + return ctx.sql(CommaNodeList(all_values)) + + def _query_insert(self, ctx): + return (ctx + .sql(EnclosedNodeList(self._columns)) + .literal(' ') + .sql(self._insert)) + + def _default_values(self, ctx): + if not self._database: + return ctx.literal('DEFAULT VALUES') + return self._database.default_values_insert(ctx) + + def __sql__(self, ctx): + super(Insert, self).__sql__(ctx) + with ctx.scope_values(): + stmt = None + if self._on_conflict is not None: + stmt = self._on_conflict.get_conflict_statement(ctx, self) + + (ctx + .sql(stmt or SQL('INSERT')) + .literal(' INTO ') + .sql(self.table) + .literal(' ')) + + if isinstance(self._insert, Mapping) and not self._columns: + try: + self._simple_insert(ctx) + except self.DefaultValuesException: + self._default_values(ctx) + self._query_type = Insert.SIMPLE + elif isinstance(self._insert, (SelectQuery, SQL)): + self._query_insert(ctx) + self._query_type = Insert.QUERY + else: + self._generate_insert(self._insert, ctx) + self._query_type = Insert.MULTI + + if self._on_conflict is not None: + update = self._on_conflict.get_conflict_update(ctx, self) + if update is not None: + ctx.literal(' ').sql(update) + + return self.apply_returning(ctx) + + def _execute(self, database): + if self._returning is None and database.returning_clause \ + and self.table._primary_key: + self._returning = (self.table._primary_key,) + try: + return super(Insert, self)._execute(database) + except self.DefaultValuesException: + pass + + def handle_result(self, database, cursor): + if self._return_cursor: + return cursor + if self._query_type != Insert.SIMPLE and not self._returning: + return database.rows_affected(cursor) + return database.last_insert_id(cursor, self._query_type) + + +class Delete(_WriteQuery): + def __sql__(self, ctx): + super(Delete, self).__sql__(ctx) + + with ctx.scope_values(subquery=True): + ctx.literal('DELETE FROM ').sql(self.table) + if self._where is not None: + with ctx.scope_normal(): + ctx.literal(' WHERE ').sql(self._where) + + self._apply_ordering(ctx) + return self.apply_returning(ctx) + + +class Index(Node): + def __init__(self, name, table, expressions, unique=False, safe=False, + where=None, using=None): + self._name = name + self._table = Entity(table) if not isinstance(table, Table) else table + self._expressions = expressions + self._where = where + self._unique = unique + self._safe = safe + self._using = using + + @Node.copy + def safe(self, _safe=True): + self._safe = _safe + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def using(self, _using=None): + self._using = _using + + def __sql__(self, ctx): + statement = 'CREATE UNIQUE INDEX ' if self._unique else 'CREATE INDEX ' + with ctx.scope_values(subquery=True): + ctx.literal(statement) + if self._safe: + ctx.literal('IF NOT EXISTS ') + + # Sqlite uses CREATE INDEX . ON , whereas most + # others use: CREATE INDEX ON .
. + if ctx.state.index_schema_prefix and \ + isinstance(self._table, Table) and self._table._schema: + index_name = Entity(self._table._schema, self._name) + table_name = Entity(self._table.__name__) + else: + index_name = Entity(self._name) + table_name = self._table + + (ctx + .sql(index_name) + .literal(' ON ') + .sql(table_name) + .literal(' ')) + if self._using is not None: + ctx.literal('USING %s ' % self._using) + + ctx.sql(EnclosedNodeList([ + SQL(expr) if isinstance(expr, basestring) else expr + for expr in self._expressions])) + if self._where is not None: + ctx.literal(' WHERE ').sql(self._where) + + return ctx + + +class ModelIndex(Index): + def __init__(self, model, fields, unique=False, safe=True, where=None, + using=None, name=None): + self._model = model + if name is None: + name = self._generate_name_from_fields(model, fields) + if using is None: + for field in fields: + if isinstance(field, Field) and hasattr(field, 'index_type'): + using = field.index_type + super(ModelIndex, self).__init__( + name=name, + table=model._meta.table, + expressions=fields, + unique=unique, + safe=safe, + where=where, + using=using) + + def _generate_name_from_fields(self, model, fields): + accum = [] + for field in fields: + if isinstance(field, basestring): + accum.append(field.split()[0]) + else: + if isinstance(field, Node) and not isinstance(field, Field): + field = field.unwrap() + if isinstance(field, Field): + accum.append(field.column_name) + + if not accum: + raise ValueError('Unable to generate a name for the index, please ' + 'explicitly specify a name.') + + clean_field_names = re.sub('[^\w]+', '', '_'.join(accum)) + meta = model._meta + prefix = meta.name if meta.legacy_table_names else meta.table_name + return _truncate_constraint_name('_'.join((prefix, clean_field_names))) + + +def _truncate_constraint_name(constraint, maxlen=64): + if len(constraint) > maxlen: + name_hash = hashlib.md5(constraint.encode('utf-8')).hexdigest() + constraint = '%s_%s' % (constraint[:(maxlen - 8)], name_hash[:7]) + return constraint + + +# DB-API 2.0 EXCEPTIONS. + + +class PeeweeException(Exception): + def __init__(self, *args): + if args and isinstance(args[0], Exception): + self.orig, args = args[0], args[1:] + super(PeeweeException, self).__init__(*args) +class ImproperlyConfigured(PeeweeException): pass +class DatabaseError(PeeweeException): pass +class DataError(DatabaseError): pass +class IntegrityError(DatabaseError): pass +class InterfaceError(PeeweeException): pass +class InternalError(DatabaseError): pass +class NotSupportedError(DatabaseError): pass +class OperationalError(DatabaseError): pass +class ProgrammingError(DatabaseError): pass + + +class ExceptionWrapper(object): + __slots__ = ('exceptions',) + def __init__(self, exceptions): + self.exceptions = exceptions + def __enter__(self): pass + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + return + # psycopg2.8 shits out a million cute error types. Try to catch em all. + if pg_errors is not None and exc_type.__name__ not in self.exceptions \ + and issubclass(exc_type, pg_errors.Error): + exc_type = exc_type.__bases__[0] + if exc_type.__name__ in self.exceptions: + new_type = self.exceptions[exc_type.__name__] + exc_args = exc_value.args + reraise(new_type, new_type(exc_value, *exc_args), traceback) + + +EXCEPTIONS = { + 'ConstraintError': IntegrityError, + 'DatabaseError': DatabaseError, + 'DataError': DataError, + 'IntegrityError': IntegrityError, + 'InterfaceError': InterfaceError, + 'InternalError': InternalError, + 'NotSupportedError': NotSupportedError, + 'OperationalError': OperationalError, + 'ProgrammingError': ProgrammingError, + 'TransactionRollbackError': OperationalError} + +__exception_wrapper__ = ExceptionWrapper(EXCEPTIONS) + + +# DATABASE INTERFACE AND CONNECTION MANAGEMENT. + + +IndexMetadata = collections.namedtuple( + 'IndexMetadata', + ('name', 'sql', 'columns', 'unique', 'table')) +ColumnMetadata = collections.namedtuple( + 'ColumnMetadata', + ('name', 'data_type', 'null', 'primary_key', 'table', 'default')) +ForeignKeyMetadata = collections.namedtuple( + 'ForeignKeyMetadata', + ('column', 'dest_table', 'dest_column', 'table')) +ViewMetadata = collections.namedtuple('ViewMetadata', ('name', 'sql')) + + +class _ConnectionState(object): + def __init__(self, **kwargs): + super(_ConnectionState, self).__init__(**kwargs) + self.reset() + + def reset(self): + self.closed = True + self.conn = None + self.ctx = [] + self.transactions = [] + + def set_connection(self, conn): + self.conn = conn + self.closed = False + self.ctx = [] + self.transactions = [] + + +class _ConnectionLocal(_ConnectionState, threading.local): pass +class _NoopLock(object): + __slots__ = () + def __enter__(self): return self + def __exit__(self, exc_type, exc_val, exc_tb): pass + + +class ConnectionContext(_callable_context_manager): + __slots__ = ('db',) + def __init__(self, db): self.db = db + def __enter__(self): + if self.db.is_closed(): + self.db.connect() + def __exit__(self, exc_type, exc_val, exc_tb): self.db.close() + + +class Database(_callable_context_manager): + context_class = Context + field_types = {} + operations = {} + param = '?' + quote = '""' + server_version = None + + # Feature toggles. + commit_select = False + compound_select_parentheses = CSQ_PARENTHESES_NEVER + for_update = False + index_schema_prefix = False + limit_max = None + nulls_ordering = False + returning_clause = False + safe_create_index = True + safe_drop_index = True + sequences = False + truncate_table = True + + def __init__(self, database, thread_safe=True, autorollback=False, + field_types=None, operations=None, autocommit=None, + autoconnect=True, **kwargs): + self._field_types = merge_dict(FIELD, self.field_types) + self._operations = merge_dict(OP, self.operations) + if field_types: + self._field_types.update(field_types) + if operations: + self._operations.update(operations) + + self.autoconnect = autoconnect + self.autorollback = autorollback + self.thread_safe = thread_safe + if thread_safe: + self._state = _ConnectionLocal() + self._lock = threading.Lock() + else: + self._state = _ConnectionState() + self._lock = _NoopLock() + + if autocommit is not None: + __deprecated__('Peewee no longer uses the "autocommit" option, as ' + 'the semantics now require it to always be True. ' + 'Because some database-drivers also use the ' + '"autocommit" parameter, you are receiving a ' + 'warning so you may update your code and remove ' + 'the parameter, as in the future, specifying ' + 'autocommit could impact the behavior of the ' + 'database driver you are using.') + + self.connect_params = {} + self.init(database, **kwargs) + + def init(self, database, **kwargs): + if not self.is_closed(): + self.close() + self.database = database + self.connect_params.update(kwargs) + self.deferred = not bool(database) + + def __enter__(self): + if self.is_closed(): + self.connect() + ctx = self.atomic() + self._state.ctx.append(ctx) + ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + ctx = self._state.ctx.pop() + try: + ctx.__exit__(exc_type, exc_val, exc_tb) + finally: + if not self._state.ctx: + self.close() + + def connection_context(self): + return ConnectionContext(self) + + def _connect(self): + raise NotImplementedError + + def connect(self, reuse_if_open=False): + with self._lock: + if self.deferred: + raise InterfaceError('Error, database must be initialized ' + 'before opening a connection.') + if not self._state.closed: + if reuse_if_open: + return False + raise OperationalError('Connection already opened.') + + self._state.reset() + with __exception_wrapper__: + self._state.set_connection(self._connect()) + if self.server_version is None: + self._set_server_version(self._state.conn) + self._initialize_connection(self._state.conn) + return True + + def _initialize_connection(self, conn): + pass + + def _set_server_version(self, conn): + self.server_version = 0 + + def close(self): + with self._lock: + if self.deferred: + raise InterfaceError('Error, database must be initialized ' + 'before opening a connection.') + if self.in_transaction(): + raise OperationalError('Attempting to close database while ' + 'transaction is open.') + is_open = not self._state.closed + try: + if is_open: + with __exception_wrapper__: + self._close(self._state.conn) + finally: + self._state.reset() + return is_open + + def _close(self, conn): + conn.close() + + def is_closed(self): + return self._state.closed + + def is_connection_usable(self): + return not self._state.closed + + def connection(self): + if self.is_closed(): + self.connect() + return self._state.conn + + def cursor(self, commit=None): + if self.is_closed(): + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') + return self._state.conn.cursor() + + def execute_sql(self, sql, params=None, commit=SENTINEL): + logger.debug((sql, params)) + if commit is SENTINEL: + if self.in_transaction(): + commit = False + elif self.commit_select: + commit = True + else: + commit = not sql[:6].lower().startswith('select') + + with __exception_wrapper__: + cursor = self.cursor(commit) + try: + cursor.execute(sql, params or ()) + except Exception: + if self.autorollback and not self.in_transaction(): + self.rollback() + raise + else: + if commit and not self.in_transaction(): + self.commit() + return cursor + + def execute(self, query, commit=SENTINEL, **context_options): + ctx = self.get_sql_context(**context_options) + sql, params = ctx.sql(query).query() + return self.execute_sql(sql, params, commit=commit) + + def get_context_options(self): + return { + 'field_types': self._field_types, + 'operations': self._operations, + 'param': self.param, + 'quote': self.quote, + 'compound_select_parentheses': self.compound_select_parentheses, + 'conflict_statement': self.conflict_statement, + 'conflict_update': self.conflict_update, + 'for_update': self.for_update, + 'index_schema_prefix': self.index_schema_prefix, + 'limit_max': self.limit_max, + 'nulls_ordering': self.nulls_ordering, + } + + def get_sql_context(self, **context_options): + context = self.get_context_options() + if context_options: + context.update(context_options) + return self.context_class(**context) + + def conflict_statement(self, on_conflict, query): + raise NotImplementedError + + def conflict_update(self, on_conflict, query): + raise NotImplementedError + + def _build_on_conflict_update(self, on_conflict, query): + if on_conflict._conflict_target: + stmt = SQL('ON CONFLICT') + target = EnclosedNodeList([ + Entity(col) if isinstance(col, basestring) else col + for col in on_conflict._conflict_target]) + if on_conflict._conflict_where is not None: + target = NodeList([target, SQL('WHERE'), + on_conflict._conflict_where]) + else: + stmt = SQL('ON CONFLICT ON CONSTRAINT') + target = on_conflict._conflict_constraint + if isinstance(target, basestring): + target = Entity(target) + + updates = [] + if on_conflict._preserve: + for column in on_conflict._preserve: + excluded = NodeList((SQL('EXCLUDED'), ensure_entity(column)), + glue='.') + expression = NodeList((ensure_entity(column), SQL('='), + excluded)) + updates.append(expression) + + if on_conflict._update: + for k, v in on_conflict._update.items(): + if not isinstance(v, Node): + # Attempt to resolve string field-names to their respective + # field object, to apply data-type conversions. + if isinstance(k, basestring): + k = getattr(query.table, k) + if isinstance(k, Field): + v = k.to_value(v) + else: + v = Value(v, unpack=False) + else: + v = QualifiedNames(v) + updates.append(NodeList((ensure_entity(k), SQL('='), v))) + + parts = [stmt, target, SQL('DO UPDATE SET'), CommaNodeList(updates)] + if on_conflict._where: + parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where))) + + return NodeList(parts) + + def last_insert_id(self, cursor, query_type=None): + return cursor.lastrowid + + def rows_affected(self, cursor): + return cursor.rowcount + + def default_values_insert(self, ctx): + return ctx.literal('DEFAULT VALUES') + + def session_start(self): + with self._lock: + return self.transaction().__enter__() + + def session_commit(self): + with self._lock: + try: + txn = self.pop_transaction() + except IndexError: + return False + txn.commit(begin=self.in_transaction()) + return True + + def session_rollback(self): + with self._lock: + try: + txn = self.pop_transaction() + except IndexError: + return False + txn.rollback(begin=self.in_transaction()) + return True + + def in_transaction(self): + return bool(self._state.transactions) + + def push_transaction(self, transaction): + self._state.transactions.append(transaction) + + def pop_transaction(self): + return self._state.transactions.pop() + + def transaction_depth(self): + return len(self._state.transactions) + + def top_transaction(self): + if self._state.transactions: + return self._state.transactions[-1] + + def atomic(self, *args, **kwargs): + return _atomic(self, *args, **kwargs) + + def manual_commit(self): + return _manual(self) + + def transaction(self, *args, **kwargs): + return _transaction(self, *args, **kwargs) + + def savepoint(self): + return _savepoint(self) + + def begin(self): + if self.is_closed(): + self.connect() + + def commit(self): + with __exception_wrapper__: + return self._state.conn.commit() + + def rollback(self): + with __exception_wrapper__: + return self._state.conn.rollback() + + def batch_commit(self, it, n): + for group in chunked(it, n): + with self.atomic(): + for obj in group: + yield obj + + def table_exists(self, table_name, schema=None): + return table_name in self.get_tables(schema=schema) + + def get_tables(self, schema=None): + raise NotImplementedError + + def get_indexes(self, table, schema=None): + raise NotImplementedError + + def get_columns(self, table, schema=None): + raise NotImplementedError + + def get_primary_keys(self, table, schema=None): + raise NotImplementedError + + def get_foreign_keys(self, table, schema=None): + raise NotImplementedError + + def sequence_exists(self, seq): + raise NotImplementedError + + def create_tables(self, models, **options): + for model in sort_models(models): + model.create_table(**options) + + def drop_tables(self, models, **kwargs): + for model in reversed(sort_models(models)): + model.drop_table(**kwargs) + + def extract_date(self, date_part, date_field): + raise NotImplementedError + + def truncate_date(self, date_part, date_field): + raise NotImplementedError + + def to_timestamp(self, date_field): + raise NotImplementedError + + def from_timestamp(self, date_field): + raise NotImplementedError + + def random(self): + return fn.random() + + def bind(self, models, bind_refs=True, bind_backrefs=True): + for model in models: + model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs) + + def bind_ctx(self, models, bind_refs=True, bind_backrefs=True): + return _BoundModelsContext(models, self, bind_refs, bind_backrefs) + + def get_noop_select(self, ctx): + return ctx.sql(Select().columns(SQL('0')).where(SQL('0'))) + + +def __pragma__(name): + def __get__(self): + return self.pragma(name) + def __set__(self, value): + return self.pragma(name, value) + return property(__get__, __set__) + + +class SqliteDatabase(Database): + field_types = { + 'BIGAUTO': FIELD.AUTO, + 'BIGINT': FIELD.INT, + 'BOOL': FIELD.INT, + 'DOUBLE': FIELD.FLOAT, + 'SMALLINT': FIELD.INT, + 'UUID': FIELD.TEXT} + operations = { + 'LIKE': 'GLOB', + 'ILIKE': 'LIKE'} + index_schema_prefix = True + limit_max = -1 + server_version = __sqlite_version__ + truncate_table = False + + def __init__(self, database, *args, **kwargs): + self._pragmas = kwargs.pop('pragmas', ()) + super(SqliteDatabase, self).__init__(database, *args, **kwargs) + self._aggregates = {} + self._collations = {} + self._functions = {} + self._window_functions = {} + self._table_functions = [] + self._extensions = set() + self._attached = {} + self.register_function(_sqlite_date_part, 'date_part', 2) + self.register_function(_sqlite_date_trunc, 'date_trunc', 2) + self.nulls_ordering = self.server_version >= (3, 30, 0) + + def init(self, database, pragmas=None, timeout=5, **kwargs): + if pragmas is not None: + self._pragmas = pragmas + if isinstance(self._pragmas, dict): + self._pragmas = list(self._pragmas.items()) + self._timeout = timeout + super(SqliteDatabase, self).init(database, **kwargs) + + def _set_server_version(self, conn): + pass + + def _connect(self): + if sqlite3 is None: + raise ImproperlyConfigured('SQLite driver not installed!') + conn = sqlite3.connect(self.database, timeout=self._timeout, + isolation_level=None, **self.connect_params) + try: + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def _add_conn_hooks(self, conn): + if self._attached: + self._attach_databases(conn) + if self._pragmas: + self._set_pragmas(conn) + self._load_aggregates(conn) + self._load_collations(conn) + self._load_functions(conn) + if self.server_version >= (3, 25, 0): + self._load_window_functions(conn) + if self._table_functions: + for table_function in self._table_functions: + table_function.register(conn) + if self._extensions: + self._load_extensions(conn) + + def _set_pragmas(self, conn): + cursor = conn.cursor() + for pragma, value in self._pragmas: + cursor.execute('PRAGMA %s = %s;' % (pragma, value)) + cursor.close() + + def _attach_databases(self, conn): + cursor = conn.cursor() + for name, db in self._attached.items(): + cursor.execute('ATTACH DATABASE "%s" AS "%s"' % (db, name)) + cursor.close() + + def pragma(self, key, value=SENTINEL, permanent=False, schema=None): + if schema is not None: + key = '"%s".%s' % (schema, key) + sql = 'PRAGMA %s' % key + if value is not SENTINEL: + sql += ' = %s' % (value or 0) + if permanent: + pragmas = dict(self._pragmas or ()) + pragmas[key] = value + self._pragmas = list(pragmas.items()) + elif permanent: + raise ValueError('Cannot specify a permanent pragma without value') + row = self.execute_sql(sql).fetchone() + if row: + return row[0] + + cache_size = __pragma__('cache_size') + foreign_keys = __pragma__('foreign_keys') + journal_mode = __pragma__('journal_mode') + journal_size_limit = __pragma__('journal_size_limit') + mmap_size = __pragma__('mmap_size') + page_size = __pragma__('page_size') + read_uncommitted = __pragma__('read_uncommitted') + synchronous = __pragma__('synchronous') + wal_autocheckpoint = __pragma__('wal_autocheckpoint') + + @property + def timeout(self): + return self._timeout + + @timeout.setter + def timeout(self, seconds): + if self._timeout == seconds: + return + + self._timeout = seconds + if not self.is_closed(): + # PySQLite multiplies user timeout by 1000, but the unit of the + # timeout PRAGMA is actually milliseconds. + self.execute_sql('PRAGMA busy_timeout=%d;' % (seconds * 1000)) + + def _load_aggregates(self, conn): + for name, (klass, num_params) in self._aggregates.items(): + conn.create_aggregate(name, num_params, klass) + + def _load_collations(self, conn): + for name, fn in self._collations.items(): + conn.create_collation(name, fn) + + def _load_functions(self, conn): + for name, (fn, num_params) in self._functions.items(): + conn.create_function(name, num_params, fn) + + def _load_window_functions(self, conn): + for name, (klass, num_params) in self._window_functions.items(): + conn.create_window_function(name, num_params, klass) + + def register_aggregate(self, klass, name=None, num_params=-1): + self._aggregates[name or klass.__name__.lower()] = (klass, num_params) + if not self.is_closed(): + self._load_aggregates(self.connection()) + + def aggregate(self, name=None, num_params=-1): + def decorator(klass): + self.register_aggregate(klass, name, num_params) + return klass + return decorator + + def register_collation(self, fn, name=None): + name = name or fn.__name__ + def _collation(*args): + expressions = args + (SQL('collate %s' % name),) + return NodeList(expressions) + fn.collation = _collation + self._collations[name] = fn + if not self.is_closed(): + self._load_collations(self.connection()) + + def collation(self, name=None): + def decorator(fn): + self.register_collation(fn, name) + return fn + return decorator + + def register_function(self, fn, name=None, num_params=-1): + self._functions[name or fn.__name__] = (fn, num_params) + if not self.is_closed(): + self._load_functions(self.connection()) + + def func(self, name=None, num_params=-1): + def decorator(fn): + self.register_function(fn, name, num_params) + return fn + return decorator + + def register_window_function(self, klass, name=None, num_params=-1): + name = name or klass.__name__.lower() + self._window_functions[name] = (klass, num_params) + if not self.is_closed(): + self._load_window_functions(self.connection()) + + def window_function(self, name=None, num_params=-1): + def decorator(klass): + self.register_window_function(klass, name, num_params) + return klass + return decorator + + def register_table_function(self, klass, name=None): + if name is not None: + klass.name = name + self._table_functions.append(klass) + if not self.is_closed(): + klass.register(self.connection()) + + def table_function(self, name=None): + def decorator(klass): + self.register_table_function(klass, name) + return klass + return decorator + + def unregister_aggregate(self, name): + del(self._aggregates[name]) + + def unregister_collation(self, name): + del(self._collations[name]) + + def unregister_function(self, name): + del(self._functions[name]) + + def unregister_window_function(self, name): + del(self._window_functions[name]) + + def unregister_table_function(self, name): + for idx, klass in enumerate(self._table_functions): + if klass.name == name: + break + else: + return False + self._table_functions.pop(idx) + return True + + def _load_extensions(self, conn): + conn.enable_load_extension(True) + for extension in self._extensions: + conn.load_extension(extension) + + def load_extension(self, extension): + self._extensions.add(extension) + if not self.is_closed(): + conn = self.connection() + conn.enable_load_extension(True) + conn.load_extension(extension) + + def unload_extension(self, extension): + self._extensions.remove(extension) + + def attach(self, filename, name): + if name in self._attached: + if self._attached[name] == filename: + return False + raise OperationalError('schema "%s" already attached.' % name) + + self._attached[name] = filename + if not self.is_closed(): + self.execute_sql('ATTACH DATABASE "%s" AS "%s"' % (filename, name)) + return True + + def detach(self, name): + if name not in self._attached: + return False + + del self._attached[name] + if not self.is_closed(): + self.execute_sql('DETACH DATABASE "%s"' % name) + return True + + def begin(self, lock_type=None): + statement = 'BEGIN %s' % lock_type if lock_type else 'BEGIN' + self.execute_sql(statement, commit=False) + + def get_tables(self, schema=None): + schema = schema or 'main' + cursor = self.execute_sql('SELECT name FROM "%s".sqlite_master WHERE ' + 'type=? ORDER BY name' % schema, ('table',)) + return [row for row, in cursor.fetchall()] + + def get_views(self, schema=None): + sql = ('SELECT name, sql FROM "%s".sqlite_master WHERE type=? ' + 'ORDER BY name') % (schema or 'main') + return [ViewMetadata(*row) for row in self.execute_sql(sql, ('view',))] + + def get_indexes(self, table, schema=None): + schema = schema or 'main' + query = ('SELECT name, sql FROM "%s".sqlite_master ' + 'WHERE tbl_name = ? AND type = ? ORDER BY name') % schema + cursor = self.execute_sql(query, (table, 'index')) + index_to_sql = dict(cursor.fetchall()) + + # Determine which indexes have a unique constraint. + unique_indexes = set() + cursor = self.execute_sql('PRAGMA "%s".index_list("%s")' % + (schema, table)) + for row in cursor.fetchall(): + name = row[1] + is_unique = int(row[2]) == 1 + if is_unique: + unique_indexes.add(name) + + # Retrieve the indexed columns. + index_columns = {} + for index_name in sorted(index_to_sql): + cursor = self.execute_sql('PRAGMA "%s".index_info("%s")' % + (schema, index_name)) + index_columns[index_name] = [row[2] for row in cursor.fetchall()] + + return [ + IndexMetadata( + name, + index_to_sql[name], + index_columns[name], + name in unique_indexes, + table) + for name in sorted(index_to_sql)] + + def get_columns(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % + (schema or 'main', table)) + return [ColumnMetadata(r[1], r[2], not r[3], bool(r[5]), table, r[4]) + for r in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % + (schema or 'main', table)) + return [row[1] for row in filter(lambda r: r[-1], cursor.fetchall())] + + def get_foreign_keys(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".foreign_key_list("%s")' % + (schema or 'main', table)) + return [ForeignKeyMetadata(row[3], row[2], row[4], table) + for row in cursor.fetchall()] + + def get_binary_type(self): + return sqlite3.Binary + + def conflict_statement(self, on_conflict, query): + action = on_conflict._action.lower() if on_conflict._action else '' + if action and action not in ('nothing', 'update'): + return SQL('INSERT OR %s' % on_conflict._action.upper()) + + def conflict_update(self, oc, query): + # Sqlite prior to 3.24.0 does not support Postgres-style upsert. + if self.server_version < (3, 24, 0) and \ + any((oc._preserve, oc._update, oc._where, oc._conflict_target, + oc._conflict_constraint)): + raise ValueError('SQLite does not support specifying which values ' + 'to preserve or update.') + + action = oc._action.lower() if oc._action else '' + if action and action not in ('nothing', 'update', ''): + return + + if action == 'nothing': + return SQL('ON CONFLICT DO NOTHING') + elif not oc._update and not oc._preserve: + raise ValueError('If you are not performing any updates (or ' + 'preserving any INSERTed values), then the ' + 'conflict resolution action should be set to ' + '"NOTHING".') + elif oc._conflict_constraint: + raise ValueError('SQLite does not support specifying named ' + 'constraints for conflict resolution.') + elif not oc._conflict_target: + raise ValueError('SQLite requires that a conflict target be ' + 'specified when doing an upsert.') + + return self._build_on_conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.date_part(date_part, date_field, python_value=int) + + def truncate_date(self, date_part, date_field): + return fn.date_trunc(date_part, date_field, + python_value=simple_date_time) + + def to_timestamp(self, date_field): + return fn.strftime('%s', date_field).cast('integer') + + def from_timestamp(self, date_field): + return fn.datetime(date_field, 'unixepoch') + + +class PostgresqlDatabase(Database): + field_types = { + 'AUTO': 'SERIAL', + 'BIGAUTO': 'BIGSERIAL', + 'BLOB': 'BYTEA', + 'BOOL': 'BOOLEAN', + 'DATETIME': 'TIMESTAMP', + 'DECIMAL': 'NUMERIC', + 'DOUBLE': 'DOUBLE PRECISION', + 'UUID': 'UUID', + 'UUIDB': 'BYTEA'} + operations = {'REGEXP': '~', 'IREGEXP': '~*'} + param = '%s' + + commit_select = True + compound_select_parentheses = CSQ_PARENTHESES_ALWAYS + for_update = True + nulls_ordering = True + returning_clause = True + safe_create_index = False + sequences = True + + def init(self, database, register_unicode=True, encoding=None, + isolation_level=None, **kwargs): + self._register_unicode = register_unicode + self._encoding = encoding + self._isolation_level = isolation_level + super(PostgresqlDatabase, self).init(database, **kwargs) + + def _connect(self): + if psycopg2 is None: + raise ImproperlyConfigured('Postgres driver not installed!') + conn = psycopg2.connect(database=self.database, **self.connect_params) + if self._register_unicode: + pg_extensions.register_type(pg_extensions.UNICODE, conn) + pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) + if self._encoding: + conn.set_client_encoding(self._encoding) + if self._isolation_level: + conn.set_isolation_level(self._isolation_level) + return conn + + def _set_server_version(self, conn): + self.server_version = conn.server_version + if self.server_version >= 90600: + self.safe_create_index = True + + def is_connection_usable(self): + if self._state.closed: + return False + + # Returns True if we are idle, running a command, or in an active + # connection. If the connection is in an error state or the connection + # is otherwise unusable, return False. + txn_status = self._state.conn.get_transaction_status() + return txn_status < pg_extensions.TRANSACTION_STATUS_INERROR + + def last_insert_id(self, cursor, query_type=None): + try: + return cursor if query_type != Insert.SIMPLE else cursor[0][0] + except (IndexError, KeyError, TypeError): + pass + + def get_tables(self, schema=None): + query = ('SELECT tablename FROM pg_catalog.pg_tables ' + 'WHERE schemaname = %s ORDER BY tablename') + cursor = self.execute_sql(query, (schema or 'public',)) + return [table for table, in cursor.fetchall()] + + def get_views(self, schema=None): + query = ('SELECT viewname, definition FROM pg_catalog.pg_views ' + 'WHERE schemaname = %s ORDER BY viewname') + cursor = self.execute_sql(query, (schema or 'public',)) + return [ViewMetadata(view_name, sql.strip(' \t;')) + for (view_name, sql) in cursor.fetchall()] + + def get_indexes(self, table, schema=None): + query = """ + SELECT + i.relname, idxs.indexdef, idx.indisunique, + array_to_string(array_agg(cols.attname), ',') + FROM pg_catalog.pg_class AS t + INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid + INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid + INNER JOIN pg_catalog.pg_indexes AS idxs ON + (idxs.tablename = t.relname AND idxs.indexname = i.relname) + LEFT OUTER JOIN pg_catalog.pg_attribute AS cols ON + (cols.attrelid = t.oid AND cols.attnum = ANY(idx.indkey)) + WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s + GROUP BY i.relname, idxs.indexdef, idx.indisunique + ORDER BY idx.indisunique DESC, i.relname;""" + cursor = self.execute_sql(query, (table, 'r', schema or 'public')) + return [IndexMetadata(name, sql.rstrip(' ;'), columns.split(','), + is_unique, table) + for name, sql, is_unique, columns in cursor.fetchall()] + + def get_columns(self, table, schema=None): + query = """ + SELECT column_name, is_nullable, data_type, column_default + FROM information_schema.columns + WHERE table_name = %s AND table_schema = %s + ORDER BY ordinal_position""" + cursor = self.execute_sql(query, (table, schema or 'public')) + pks = set(self.get_primary_keys(table, schema)) + return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) + for name, null, dt, df in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + query = """ + SELECT kc.column_name + FROM information_schema.table_constraints AS tc + INNER JOIN information_schema.key_column_usage AS kc ON ( + tc.table_name = kc.table_name AND + tc.table_schema = kc.table_schema AND + tc.constraint_name = kc.constraint_name) + WHERE + tc.constraint_type = %s AND + tc.table_name = %s AND + tc.table_schema = %s""" + ctype = 'PRIMARY KEY' + cursor = self.execute_sql(query, (ctype, table, schema or 'public')) + return [pk for pk, in cursor.fetchall()] + + def get_foreign_keys(self, table, schema=None): + sql = """ + SELECT DISTINCT + kcu.column_name, ccu.table_name, ccu.column_name + FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON (tc.constraint_name = kcu.constraint_name AND + tc.constraint_schema = kcu.constraint_schema) + JOIN information_schema.constraint_column_usage AS ccu + ON (ccu.constraint_name = tc.constraint_name AND + ccu.constraint_schema = tc.constraint_schema) + WHERE + tc.constraint_type = 'FOREIGN KEY' AND + tc.table_name = %s AND + tc.table_schema = %s""" + cursor = self.execute_sql(sql, (table, schema or 'public')) + return [ForeignKeyMetadata(row[0], row[1], row[2], table) + for row in cursor.fetchall()] + + def sequence_exists(self, sequence): + res = self.execute_sql(""" + SELECT COUNT(*) FROM pg_class, pg_namespace + WHERE relkind='S' + AND pg_class.relnamespace = pg_namespace.oid + AND relname=%s""", (sequence,)) + return bool(res.fetchone()[0]) + + def get_binary_type(self): + return psycopg2.Binary + + def conflict_statement(self, on_conflict, query): + return + + def conflict_update(self, oc, query): + action = oc._action.lower() if oc._action else '' + if action in ('ignore', 'nothing'): + return SQL('ON CONFLICT DO NOTHING') + elif action and action != 'update': + raise ValueError('The only supported actions for conflict ' + 'resolution with Postgresql are "ignore" or ' + '"update".') + elif not oc._update and not oc._preserve: + raise ValueError('If you are not performing any updates (or ' + 'preserving any INSERTed values), then the ' + 'conflict resolution action should be set to ' + '"IGNORE".') + elif not (oc._conflict_target or oc._conflict_constraint): + raise ValueError('Postgres requires that a conflict target be ' + 'specified when doing an upsert.') + + return self._build_on_conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.EXTRACT(NodeList((date_part, SQL('FROM'), date_field))) + + def truncate_date(self, date_part, date_field): + return fn.DATE_TRUNC(date_part, date_field) + + def to_timestamp(self, date_field): + return self.extract_date('EPOCH', date_field) + + def from_timestamp(self, date_field): + # Ironically, here, Postgres means "to the Postgresql timestamp type". + return fn.to_timestamp(date_field) + + def get_noop_select(self, ctx): + return ctx.sql(Select().columns(SQL('0')).where(SQL('false'))) + + def set_time_zone(self, timezone): + self.execute_sql('set time zone "%s";' % timezone) + + +class MySQLDatabase(Database): + field_types = { + 'AUTO': 'INTEGER AUTO_INCREMENT', + 'BIGAUTO': 'BIGINT AUTO_INCREMENT', + 'BOOL': 'BOOL', + 'DECIMAL': 'NUMERIC', + 'DOUBLE': 'DOUBLE PRECISION', + 'FLOAT': 'FLOAT', + 'UUID': 'VARCHAR(40)', + 'UUIDB': 'VARBINARY(16)'} + operations = { + 'LIKE': 'LIKE BINARY', + 'ILIKE': 'LIKE', + 'REGEXP': 'REGEXP BINARY', + 'IREGEXP': 'REGEXP', + 'XOR': 'XOR'} + param = '%s' + quote = '``' + + commit_select = True + compound_select_parentheses = CSQ_PARENTHESES_UNNESTED + for_update = True + limit_max = 2 ** 64 - 1 + safe_create_index = False + safe_drop_index = False + sql_mode = 'PIPES_AS_CONCAT' + + def init(self, database, **kwargs): + params = { + 'charset': 'utf8', + 'sql_mode': self.sql_mode, + 'use_unicode': True} + params.update(kwargs) + if 'password' in params and mysql_passwd: + params['passwd'] = params.pop('password') + super(MySQLDatabase, self).init(database, **params) + + def _connect(self): + if mysql is None: + raise ImproperlyConfigured('MySQL driver not installed!') + conn = mysql.connect(db=self.database, **self.connect_params) + return conn + + def _set_server_version(self, conn): + try: + version_raw = conn.server_version + except AttributeError: + version_raw = conn.get_server_info() + self.server_version = self._extract_server_version(version_raw) + + def _extract_server_version(self, version): + version = version.lower() + if 'maria' in version: + match_obj = re.search(r'(1\d\.\d+\.\d+)', version) + else: + match_obj = re.search(r'(\d\.\d+\.\d+)', version) + if match_obj is not None: + return tuple(int(num) for num in match_obj.groups()[0].split('.')) + + warnings.warn('Unable to determine MySQL version: "%s"' % version) + return (0, 0, 0) # Unable to determine version! + + def default_values_insert(self, ctx): + return ctx.literal('() VALUES ()') + + def get_tables(self, schema=None): + query = ('SELECT table_name FROM information_schema.tables ' + 'WHERE table_schema = DATABASE() AND table_type != %s ' + 'ORDER BY table_name') + return [table for table, in self.execute_sql(query, ('VIEW',))] + + def get_views(self, schema=None): + query = ('SELECT table_name, view_definition ' + 'FROM information_schema.views ' + 'WHERE table_schema = DATABASE() ORDER BY table_name') + cursor = self.execute_sql(query) + return [ViewMetadata(*row) for row in cursor.fetchall()] + + def get_indexes(self, table, schema=None): + cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) + unique = set() + indexes = {} + for row in cursor.fetchall(): + if not row[1]: + unique.add(row[2]) + indexes.setdefault(row[2], []) + indexes[row[2]].append(row[4]) + return [IndexMetadata(name, None, indexes[name], name in unique, table) + for name in indexes] + + def get_columns(self, table, schema=None): + sql = """ + SELECT column_name, is_nullable, data_type, column_default + FROM information_schema.columns + WHERE table_name = %s AND table_schema = DATABASE()""" + cursor = self.execute_sql(sql, (table,)) + pks = set(self.get_primary_keys(table)) + return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) + for name, null, dt, df in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) + return [row[4] for row in + filter(lambda row: row[2] == 'PRIMARY', cursor.fetchall())] + + def get_foreign_keys(self, table, schema=None): + query = """ + SELECT column_name, referenced_table_name, referenced_column_name + FROM information_schema.key_column_usage + WHERE table_name = %s + AND table_schema = DATABASE() + AND referenced_table_name IS NOT NULL + AND referenced_column_name IS NOT NULL""" + cursor = self.execute_sql(query, (table,)) + return [ + ForeignKeyMetadata(column, dest_table, dest_column, table) + for column, dest_table, dest_column in cursor.fetchall()] + + def get_binary_type(self): + return mysql.Binary + + def conflict_statement(self, on_conflict, query): + if not on_conflict._action: return + + action = on_conflict._action.lower() + if action == 'replace': + return SQL('REPLACE') + elif action == 'ignore': + return SQL('INSERT IGNORE') + elif action != 'update': + raise ValueError('Un-supported action for conflict resolution. ' + 'MySQL supports REPLACE, IGNORE and UPDATE.') + + def conflict_update(self, on_conflict, query): + if on_conflict._where or on_conflict._conflict_target or \ + on_conflict._conflict_constraint: + raise ValueError('MySQL does not support the specification of ' + 'where clauses or conflict targets for conflict ' + 'resolution.') + + updates = [] + if on_conflict._preserve: + # Here we need to determine which function to use, which varies + # depending on the MySQL server version. MySQL and MariaDB prior to + # 10.3.3 use "VALUES", while MariaDB 10.3.3+ use "VALUE". + version = self.server_version or (0,) + if version[0] == 10 and version >= (10, 3, 3): + VALUE_FN = fn.VALUE + else: + VALUE_FN = fn.VALUES + + for column in on_conflict._preserve: + entity = ensure_entity(column) + expression = NodeList(( + ensure_entity(column), + SQL('='), + VALUE_FN(entity))) + updates.append(expression) + + if on_conflict._update: + for k, v in on_conflict._update.items(): + if not isinstance(v, Node): + # Attempt to resolve string field-names to their respective + # field object, to apply data-type conversions. + if isinstance(k, basestring): + k = getattr(query.table, k) + if isinstance(k, Field): + v = k.to_value(v) + else: + v = Value(v, unpack=False) + updates.append(NodeList((ensure_entity(k), SQL('='), v))) + + if updates: + return NodeList((SQL('ON DUPLICATE KEY UPDATE'), + CommaNodeList(updates))) + + def extract_date(self, date_part, date_field): + return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field))) + + def truncate_date(self, date_part, date_field): + return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part], + python_value=simple_date_time) + + def to_timestamp(self, date_field): + return fn.UNIX_TIMESTAMP(date_field) + + def from_timestamp(self, date_field): + return fn.FROM_UNIXTIME(date_field) + + def random(self): + return fn.rand() + + def get_noop_select(self, ctx): + return ctx.literal('DO 0') + + +# TRANSACTION CONTROL. + + +class _manual(_callable_context_manager): + def __init__(self, db): + self.db = db + + def __enter__(self): + top = self.db.top_transaction() + if top is not None and not isinstance(top, _manual): + raise ValueError('Cannot enter manual commit block while a ' + 'transaction is active.') + self.db.push_transaction(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.db.pop_transaction() is not self: + raise ValueError('Transaction stack corrupted while exiting ' + 'manual commit block.') + + +class _atomic(_callable_context_manager): + def __init__(self, db, *args, **kwargs): + self.db = db + self._transaction_args = (args, kwargs) + + def __enter__(self): + if self.db.transaction_depth() == 0: + args, kwargs = self._transaction_args + self._helper = self.db.transaction(*args, **kwargs) + elif isinstance(self.db.top_transaction(), _manual): + raise ValueError('Cannot enter atomic commit block while in ' + 'manual commit mode.') + else: + self._helper = self.db.savepoint() + return self._helper.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._helper.__exit__(exc_type, exc_val, exc_tb) + + +class _transaction(_callable_context_manager): + def __init__(self, db, *args, **kwargs): + self.db = db + self._begin_args = (args, kwargs) + + def _begin(self): + args, kwargs = self._begin_args + self.db.begin(*args, **kwargs) + + def commit(self, begin=True): + self.db.commit() + if begin: + self._begin() + + def rollback(self, begin=True): + self.db.rollback() + if begin: + self._begin() + + def __enter__(self): + if self.db.transaction_depth() == 0: + self._begin() + self.db.push_transaction(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type: + self.rollback(False) + elif self.db.transaction_depth() == 1: + try: + self.commit(False) + except: + self.rollback(False) + raise + finally: + self.db.pop_transaction() + + +class _savepoint(_callable_context_manager): + def __init__(self, db, sid=None): + self.db = db + self.sid = sid or 's' + uuid.uuid4().hex + self.quoted_sid = self.sid.join(self.db.quote) + + def _begin(self): + self.db.execute_sql('SAVEPOINT %s;' % self.quoted_sid) + + def commit(self, begin=True): + self.db.execute_sql('RELEASE SAVEPOINT %s;' % self.quoted_sid) + if begin: self._begin() + + def rollback(self): + self.db.execute_sql('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) + + def __enter__(self): + self._begin() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type: + self.rollback() + else: + try: + self.commit(begin=False) + except: + self.rollback() + raise + + +# CURSOR REPRESENTATIONS. + + +class CursorWrapper(object): + def __init__(self, cursor): + self.cursor = cursor + self.count = 0 + self.index = 0 + self.initialized = False + self.populated = False + self.row_cache = [] + + def __iter__(self): + if self.populated: + return iter(self.row_cache) + return ResultIterator(self) + + def __getitem__(self, item): + if isinstance(item, slice): + stop = item.stop + if stop is None or stop < 0: + self.fill_cache() + else: + self.fill_cache(stop) + return self.row_cache[item] + elif isinstance(item, int): + self.fill_cache(item if item > 0 else 0) + return self.row_cache[item] + else: + raise ValueError('CursorWrapper only supports integer and slice ' + 'indexes.') + + def __len__(self): + self.fill_cache() + return self.count + + def initialize(self): + pass + + def iterate(self, cache=True): + row = self.cursor.fetchone() + if row is None: + self.populated = True + self.cursor.close() + raise StopIteration + elif not self.initialized: + self.initialize() # Lazy initialization. + self.initialized = True + self.count += 1 + result = self.process_row(row) + if cache: + self.row_cache.append(result) + return result + + def process_row(self, row): + return row + + def iterator(self): + """Efficient one-pass iteration over the result set.""" + while True: + try: + yield self.iterate(False) + except StopIteration: + return + + def fill_cache(self, n=0): + n = n or float('Inf') + if n < 0: + raise ValueError('Negative values are not supported.') + + iterator = ResultIterator(self) + iterator.index = self.count + while not self.populated and (n > self.count): + try: + iterator.next() + except StopIteration: + break + + +class DictCursorWrapper(CursorWrapper): + def _initialize_columns(self): + description = self.cursor.description + self.columns = [t[0][t[0].find('.') + 1:].strip('"') + for t in description] + self.ncols = len(description) + + initialize = _initialize_columns + + def _row_to_dict(self, row): + result = {} + for i in range(self.ncols): + result.setdefault(self.columns[i], row[i]) # Do not overwrite. + return result + + process_row = _row_to_dict + + +class NamedTupleCursorWrapper(CursorWrapper): + def initialize(self): + description = self.cursor.description + self.tuple_class = collections.namedtuple( + 'Row', + [col[0][col[0].find('.') + 1:].strip('"') for col in description]) + + def process_row(self, row): + return self.tuple_class(*row) + + +class ObjectCursorWrapper(DictCursorWrapper): + def __init__(self, cursor, constructor): + super(ObjectCursorWrapper, self).__init__(cursor) + self.constructor = constructor + + def process_row(self, row): + row_dict = self._row_to_dict(row) + return self.constructor(**row_dict) + + +class ResultIterator(object): + def __init__(self, cursor_wrapper): + self.cursor_wrapper = cursor_wrapper + self.index = 0 + + def __iter__(self): + return self + + def next(self): + if self.index < self.cursor_wrapper.count: + obj = self.cursor_wrapper.row_cache[self.index] + elif not self.cursor_wrapper.populated: + self.cursor_wrapper.iterate() + obj = self.cursor_wrapper.row_cache[self.index] + else: + raise StopIteration + self.index += 1 + return obj + + __next__ = next + +# FIELDS + +class FieldAccessor(object): + def __init__(self, model, field, name): + self.model = model + self.field = field + self.name = name + + def __get__(self, instance, instance_type=None): + if instance is not None: + return instance.__data__.get(self.name) + return self.field + + def __set__(self, instance, value): + instance.__data__[self.name] = value + instance._dirty.add(self.name) + + +class ForeignKeyAccessor(FieldAccessor): + def __init__(self, model, field, name): + super(ForeignKeyAccessor, self).__init__(model, field, name) + self.rel_model = field.rel_model + + def get_rel_instance(self, instance): + value = instance.__data__.get(self.name) + if value is not None or self.name in instance.__rel__: + if self.name not in instance.__rel__: + obj = self.rel_model.get(self.field.rel_field == value) + instance.__rel__[self.name] = obj + return instance.__rel__[self.name] + elif not self.field.null: + raise self.rel_model.DoesNotExist + return value + + def __get__(self, instance, instance_type=None): + if instance is not None: + return self.get_rel_instance(instance) + return self.field + + def __set__(self, instance, obj): + if isinstance(obj, self.rel_model): + instance.__data__[self.name] = getattr(obj, self.field.rel_field.name) + instance.__rel__[self.name] = obj + else: + fk_value = instance.__data__.get(self.name) + instance.__data__[self.name] = obj + if obj != fk_value and self.name in instance.__rel__: + del instance.__rel__[self.name] + instance._dirty.add(self.name) + + +class NoQueryForeignKeyAccessor(ForeignKeyAccessor): + def get_rel_instance(self, instance): + value = instance.__data__.get(self.name) + if value is not None: + return instance.__rel__.get(self.name, value) + elif not self.field.null: + raise self.rel_model.DoesNotExist + + +class BackrefAccessor(object): + def __init__(self, field): + self.field = field + self.model = field.rel_model + self.rel_model = field.model + + def __get__(self, instance, instance_type=None): + if instance is not None: + dest = self.field.rel_field.name + return (self.rel_model + .select() + .where(self.field == getattr(instance, dest))) + return self + + +class ObjectIdAccessor(object): + """Gives direct access to the underlying id""" + def __init__(self, field): + self.field = field + + def __get__(self, instance, instance_type=None): + if instance is not None: + return instance.__data__.get(self.field.name) + return self.field + + def __set__(self, instance, value): + setattr(instance, self.field.name, value) + + +class Field(ColumnBase): + _field_counter = 0 + _order = 0 + accessor_class = FieldAccessor + auto_increment = False + default_index_type = None + field_type = 'DEFAULT' + unpack = True + + def __init__(self, null=False, index=False, unique=False, column_name=None, + default=None, primary_key=False, constraints=None, + sequence=None, collation=None, unindexed=False, choices=None, + help_text=None, verbose_name=None, index_type=None, + db_column=None, _hidden=False): + if db_column is not None: + __deprecated__('"db_column" has been deprecated in favor of ' + '"column_name" for Field objects.') + column_name = db_column + + self.null = null + self.index = index + self.unique = unique + self.column_name = column_name + self.default = default + self.primary_key = primary_key + self.constraints = constraints # List of column constraints. + self.sequence = sequence # Name of sequence, e.g. foo_id_seq. + self.collation = collation + self.unindexed = unindexed + self.choices = choices + self.help_text = help_text + self.verbose_name = verbose_name + self.index_type = index_type or self.default_index_type + self._hidden = _hidden + + # Used internally for recovering the order in which Fields were defined + # on the Model class. + Field._field_counter += 1 + self._order = Field._field_counter + self._sort_key = (self.primary_key and 1 or 2), self._order + + def __hash__(self): + return hash(self.name + '.' + self.model.__name__) + + def __repr__(self): + if hasattr(self, 'model') and getattr(self, 'name', None): + return '<%s: %s.%s>' % (type(self).__name__, + self.model.__name__, + self.name) + return '<%s: (unbound)>' % type(self).__name__ + + def bind(self, model, name, set_attribute=True): + self.model = model + self.name = self.safe_name = name + self.column_name = self.column_name or name + if set_attribute: + setattr(model, name, self.accessor_class(model, self, name)) + + @property + def column(self): + return Column(self.model._meta.table, self.column_name) + + def adapt(self, value): + return value + + def db_value(self, value): + return value if value is None else self.adapt(value) + + def python_value(self, value): + return value if value is None else self.adapt(value) + + def to_value(self, value): + return Value(value, self.db_value, unpack=False) + + def get_sort_key(self, ctx): + return self._sort_key + + def __sql__(self, ctx): + return ctx.sql(self.column) + + def get_modifiers(self): + pass + + def ddl_datatype(self, ctx): + if ctx and ctx.state.field_types: + column_type = ctx.state.field_types.get(self.field_type, + self.field_type) + else: + column_type = self.field_type + + modifiers = self.get_modifiers() + if column_type and modifiers: + modifier_literal = ', '.join([str(m) for m in modifiers]) + return SQL('%s(%s)' % (column_type, modifier_literal)) + else: + return SQL(column_type) + + def ddl(self, ctx): + accum = [Entity(self.column_name)] + data_type = self.ddl_datatype(ctx) + if data_type: + accum.append(data_type) + if self.unindexed: + accum.append(SQL('UNINDEXED')) + if not self.null: + accum.append(SQL('NOT NULL')) + if self.primary_key: + accum.append(SQL('PRIMARY KEY')) + if self.sequence: + accum.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence)) + if self.constraints: + accum.extend(self.constraints) + if self.collation: + accum.append(SQL('COLLATE %s' % self.collation)) + return NodeList(accum) + + +class IntegerField(Field): + field_type = 'INT' + + def adapt(self, value): + try: + return int(value) + except ValueError: + return value + + +class BigIntegerField(IntegerField): + field_type = 'BIGINT' + + +class SmallIntegerField(IntegerField): + field_type = 'SMALLINT' + + +class AutoField(IntegerField): + auto_increment = True + field_type = 'AUTO' + + def __init__(self, *args, **kwargs): + if kwargs.get('primary_key') is False: + raise ValueError('%s must always be a primary key.' % type(self)) + kwargs['primary_key'] = True + super(AutoField, self).__init__(*args, **kwargs) + + +class BigAutoField(AutoField): + field_type = 'BIGAUTO' + + +class IdentityField(AutoField): + field_type = 'INT GENERATED BY DEFAULT AS IDENTITY' + + def __init__(self, generate_always=False, **kwargs): + if generate_always: + self.field_type = 'INT GENERATED ALWAYS AS IDENTITY' + super(IdentityField, self).__init__(**kwargs) + + +class PrimaryKeyField(AutoField): + def __init__(self, *args, **kwargs): + __deprecated__('"PrimaryKeyField" has been renamed to "AutoField". ' + 'Please update your code accordingly as this will be ' + 'completely removed in a subsequent release.') + super(PrimaryKeyField, self).__init__(*args, **kwargs) + + +class FloatField(Field): + field_type = 'FLOAT' + + def adapt(self, value): + try: + return float(value) + except ValueError: + return value + + +class DoubleField(FloatField): + field_type = 'DOUBLE' + + +class DecimalField(Field): + field_type = 'DECIMAL' + + def __init__(self, max_digits=10, decimal_places=5, auto_round=False, + rounding=None, *args, **kwargs): + self.max_digits = max_digits + self.decimal_places = decimal_places + self.auto_round = auto_round + self.rounding = rounding or decimal.DefaultContext.rounding + self._exp = decimal.Decimal(10) ** (-self.decimal_places) + super(DecimalField, self).__init__(*args, **kwargs) + + def get_modifiers(self): + return [self.max_digits, self.decimal_places] + + def db_value(self, value): + D = decimal.Decimal + if not value: + return value if value is None else D(0) + if self.auto_round: + decimal_value = D(text_type(value)) + return decimal_value.quantize(self._exp, rounding=self.rounding) + return value + + def python_value(self, value): + if value is not None: + if isinstance(value, decimal.Decimal): + return value + return decimal.Decimal(text_type(value)) + + +class _StringField(Field): + def adapt(self, value): + if isinstance(value, text_type): + return value + elif isinstance(value, bytes_type): + return value.decode('utf-8') + return text_type(value) + + def __add__(self, other): return StringExpression(self, OP.CONCAT, other) + def __radd__(self, other): return StringExpression(other, OP.CONCAT, self) + + +class CharField(_StringField): + field_type = 'VARCHAR' + + def __init__(self, max_length=255, *args, **kwargs): + self.max_length = max_length + super(CharField, self).__init__(*args, **kwargs) + + def get_modifiers(self): + return self.max_length and [self.max_length] or None + + +class FixedCharField(CharField): + field_type = 'CHAR' + + def python_value(self, value): + value = super(FixedCharField, self).python_value(value) + if value: + value = value.strip() + return value + + +class TextField(_StringField): + field_type = 'TEXT' + + +class BlobField(Field): + field_type = 'BLOB' + + def _db_hook(self, database): + if database is None: + self._constructor = bytearray + else: + self._constructor = database.get_binary_type() + + def bind(self, model, name, set_attribute=True): + self._constructor = bytearray + if model._meta.database: + if isinstance(model._meta.database, Proxy): + model._meta.database.attach_callback(self._db_hook) + else: + self._db_hook(model._meta.database) + + # Attach a hook to the model metadata; in the event the database is + # changed or set at run-time, we will be sure to apply our callback and + # use the proper data-type for our database driver. + model._meta._db_hooks.append(self._db_hook) + return super(BlobField, self).bind(model, name, set_attribute) + + def db_value(self, value): + if isinstance(value, text_type): + value = value.encode('raw_unicode_escape') + if isinstance(value, bytes_type): + return self._constructor(value) + return value + + +class BitField(BitwiseMixin, BigIntegerField): + def __init__(self, *args, **kwargs): + kwargs.setdefault('default', 0) + super(BitField, self).__init__(*args, **kwargs) + self.__current_flag = 1 + + def flag(self, value=None): + if value is None: + value = self.__current_flag + self.__current_flag <<= 1 + else: + self.__current_flag = value << 1 + + class FlagDescriptor(object): + def __init__(self, field, value): + self._field = field + self._value = value + def __get__(self, instance, instance_type=None): + if instance is None: + return self._field.bin_and(self._value) != 0 + value = getattr(instance, self._field.name) or 0 + return (value & self._value) != 0 + def __set__(self, instance, is_set): + if is_set not in (True, False): + raise ValueError('Value must be either True or False') + value = getattr(instance, self._field.name) or 0 + if is_set: + value |= self._value + else: + value &= ~self._value + setattr(instance, self._field.name, value) + return FlagDescriptor(self, value) + + +class BigBitFieldData(object): + def __init__(self, instance, name): + self.instance = instance + self.name = name + value = self.instance.__data__.get(self.name) + if not value: + value = bytearray() + elif not isinstance(value, bytearray): + value = bytearray(value) + self._buffer = self.instance.__data__[self.name] = value + + def _ensure_length(self, idx): + byte_num, byte_offset = divmod(idx, 8) + cur_size = len(self._buffer) + if cur_size <= byte_num: + self._buffer.extend(b'\x00' * ((byte_num + 1) - cur_size)) + return byte_num, byte_offset + + def set_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] |= (1 << byte_offset) + + def clear_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] &= ~(1 << byte_offset) + + def toggle_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] ^= (1 << byte_offset) + return bool(self._buffer[byte_num] & (1 << byte_offset)) + + def is_set(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + return bool(self._buffer[byte_num] & (1 << byte_offset)) + + def __repr__(self): + return repr(self._buffer) + + +class BigBitFieldAccessor(FieldAccessor): + def __get__(self, instance, instance_type=None): + if instance is None: + return self.field + return BigBitFieldData(instance, self.name) + def __set__(self, instance, value): + if isinstance(value, memoryview): + value = value.tobytes() + elif isinstance(value, buffer_type): + value = bytes(value) + elif isinstance(value, bytearray): + value = bytes_type(value) + elif isinstance(value, BigBitFieldData): + value = bytes_type(value._buffer) + elif isinstance(value, text_type): + value = value.encode('utf-8') + elif not isinstance(value, bytes_type): + raise ValueError('Value must be either a bytes, memoryview or ' + 'BigBitFieldData instance.') + super(BigBitFieldAccessor, self).__set__(instance, value) + + +class BigBitField(BlobField): + accessor_class = BigBitFieldAccessor + + def __init__(self, *args, **kwargs): + kwargs.setdefault('default', bytes_type) + super(BigBitField, self).__init__(*args, **kwargs) + + def db_value(self, value): + return bytes_type(value) if value is not None else value + + +class UUIDField(Field): + field_type = 'UUID' + + def db_value(self, value): + if isinstance(value, basestring) and len(value) == 32: + # Hex string. No transformation is necessary. + return value + elif isinstance(value, bytes) and len(value) == 16: + # Allow raw binary representation. + value = uuid.UUID(bytes=value) + if isinstance(value, uuid.UUID): + return value.hex + try: + return uuid.UUID(value).hex + except: + return value + + def python_value(self, value): + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(value) if value is not None else None + + +class BinaryUUIDField(BlobField): + field_type = 'UUIDB' + + def db_value(self, value): + if isinstance(value, bytes) and len(value) == 16: + # Raw binary value. No transformation is necessary. + return self._constructor(value) + elif isinstance(value, basestring) and len(value) == 32: + # Allow hex string representation. + value = uuid.UUID(hex=value) + if isinstance(value, uuid.UUID): + return self._constructor(value.bytes) + elif value is not None: + raise ValueError('value for binary UUID field must be UUID(), ' + 'a hexadecimal string, or a bytes object.') + + def python_value(self, value): + if isinstance(value, uuid.UUID): + return value + elif isinstance(value, memoryview): + value = value.tobytes() + elif value and not isinstance(value, bytes): + value = bytes(value) + return uuid.UUID(bytes=value) if value is not None else None + + +def _date_part(date_part): + def dec(self): + return self.model._meta.database.extract_date(date_part, self) + return dec + +def format_date_time(value, formats, post_process=None): + post_process = post_process or (lambda x: x) + for fmt in formats: + try: + return post_process(datetime.datetime.strptime(value, fmt)) + except ValueError: + pass + return value + +def simple_date_time(value): + try: + return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') + except (TypeError, ValueError): + return value + + +class _BaseFormattedField(Field): + formats = None + + def __init__(self, formats=None, *args, **kwargs): + if formats is not None: + self.formats = formats + super(_BaseFormattedField, self).__init__(*args, **kwargs) + + +class DateTimeField(_BaseFormattedField): + field_type = 'DATETIME' + formats = [ + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d', + ] + + def adapt(self, value): + if value and isinstance(value, basestring): + return format_date_time(value, self.formats) + return value + + def to_timestamp(self): + return self.model._meta.database.to_timestamp(self) + + def truncate(self, part): + return self.model._meta.database.truncate_date(part, self) + + year = property(_date_part('year')) + month = property(_date_part('month')) + day = property(_date_part('day')) + hour = property(_date_part('hour')) + minute = property(_date_part('minute')) + second = property(_date_part('second')) + + +class DateField(_BaseFormattedField): + field_type = 'DATE' + formats = [ + '%Y-%m-%d', + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + ] + + def adapt(self, value): + if value and isinstance(value, basestring): + pp = lambda x: x.date() + return format_date_time(value, self.formats, pp) + elif value and isinstance(value, datetime.datetime): + return value.date() + return value + + def to_timestamp(self): + return self.model._meta.database.to_timestamp(self) + + def truncate(self, part): + return self.model._meta.database.truncate_date(part, self) + + year = property(_date_part('year')) + month = property(_date_part('month')) + day = property(_date_part('day')) + + +class TimeField(_BaseFormattedField): + field_type = 'TIME' + formats = [ + '%H:%M:%S.%f', + '%H:%M:%S', + '%H:%M', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d %H:%M:%S', + ] + + def adapt(self, value): + if value: + if isinstance(value, basestring): + pp = lambda x: x.time() + return format_date_time(value, self.formats, pp) + elif isinstance(value, datetime.datetime): + return value.time() + if value is not None and isinstance(value, datetime.timedelta): + return (datetime.datetime.min + value).time() + return value + + hour = property(_date_part('hour')) + minute = property(_date_part('minute')) + second = property(_date_part('second')) + + +def _timestamp_date_part(date_part): + def dec(self): + db = self.model._meta.database + expr = ((self / Value(self.resolution, converter=False)) + if self.resolution > 1 else self) + return db.extract_date(date_part, db.from_timestamp(expr)) + return dec + + +class TimestampField(BigIntegerField): + # Support second -> microsecond resolution. + valid_resolutions = [10**i for i in range(7)] + + def __init__(self, *args, **kwargs): + self.resolution = kwargs.pop('resolution', None) + + if not self.resolution: + self.resolution = 1 + elif self.resolution in range(2, 7): + self.resolution = 10 ** self.resolution + elif self.resolution not in self.valid_resolutions: + raise ValueError('TimestampField resolution must be one of: %s' % + ', '.join(str(i) for i in self.valid_resolutions)) + self.ticks_to_microsecond = 1000000 // self.resolution + + self.utc = kwargs.pop('utc', False) or False + dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now + kwargs.setdefault('default', dflt) + super(TimestampField, self).__init__(*args, **kwargs) + + def local_to_utc(self, dt): + # Convert naive local datetime into naive UTC, e.g.: + # 2019-03-01T12:00:00 (local=US/Central) -> 2019-03-01T18:00:00. + # 2019-05-01T12:00:00 (local=US/Central) -> 2019-05-01T17:00:00. + # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. + return datetime.datetime(*time.gmtime(time.mktime(dt.timetuple()))[:6]) + + def utc_to_local(self, dt): + # Convert a naive UTC datetime into local time, e.g.: + # 2019-03-01T18:00:00 (local=US/Central) -> 2019-03-01T12:00:00. + # 2019-05-01T17:00:00 (local=US/Central) -> 2019-05-01T12:00:00. + # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. + ts = calendar.timegm(dt.utctimetuple()) + return datetime.datetime.fromtimestamp(ts) + + def get_timestamp(self, value): + if self.utc: + # If utc-mode is on, then we assume all naive datetimes are in UTC. + return calendar.timegm(value.utctimetuple()) + else: + return time.mktime(value.timetuple()) + + def db_value(self, value): + if value is None: + return + + if isinstance(value, datetime.datetime): + pass + elif isinstance(value, datetime.date): + value = datetime.datetime(value.year, value.month, value.day) + else: + return int(round(value * self.resolution)) + + timestamp = self.get_timestamp(value) + if self.resolution > 1: + timestamp += (value.microsecond * .000001) + timestamp *= self.resolution + return int(round(timestamp)) + + def python_value(self, value): + if value is not None and isinstance(value, (int, float, long)): + if self.resolution > 1: + value, ticks = divmod(value, self.resolution) + microseconds = int(ticks * self.ticks_to_microsecond) + else: + microseconds = 0 + + if self.utc: + value = datetime.datetime.utcfromtimestamp(value) + else: + value = datetime.datetime.fromtimestamp(value) + + if microseconds: + value = value.replace(microsecond=microseconds) + + return value + + def from_timestamp(self): + expr = ((self / Value(self.resolution, converter=False)) + if self.resolution > 1 else self) + return self.model._meta.database.from_timestamp(expr) + + year = property(_timestamp_date_part('year')) + month = property(_timestamp_date_part('month')) + day = property(_timestamp_date_part('day')) + hour = property(_timestamp_date_part('hour')) + minute = property(_timestamp_date_part('minute')) + second = property(_timestamp_date_part('second')) + + +class IPField(BigIntegerField): + def db_value(self, val): + if val is not None: + return struct.unpack('!I', socket.inet_aton(val))[0] + + def python_value(self, val): + if val is not None: + return socket.inet_ntoa(struct.pack('!I', val)) + + +class BooleanField(Field): + field_type = 'BOOL' + adapt = bool + + +class BareField(Field): + def __init__(self, adapt=None, *args, **kwargs): + super(BareField, self).__init__(*args, **kwargs) + if adapt is not None: + self.adapt = adapt + + def ddl_datatype(self, ctx): + return + + +class ForeignKeyField(Field): + accessor_class = ForeignKeyAccessor + + def __init__(self, model, field=None, backref=None, on_delete=None, + on_update=None, deferrable=None, _deferred=None, + rel_model=None, to_field=None, object_id_name=None, + lazy_load=True, related_name=None, *args, **kwargs): + kwargs.setdefault('index', True) + + # If lazy_load is disable, we use a different descriptor/accessor that + # will ensure we don't accidentally perform a query. + if not lazy_load: + self.accessor_class = NoQueryForeignKeyAccessor + + super(ForeignKeyField, self).__init__(*args, **kwargs) + + if rel_model is not None: + __deprecated__('"rel_model" has been deprecated in favor of ' + '"model" for ForeignKeyField objects.') + model = rel_model + if to_field is not None: + __deprecated__('"to_field" has been deprecated in favor of ' + '"field" for ForeignKeyField objects.') + field = to_field + if related_name is not None: + __deprecated__('"related_name" has been deprecated in favor of ' + '"backref" for Field objects.') + backref = related_name + + self._is_self_reference = model == 'self' + self.rel_model = model + self.rel_field = field + self.declared_backref = backref + self.backref = None + self.on_delete = on_delete + self.on_update = on_update + self.deferrable = deferrable + self.deferred = _deferred + self.object_id_name = object_id_name + self.lazy_load = lazy_load + + @property + def field_type(self): + if not isinstance(self.rel_field, AutoField): + return self.rel_field.field_type + elif isinstance(self.rel_field, BigAutoField): + return BigIntegerField.field_type + return IntegerField.field_type + + def get_modifiers(self): + if not isinstance(self.rel_field, AutoField): + return self.rel_field.get_modifiers() + return super(ForeignKeyField, self).get_modifiers() + + def adapt(self, value): + return self.rel_field.adapt(value) + + def db_value(self, value): + if isinstance(value, self.rel_model): + value = value.get_id() + return self.rel_field.db_value(value) + + def python_value(self, value): + if isinstance(value, self.rel_model): + return value + return self.rel_field.python_value(value) + + def bind(self, model, name, set_attribute=True): + if not self.column_name: + self.column_name = name if name.endswith('_id') else name + '_id' + if not self.object_id_name: + self.object_id_name = self.column_name + if self.object_id_name == name: + self.object_id_name += '_id' + elif self.object_id_name == name: + raise ValueError('ForeignKeyField "%s"."%s" specifies an ' + 'object_id_name that conflicts with its field ' + 'name.' % (model._meta.name, name)) + if self._is_self_reference: + self.rel_model = model + if isinstance(self.rel_field, basestring): + self.rel_field = getattr(self.rel_model, self.rel_field) + elif self.rel_field is None: + self.rel_field = self.rel_model._meta.primary_key + + # Bind field before assigning backref, so field is bound when + # calling declared_backref() (if callable). + super(ForeignKeyField, self).bind(model, name, set_attribute) + self.safe_name = self.object_id_name + + if callable_(self.declared_backref): + self.backref = self.declared_backref(self) + else: + self.backref, self.declared_backref = self.declared_backref, None + if not self.backref: + self.backref = '%s_set' % model._meta.name + + if set_attribute: + setattr(model, self.object_id_name, ObjectIdAccessor(self)) + if self.backref not in '!+': + setattr(self.rel_model, self.backref, BackrefAccessor(self)) + + def foreign_key_constraint(self): + parts = [ + SQL('FOREIGN KEY'), + EnclosedNodeList((self,)), + SQL('REFERENCES'), + self.rel_model, + EnclosedNodeList((self.rel_field,))] + if self.on_delete: + parts.append(SQL('ON DELETE %s' % self.on_delete)) + if self.on_update: + parts.append(SQL('ON UPDATE %s' % self.on_update)) + if self.deferrable: + parts.append(SQL('DEFERRABLE %s' % self.deferrable)) + return NodeList(parts) + + def __getattr__(self, attr): + if attr.startswith('__'): + # Prevent recursion error when deep-copying. + raise AttributeError('Cannot look-up non-existant "__" methods.') + if attr in self.rel_model._meta.fields: + return self.rel_model._meta.fields[attr] + raise AttributeError('Foreign-key has no attribute %s, nor is it a ' + 'valid field on the related model.' % attr) + + +class DeferredForeignKey(Field): + _unresolved = set() + + def __init__(self, rel_model_name, **kwargs): + self.field_kwargs = kwargs + self.rel_model_name = rel_model_name.lower() + DeferredForeignKey._unresolved.add(self) + super(DeferredForeignKey, self).__init__( + column_name=kwargs.get('column_name'), + null=kwargs.get('null')) + + __hash__ = object.__hash__ + + def __deepcopy__(self, memo=None): + return DeferredForeignKey(self.rel_model_name, **self.field_kwargs) + + def set_model(self, rel_model): + field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs) + self.model._meta.add_field(self.name, field) + + @staticmethod + def resolve(model_cls): + unresolved = sorted(DeferredForeignKey._unresolved, + key=operator.attrgetter('_order')) + for dr in unresolved: + if dr.rel_model_name == model_cls.__name__.lower(): + dr.set_model(model_cls) + DeferredForeignKey._unresolved.discard(dr) + + +class DeferredThroughModel(object): + def __init__(self): + self._refs = [] + + def set_field(self, model, field, name): + self._refs.append((model, field, name)) + + def set_model(self, through_model): + for src_model, m2mfield, name in self._refs: + m2mfield.through_model = through_model + src_model._meta.add_field(name, m2mfield) + + +class MetaField(Field): + column_name = default = model = name = None + primary_key = False + + +class ManyToManyFieldAccessor(FieldAccessor): + def __init__(self, model, field, name): + super(ManyToManyFieldAccessor, self).__init__(model, field, name) + self.model = field.model + self.rel_model = field.rel_model + self.through_model = field.through_model + src_fks = self.through_model._meta.model_refs[self.model] + dest_fks = self.through_model._meta.model_refs[self.rel_model] + if not src_fks: + raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % + (self.model, self.through_model)) + elif not dest_fks: + raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % + (self.rel_model, self.through_model)) + self.src_fk = src_fks[0] + self.dest_fk = dest_fks[0] + + def __get__(self, instance, instance_type=None, force_query=False): + if instance is not None: + if not force_query and self.src_fk.backref != '+': + backref = getattr(instance, self.src_fk.backref) + if isinstance(backref, list): + return [getattr(obj, self.dest_fk.name) for obj in backref] + + src_id = getattr(instance, self.src_fk.rel_field.name) + return (ManyToManyQuery(instance, self, self.rel_model) + .join(self.through_model) + .join(self.model) + .where(self.src_fk == src_id)) + + return self.field + + def __set__(self, instance, value): + query = self.__get__(instance, force_query=True) + query.add(value, clear_existing=True) + + +class ManyToManyField(MetaField): + accessor_class = ManyToManyFieldAccessor + + def __init__(self, model, backref=None, through_model=None, on_delete=None, + on_update=None, _is_backref=False): + if through_model is not None: + if not (isinstance(through_model, DeferredThroughModel) or + is_model(through_model)): + raise TypeError('Unexpected value for through_model. Expected ' + 'Model or DeferredThroughModel.') + if not _is_backref and (on_delete is not None or on_update is not None): + raise ValueError('Cannot specify on_delete or on_update when ' + 'through_model is specified.') + self.rel_model = model + self.backref = backref + self._through_model = through_model + self._on_delete = on_delete + self._on_update = on_update + self._is_backref = _is_backref + + def _get_descriptor(self): + return ManyToManyFieldAccessor(self) + + def bind(self, model, name, set_attribute=True): + if isinstance(self._through_model, DeferredThroughModel): + self._through_model.set_field(model, self, name) + return + + super(ManyToManyField, self).bind(model, name, set_attribute) + + if not self._is_backref: + many_to_many_field = ManyToManyField( + self.model, + backref=name, + through_model=self.through_model, + on_delete=self._on_delete, + on_update=self._on_update, + _is_backref=True) + self.backref = self.backref or model._meta.name + 's' + self.rel_model._meta.add_field(self.backref, many_to_many_field) + + def get_models(self): + return [model for _, model in sorted(( + (self._is_backref, self.model), + (not self._is_backref, self.rel_model)))] + + @property + def through_model(self): + if self._through_model is None: + self._through_model = self._create_through_model() + return self._through_model + + @through_model.setter + def through_model(self, value): + self._through_model = value + + def _create_through_model(self): + lhs, rhs = self.get_models() + tables = [model._meta.table_name for model in (lhs, rhs)] + + class Meta: + database = self.model._meta.database + schema = self.model._meta.schema + table_name = '%s_%s_through' % tuple(tables) + indexes = ( + ((lhs._meta.name, rhs._meta.name), + True),) + + params = {'on_delete': self._on_delete, 'on_update': self._on_update} + attrs = { + lhs._meta.name: ForeignKeyField(lhs, **params), + rhs._meta.name: ForeignKeyField(rhs, **params), + 'Meta': Meta} + + klass_name = '%s%sThrough' % (lhs.__name__, rhs.__name__) + return type(klass_name, (Model,), attrs) + + def get_through_model(self): + # XXX: Deprecated. Just use the "through_model" property. + return self.through_model + + +class VirtualField(MetaField): + field_class = None + + def __init__(self, field_class=None, *args, **kwargs): + Field = field_class if field_class is not None else self.field_class + self.field_instance = Field() if Field is not None else None + super(VirtualField, self).__init__(*args, **kwargs) + + def db_value(self, value): + if self.field_instance is not None: + return self.field_instance.db_value(value) + return value + + def python_value(self, value): + if self.field_instance is not None: + return self.field_instance.python_value(value) + return value + + def bind(self, model, name, set_attribute=True): + self.model = model + self.column_name = self.name = self.safe_name = name + setattr(model, name, self.accessor_class(model, self, name)) + + +class CompositeKey(MetaField): + sequence = None + + def __init__(self, *field_names): + self.field_names = field_names + + def __get__(self, instance, instance_type=None): + if instance is not None: + return tuple([getattr(instance, field_name) + for field_name in self.field_names]) + return self + + def __set__(self, instance, value): + if not isinstance(value, (list, tuple)): + raise TypeError('A list or tuple must be used to set the value of ' + 'a composite primary key.') + if len(value) != len(self.field_names): + raise ValueError('The length of the value must equal the number ' + 'of columns of the composite primary key.') + for idx, field_value in enumerate(value): + setattr(instance, self.field_names[idx], field_value) + + def __eq__(self, other): + expressions = [(self.model._meta.fields[field] == value) + for field, value in zip(self.field_names, other)] + return reduce(operator.and_, expressions) + + def __ne__(self, other): + return ~(self == other) + + def __hash__(self): + return hash((self.model.__name__, self.field_names)) + + def __sql__(self, ctx): + # If the composite PK is being selected, do not use parens. Elsewhere, + # such as in an expression, we want to use parentheses and treat it as + # a row value. + parens = ctx.scope != SCOPE_SOURCE + return ctx.sql(NodeList([self.model._meta.fields[field] + for field in self.field_names], ', ', parens)) + + def bind(self, model, name, set_attribute=True): + self.model = model + self.column_name = self.name = self.safe_name = name + setattr(model, self.name, self) + + +class _SortedFieldList(object): + __slots__ = ('_keys', '_items') + + def __init__(self): + self._keys = [] + self._items = [] + + def __getitem__(self, i): + return self._items[i] + + def __iter__(self): + return iter(self._items) + + def __contains__(self, item): + k = item._sort_key + i = bisect_left(self._keys, k) + j = bisect_right(self._keys, k) + return item in self._items[i:j] + + def index(self, field): + return self._keys.index(field._sort_key) + + def insert(self, item): + k = item._sort_key + i = bisect_left(self._keys, k) + self._keys.insert(i, k) + self._items.insert(i, item) + + def remove(self, item): + idx = self.index(item) + del self._items[idx] + del self._keys[idx] + + +# MODELS + + +class SchemaManager(object): + def __init__(self, model, database=None, **context_options): + self.model = model + self._database = database + context_options.setdefault('scope', SCOPE_VALUES) + self.context_options = context_options + + @property + def database(self): + db = self._database or self.model._meta.database + if db is None: + raise ImproperlyConfigured('database attribute does not appear to ' + 'be set on the model: %s' % self.model) + return db + + @database.setter + def database(self, value): + self._database = value + + def _create_context(self): + return self.database.get_sql_context(**self.context_options) + + def _create_table(self, safe=True, **options): + is_temp = options.pop('temporary', False) + ctx = self._create_context() + ctx.literal('CREATE TEMPORARY TABLE ' if is_temp else 'CREATE TABLE ') + if safe: + ctx.literal('IF NOT EXISTS ') + ctx.sql(self.model).literal(' ') + + columns = [] + constraints = [] + meta = self.model._meta + if meta.composite_key: + pk_columns = [meta.fields[field_name].column + for field_name in meta.primary_key.field_names] + constraints.append(NodeList((SQL('PRIMARY KEY'), + EnclosedNodeList(pk_columns)))) + + for field in meta.sorted_fields: + columns.append(field.ddl(ctx)) + if isinstance(field, ForeignKeyField) and not field.deferred: + constraints.append(field.foreign_key_constraint()) + + if meta.constraints: + constraints.extend(meta.constraints) + + constraints.extend(self._create_table_option_sql(options)) + ctx.sql(EnclosedNodeList(columns + constraints)) + + if meta.table_settings is not None: + table_settings = ensure_tuple(meta.table_settings) + for setting in table_settings: + if not isinstance(setting, basestring): + raise ValueError('table_settings must be strings') + ctx.literal(' ').literal(setting) + + if meta.without_rowid: + ctx.literal(' WITHOUT ROWID') + return ctx + + def _create_table_option_sql(self, options): + accum = [] + options = merge_dict(self.model._meta.options or {}, options) + if not options: + return accum + + for key, value in sorted(options.items()): + if not isinstance(value, Node): + if is_model(value): + value = value._meta.table + else: + value = SQL(str(value)) + accum.append(NodeList((SQL(key), value), glue='=')) + return accum + + def create_table(self, safe=True, **options): + self.database.execute(self._create_table(safe=safe, **options)) + + def _create_table_as(self, table_name, query, safe=True, **meta): + ctx = (self._create_context() + .literal('CREATE TEMPORARY TABLE ' + if meta.get('temporary') else 'CREATE TABLE ')) + if safe: + ctx.literal('IF NOT EXISTS ') + return (ctx + .sql(Entity(table_name)) + .literal(' AS ') + .sql(query)) + + def create_table_as(self, table_name, query, safe=True, **meta): + ctx = self._create_table_as(table_name, query, safe=safe, **meta) + self.database.execute(ctx) + + def _drop_table(self, safe=True, **options): + ctx = (self._create_context() + .literal('DROP TABLE IF EXISTS ' if safe else 'DROP TABLE ') + .sql(self.model)) + if options.get('cascade'): + ctx = ctx.literal(' CASCADE') + elif options.get('restrict'): + ctx = ctx.literal(' RESTRICT') + return ctx + + def drop_table(self, safe=True, **options): + self.database.execute(self._drop_table(safe=safe, **options)) + + def _truncate_table(self, restart_identity=False, cascade=False): + db = self.database + if not db.truncate_table: + return (self._create_context() + .literal('DELETE FROM ').sql(self.model)) + + ctx = self._create_context().literal('TRUNCATE TABLE ').sql(self.model) + if restart_identity: + ctx = ctx.literal(' RESTART IDENTITY') + if cascade: + ctx = ctx.literal(' CASCADE') + return ctx + + def truncate_table(self, restart_identity=False, cascade=False): + self.database.execute(self._truncate_table(restart_identity, cascade)) + + def _create_indexes(self, safe=True): + return [self._create_index(index, safe) + for index in self.model._meta.fields_to_index()] + + def _create_index(self, index, safe=True): + if isinstance(index, Index): + if not self.database.safe_create_index: + index = index.safe(False) + elif index._safe != safe: + index = index.safe(safe) + return self._create_context().sql(index) + + def create_indexes(self, safe=True): + for query in self._create_indexes(safe=safe): + self.database.execute(query) + + def _drop_indexes(self, safe=True): + return [self._drop_index(index, safe) + for index in self.model._meta.fields_to_index() + if isinstance(index, Index)] + + def _drop_index(self, index, safe): + statement = 'DROP INDEX ' + if safe and self.database.safe_drop_index: + statement += 'IF EXISTS ' + if isinstance(index._table, Table) and index._table._schema: + index_name = Entity(index._table._schema, index._name) + else: + index_name = Entity(index._name) + return (self + ._create_context() + .literal(statement) + .sql(index_name)) + + def drop_indexes(self, safe=True): + for query in self._drop_indexes(safe=safe): + self.database.execute(query) + + def _check_sequences(self, field): + if not field.sequence or not self.database.sequences: + raise ValueError('Sequences are either not supported, or are not ' + 'defined for "%s".' % field.name) + + def _sequence_for_field(self, field): + if field.model._meta.schema: + return Entity(field.model._meta.schema, field.sequence) + else: + return Entity(field.sequence) + + def _create_sequence(self, field): + self._check_sequences(field) + if not self.database.sequence_exists(field.sequence): + return (self + ._create_context() + .literal('CREATE SEQUENCE ') + .sql(self._sequence_for_field(field))) + + def create_sequence(self, field): + seq_ctx = self._create_sequence(field) + if seq_ctx is not None: + self.database.execute(seq_ctx) + + def _drop_sequence(self, field): + self._check_sequences(field) + if self.database.sequence_exists(field.sequence): + return (self + ._create_context() + .literal('DROP SEQUENCE ') + .sql(self._sequence_for_field(field))) + + def drop_sequence(self, field): + seq_ctx = self._drop_sequence(field) + if seq_ctx is not None: + self.database.execute(seq_ctx) + + def _create_foreign_key(self, field): + name = 'fk_%s_%s_refs_%s' % (field.model._meta.table_name, + field.column_name, + field.rel_model._meta.table_name) + return (self + ._create_context() + .literal('ALTER TABLE ') + .sql(field.model) + .literal(' ADD CONSTRAINT ') + .sql(Entity(_truncate_constraint_name(name))) + .literal(' ') + .sql(field.foreign_key_constraint())) + + def create_foreign_key(self, field): + self.database.execute(self._create_foreign_key(field)) + + def create_sequences(self): + if self.database.sequences: + for field in self.model._meta.sorted_fields: + if field.sequence: + self.create_sequence(field) + + def create_all(self, safe=True, **table_options): + self.create_sequences() + self.create_table(safe, **table_options) + self.create_indexes(safe=safe) + + def drop_sequences(self): + if self.database.sequences: + for field in self.model._meta.sorted_fields: + if field.sequence: + self.drop_sequence(field) + + def drop_all(self, safe=True, drop_sequences=True, **options): + self.drop_table(safe, **options) + if drop_sequences: + self.drop_sequences() + + +class Metadata(object): + def __init__(self, model, database=None, table_name=None, indexes=None, + primary_key=None, constraints=None, schema=None, + only_save_dirty=False, depends_on=None, options=None, + db_table=None, table_function=None, table_settings=None, + without_rowid=False, temporary=False, legacy_table_names=True, + **kwargs): + if db_table is not None: + __deprecated__('"db_table" has been deprecated in favor of ' + '"table_name" for Models.') + table_name = db_table + self.model = model + self.database = database + + self.fields = {} + self.columns = {} + self.combined = {} + + self._sorted_field_list = _SortedFieldList() + self.sorted_fields = [] + self.sorted_field_names = [] + + self.defaults = {} + self._default_by_name = {} + self._default_dict = {} + self._default_callables = {} + self._default_callable_list = [] + + self.name = model.__name__.lower() + self.table_function = table_function + self.legacy_table_names = legacy_table_names + if not table_name: + table_name = (self.table_function(model) + if self.table_function + else self.make_table_name()) + self.table_name = table_name + self._table = None + + self.indexes = list(indexes) if indexes else [] + self.constraints = constraints + self._schema = schema + self.primary_key = primary_key + self.composite_key = self.auto_increment = None + self.only_save_dirty = only_save_dirty + self.depends_on = depends_on + self.table_settings = table_settings + self.without_rowid = without_rowid + self.temporary = temporary + + self.refs = {} + self.backrefs = {} + self.model_refs = collections.defaultdict(list) + self.model_backrefs = collections.defaultdict(list) + self.manytomany = {} + + self.options = options or {} + for key, value in kwargs.items(): + setattr(self, key, value) + self._additional_keys = set(kwargs.keys()) + + # Allow objects to register hooks that are called if the model is bound + # to a different database. For example, BlobField uses a different + # Python data-type depending on the db driver / python version. When + # the database changes, we need to update any BlobField so they can use + # the appropriate data-type. + self._db_hooks = [] + + def make_table_name(self): + if self.legacy_table_names: + return re.sub('[^\w]+', '_', self.name) + return make_snake_case(self.model.__name__) + + def model_graph(self, refs=True, backrefs=True, depth_first=True): + if not refs and not backrefs: + raise ValueError('One of `refs` or `backrefs` must be True.') + + accum = [(None, self.model, None)] + seen = set() + queue = collections.deque((self,)) + method = queue.pop if depth_first else queue.popleft + + while queue: + curr = method() + if curr in seen: continue + seen.add(curr) + + if refs: + for fk, model in curr.refs.items(): + accum.append((fk, model, False)) + queue.append(model._meta) + if backrefs: + for fk, model in curr.backrefs.items(): + accum.append((fk, model, True)) + queue.append(model._meta) + + return accum + + def add_ref(self, field): + rel = field.rel_model + self.refs[field] = rel + self.model_refs[rel].append(field) + rel._meta.backrefs[field] = self.model + rel._meta.model_backrefs[self.model].append(field) + + def remove_ref(self, field): + rel = field.rel_model + del self.refs[field] + self.model_refs[rel].remove(field) + del rel._meta.backrefs[field] + rel._meta.model_backrefs[self.model].remove(field) + + def add_manytomany(self, field): + self.manytomany[field.name] = field + + def remove_manytomany(self, field): + del self.manytomany[field.name] + + @property + def table(self): + if self._table is None: + self._table = Table( + self.table_name, + [field.column_name for field in self.sorted_fields], + schema=self.schema, + _model=self.model, + _database=self.database) + return self._table + + @table.setter + def table(self, value): + raise AttributeError('Cannot set the "table".') + + @table.deleter + def table(self): + self._table = None + + @property + def schema(self): + return self._schema + + @schema.setter + def schema(self, value): + self._schema = value + del self.table + + @property + def entity(self): + if self._schema: + return Entity(self._schema, self.table_name) + else: + return Entity(self.table_name) + + def _update_sorted_fields(self): + self.sorted_fields = list(self._sorted_field_list) + self.sorted_field_names = [f.name for f in self.sorted_fields] + + def get_rel_for_model(self, model): + if isinstance(model, ModelAlias): + model = model.model + forwardrefs = self.model_refs.get(model, []) + backrefs = self.model_backrefs.get(model, []) + return (forwardrefs, backrefs) + + def add_field(self, field_name, field, set_attribute=True): + if field_name in self.fields: + self.remove_field(field_name) + elif field_name in self.manytomany: + self.remove_manytomany(self.manytomany[field_name]) + + if not isinstance(field, MetaField): + del self.table + field.bind(self.model, field_name, set_attribute) + self.fields[field.name] = field + self.columns[field.column_name] = field + self.combined[field.name] = field + self.combined[field.column_name] = field + + self._sorted_field_list.insert(field) + self._update_sorted_fields() + + if field.default is not None: + # This optimization helps speed up model instance construction. + self.defaults[field] = field.default + if callable_(field.default): + self._default_callables[field] = field.default + self._default_callable_list.append((field.name, + field.default)) + else: + self._default_dict[field] = field.default + self._default_by_name[field.name] = field.default + else: + field.bind(self.model, field_name, set_attribute) + + if isinstance(field, ForeignKeyField): + self.add_ref(field) + elif isinstance(field, ManyToManyField) and field.name: + self.add_manytomany(field) + + def remove_field(self, field_name): + if field_name not in self.fields: + return + + del self.table + original = self.fields.pop(field_name) + del self.columns[original.column_name] + del self.combined[field_name] + try: + del self.combined[original.column_name] + except KeyError: + pass + self._sorted_field_list.remove(original) + self._update_sorted_fields() + + if original.default is not None: + del self.defaults[original] + if self._default_callables.pop(original, None): + for i, (name, _) in enumerate(self._default_callable_list): + if name == field_name: + self._default_callable_list.pop(i) + break + else: + self._default_dict.pop(original, None) + self._default_by_name.pop(original.name, None) + + if isinstance(original, ForeignKeyField): + self.remove_ref(original) + + def set_primary_key(self, name, field): + self.composite_key = isinstance(field, CompositeKey) + self.add_field(name, field) + self.primary_key = field + self.auto_increment = ( + field.auto_increment or + bool(field.sequence)) + + def get_primary_keys(self): + if self.composite_key: + return tuple([self.fields[field_name] + for field_name in self.primary_key.field_names]) + else: + return (self.primary_key,) if self.primary_key is not False else () + + def get_default_dict(self): + dd = self._default_by_name.copy() + for field_name, default in self._default_callable_list: + dd[field_name] = default() + return dd + + def fields_to_index(self): + indexes = [] + for f in self.sorted_fields: + if f.primary_key: + continue + if f.index or f.unique: + indexes.append(ModelIndex(self.model, (f,), unique=f.unique, + using=f.index_type)) + + for index_obj in self.indexes: + if isinstance(index_obj, Node): + indexes.append(index_obj) + elif isinstance(index_obj, (list, tuple)): + index_parts, unique = index_obj + fields = [] + for part in index_parts: + if isinstance(part, basestring): + fields.append(self.combined[part]) + elif isinstance(part, Node): + fields.append(part) + else: + raise ValueError('Expected either a field name or a ' + 'subclass of Node. Got: %s' % part) + indexes.append(ModelIndex(self.model, fields, unique=unique)) + + return indexes + + def set_database(self, database): + self.database = database + self.model._schema._database = database + del self.table + + # Apply any hooks that have been registered. + for hook in self._db_hooks: + hook(database) + + def set_table_name(self, table_name): + self.table_name = table_name + del self.table + + +class SubclassAwareMetadata(Metadata): + models = [] + + def __init__(self, model, *args, **kwargs): + super(SubclassAwareMetadata, self).__init__(model, *args, **kwargs) + self.models.append(model) + + def map_models(self, fn): + for model in self.models: + fn(model) + + +class DoesNotExist(Exception): pass + + +class ModelBase(type): + inheritable = set(['constraints', 'database', 'indexes', 'primary_key', + 'options', 'schema', 'table_function', 'temporary', + 'only_save_dirty', 'legacy_table_names', + 'table_settings']) + + def __new__(cls, name, bases, attrs): + if name == MODEL_BASE or bases[0].__name__ == MODEL_BASE: + return super(ModelBase, cls).__new__(cls, name, bases, attrs) + + meta_options = {} + meta = attrs.pop('Meta', None) + if meta: + for k, v in meta.__dict__.items(): + if not k.startswith('_'): + meta_options[k] = v + + pk = getattr(meta, 'primary_key', None) + pk_name = parent_pk = None + + # Inherit any field descriptors by deep copying the underlying field + # into the attrs of the new model, additionally see if the bases define + # inheritable model options and swipe them. + for b in bases: + if not hasattr(b, '_meta'): + continue + + base_meta = b._meta + if parent_pk is None: + parent_pk = deepcopy(base_meta.primary_key) + all_inheritable = cls.inheritable | base_meta._additional_keys + for k in base_meta.__dict__: + if k in all_inheritable and k not in meta_options: + meta_options[k] = base_meta.__dict__[k] + meta_options.setdefault('schema', base_meta.schema) + + for (k, v) in b.__dict__.items(): + if k in attrs: continue + + if isinstance(v, FieldAccessor) and not v.field.primary_key: + attrs[k] = deepcopy(v.field) + + sopts = meta_options.pop('schema_options', None) or {} + Meta = meta_options.get('model_metadata_class', Metadata) + Schema = meta_options.get('schema_manager_class', SchemaManager) + + # Construct the new class. + cls = super(ModelBase, cls).__new__(cls, name, bases, attrs) + cls.__data__ = cls.__rel__ = None + + cls._meta = Meta(cls, **meta_options) + cls._schema = Schema(cls, **sopts) + + fields = [] + for key, value in cls.__dict__.items(): + if isinstance(value, Field): + if value.primary_key and pk: + raise ValueError('over-determined primary key %s.' % name) + elif value.primary_key: + pk, pk_name = value, key + else: + fields.append((key, value)) + + if pk is None: + if parent_pk is not False: + pk, pk_name = ((parent_pk, parent_pk.name) + if parent_pk is not None else + (AutoField(), 'id')) + else: + pk = False + elif isinstance(pk, CompositeKey): + pk_name = '__composite_key__' + cls._meta.composite_key = True + + if pk is not False: + cls._meta.set_primary_key(pk_name, pk) + + for name, field in fields: + cls._meta.add_field(name, field) + + # Create a repr and error class before finalizing. + if hasattr(cls, '__str__') and '__repr__' not in attrs: + setattr(cls, '__repr__', lambda self: '<%s: %s>' % ( + cls.__name__, self.__str__())) + + exc_name = '%sDoesNotExist' % cls.__name__ + exc_attrs = {'__module__': cls.__module__} + exception_class = type(exc_name, (DoesNotExist,), exc_attrs) + cls.DoesNotExist = exception_class + + # Call validation hook, allowing additional model validation. + cls.validate_model() + DeferredForeignKey.resolve(cls) + return cls + + def __repr__(self): + return '' % self.__name__ + + def __iter__(self): + return iter(self.select()) + + def __getitem__(self, key): + return self.get_by_id(key) + + def __setitem__(self, key, value): + self.set_by_id(key, value) + + def __delitem__(self, key): + self.delete_by_id(key) + + def __contains__(self, key): + try: + self.get_by_id(key) + except self.DoesNotExist: + return False + else: + return True + + def __len__(self): + return self.select().count() + def __bool__(self): return True + __nonzero__ = __bool__ # Python 2. + + +class _BoundModelsContext(_callable_context_manager): + def __init__(self, models, database, bind_refs, bind_backrefs): + self.models = models + self.database = database + self.bind_refs = bind_refs + self.bind_backrefs = bind_backrefs + + def __enter__(self): + self._orig_database = [] + for model in self.models: + self._orig_database.append(model._meta.database) + model.bind(self.database, self.bind_refs, self.bind_backrefs) + return self.models + + def __exit__(self, exc_type, exc_val, exc_tb): + for model, db in zip(self.models, self._orig_database): + model.bind(db, self.bind_refs, self.bind_backrefs) + + +class Model(with_metaclass(ModelBase, Node)): + def __init__(self, *args, **kwargs): + if kwargs.pop('__no_default__', None): + self.__data__ = {} + else: + self.__data__ = self._meta.get_default_dict() + self._dirty = set(self.__data__) + self.__rel__ = {} + + for k in kwargs: + setattr(self, k, kwargs[k]) + + def __str__(self): + return str(self._pk) if self._meta.primary_key is not False else 'n/a' + + @classmethod + def validate_model(cls): + pass + + @classmethod + def alias(cls, alias=None): + return ModelAlias(cls, alias) + + @classmethod + def select(cls, *fields): + is_default = not fields + if not fields: + fields = cls._meta.sorted_fields + return ModelSelect(cls, fields, is_default=is_default) + + @classmethod + def _normalize_data(cls, data, kwargs): + normalized = {} + if data: + if not isinstance(data, dict): + if kwargs: + raise ValueError('Data cannot be mixed with keyword ' + 'arguments: %s' % data) + return data + for key in data: + try: + field = (key if isinstance(key, Field) + else cls._meta.combined[key]) + except KeyError: + raise ValueError('Unrecognized field name: "%s" in %s.' % + (key, data)) + normalized[field] = data[key] + if kwargs: + for key in kwargs: + try: + normalized[cls._meta.combined[key]] = kwargs[key] + except KeyError: + normalized[getattr(cls, key)] = kwargs[key] + return normalized + + @classmethod + def update(cls, __data=None, **update): + return ModelUpdate(cls, cls._normalize_data(__data, update)) + + @classmethod + def insert(cls, __data=None, **insert): + return ModelInsert(cls, cls._normalize_data(__data, insert)) + + @classmethod + def insert_many(cls, rows, fields=None): + return ModelInsert(cls, insert=rows, columns=fields) + + @classmethod + def insert_from(cls, query, fields): + columns = [getattr(cls, field) if isinstance(field, basestring) + else field for field in fields] + return ModelInsert(cls, insert=query, columns=columns) + + @classmethod + def replace(cls, __data=None, **insert): + return cls.insert(__data, **insert).on_conflict('REPLACE') + + @classmethod + def replace_many(cls, rows, fields=None): + return (cls + .insert_many(rows=rows, fields=fields) + .on_conflict('REPLACE')) + + @classmethod + def raw(cls, sql, *params): + return ModelRaw(cls, sql, params) + + @classmethod + def delete(cls): + return ModelDelete(cls) + + @classmethod + def create(cls, **query): + inst = cls(**query) + inst.save(force_insert=True) + return inst + + @classmethod + def bulk_create(cls, model_list, batch_size=None): + if batch_size is not None: + batches = chunked(model_list, batch_size) + else: + batches = [model_list] + + field_names = list(cls._meta.sorted_field_names) + if cls._meta.auto_increment: + pk_name = cls._meta.primary_key.name + field_names.remove(pk_name) + + if cls._meta.database.returning_clause and \ + cls._meta.primary_key is not False: + pk_fields = cls._meta.get_primary_keys() + else: + pk_fields = None + + fields = [cls._meta.fields[field_name] for field_name in field_names] + for batch in batches: + accum = ([getattr(model, f) for f in field_names] + for model in batch) + res = cls.insert_many(accum, fields=fields).execute() + if pk_fields and res is not None: + for row, model in zip(res, batch): + for (pk_field, obj_id) in zip(pk_fields, row): + setattr(model, pk_field.name, obj_id) + + @classmethod + def bulk_update(cls, model_list, fields, batch_size=None): + if isinstance(cls._meta.primary_key, CompositeKey): + raise ValueError('bulk_update() is not supported for models with ' + 'a composite primary key.') + + # First normalize list of fields so all are field instances. + fields = [cls._meta.fields[f] if isinstance(f, basestring) else f + for f in fields] + # Now collect list of attribute names to use for values. + attrs = [field.object_id_name if isinstance(field, ForeignKeyField) + else field.name for field in fields] + + if batch_size is not None: + batches = chunked(model_list, batch_size) + else: + batches = [model_list] + + n = 0 + for batch in batches: + id_list = [model._pk for model in batch] + update = {} + for field, attr in zip(fields, attrs): + accum = [] + for model in batch: + value = getattr(model, attr) + if not isinstance(value, Node): + value = field.to_value(value) + accum.append((model._pk, value)) + case = Case(cls._meta.primary_key, accum) + update[field] = case + + n += (cls.update(update) + .where(cls._meta.primary_key.in_(id_list)) + .execute()) + return n + + @classmethod + def noop(cls): + return NoopModelSelect(cls, ()) + + @classmethod + def get(cls, *query, **filters): + sq = cls.select() + if query: + # Handle simple lookup using just the primary key. + if len(query) == 1 and isinstance(query[0], int): + sq = sq.where(cls._meta.primary_key == query[0]) + else: + sq = sq.where(*query) + if filters: + sq = sq.filter(**filters) + return sq.get() + + @classmethod + def get_or_none(cls, *query, **filters): + try: + return cls.get(*query, **filters) + except DoesNotExist: + pass + + @classmethod + def get_by_id(cls, pk): + return cls.get(cls._meta.primary_key == pk) + + @classmethod + def set_by_id(cls, key, value): + if key is None: + return cls.insert(value).execute() + else: + return (cls.update(value) + .where(cls._meta.primary_key == key).execute()) + + @classmethod + def delete_by_id(cls, pk): + return cls.delete().where(cls._meta.primary_key == pk).execute() + + @classmethod + def get_or_create(cls, **kwargs): + defaults = kwargs.pop('defaults', {}) + query = cls.select() + for field, value in kwargs.items(): + query = query.where(getattr(cls, field) == value) + + try: + return query.get(), False + except cls.DoesNotExist: + try: + if defaults: + kwargs.update(defaults) + with cls._meta.database.atomic(): + return cls.create(**kwargs), True + except IntegrityError as exc: + try: + return query.get(), False + except cls.DoesNotExist: + raise exc + + @classmethod + def filter(cls, *dq_nodes, **filters): + return cls.select().filter(*dq_nodes, **filters) + + def get_id(self): + # Using getattr(self, pk-name) could accidentally trigger a query if + # the primary-key is a foreign-key. So we use the safe_name attribute, + # which defaults to the field-name, but will be the object_id_name for + # foreign-key fields. + if self._meta.primary_key is not False: + return getattr(self, self._meta.primary_key.safe_name) + + _pk = property(get_id) + + @_pk.setter + def _pk(self, value): + setattr(self, self._meta.primary_key.name, value) + + def _pk_expr(self): + return self._meta.primary_key == self._pk + + def _prune_fields(self, field_dict, only): + new_data = {} + for field in only: + if isinstance(field, basestring): + field = self._meta.combined[field] + if field.name in field_dict: + new_data[field.name] = field_dict[field.name] + return new_data + + def _populate_unsaved_relations(self, field_dict): + for foreign_key_field in self._meta.refs: + foreign_key = foreign_key_field.name + conditions = ( + foreign_key in field_dict and + field_dict[foreign_key] is None and + self.__rel__.get(foreign_key) is not None) + if conditions: + setattr(self, foreign_key, getattr(self, foreign_key)) + field_dict[foreign_key] = self.__data__[foreign_key] + + def save(self, force_insert=False, only=None): + field_dict = self.__data__.copy() + if self._meta.primary_key is not False: + pk_field = self._meta.primary_key + pk_value = self._pk + else: + pk_field = pk_value = None + if only: + field_dict = self._prune_fields(field_dict, only) + elif self._meta.only_save_dirty and not force_insert: + field_dict = self._prune_fields(field_dict, self.dirty_fields) + if not field_dict: + self._dirty.clear() + return False + + self._populate_unsaved_relations(field_dict) + rows = 1 + + if pk_value is not None and not force_insert: + if self._meta.composite_key: + for pk_part_name in pk_field.field_names: + field_dict.pop(pk_part_name, None) + else: + field_dict.pop(pk_field.name, None) + if not field_dict: + raise ValueError('no data to save!') + rows = self.update(**field_dict).where(self._pk_expr()).execute() + elif pk_field is not None: + pk = self.insert(**field_dict).execute() + if pk is not None and (self._meta.auto_increment or + pk_value is None): + self._pk = pk + else: + self.insert(**field_dict).execute() + + self._dirty.clear() + return rows + + def is_dirty(self): + return bool(self._dirty) + + @property + def dirty_fields(self): + return [f for f in self._meta.sorted_fields if f.name in self._dirty] + + def dependencies(self, search_nullable=False): + model_class = type(self) + stack = [(type(self), None)] + seen = set() + + while stack: + klass, query = stack.pop() + if klass in seen: + continue + seen.add(klass) + for fk, rel_model in klass._meta.backrefs.items(): + if rel_model is model_class or query is None: + node = (fk == self.__data__[fk.rel_field.name]) + else: + node = fk << query + subquery = (rel_model.select(rel_model._meta.primary_key) + .where(node)) + if not fk.null or search_nullable: + stack.append((rel_model, subquery)) + yield (node, fk) + + def delete_instance(self, recursive=False, delete_nullable=False): + if recursive: + dependencies = self.dependencies(delete_nullable) + for query, fk in reversed(list(dependencies)): + model = fk.model + if fk.null and not delete_nullable: + model.update(**{fk.name: None}).where(query).execute() + else: + model.delete().where(query).execute() + return type(self).delete().where(self._pk_expr()).execute() + + def __hash__(self): + return hash((self.__class__, self._pk)) + + def __eq__(self, other): + return ( + other.__class__ == self.__class__ and + self._pk is not None and + self._pk == other._pk) + + def __ne__(self, other): + return not self == other + + def __sql__(self, ctx): + return ctx.sql(Value(getattr(self, self._meta.primary_key.name), + converter=self._meta.primary_key.db_value)) + + @classmethod + def bind(cls, database, bind_refs=True, bind_backrefs=True): + is_different = cls._meta.database is not database + cls._meta.set_database(database) + if bind_refs or bind_backrefs: + G = cls._meta.model_graph(refs=bind_refs, backrefs=bind_backrefs) + for _, model, is_backref in G: + model._meta.set_database(database) + return is_different + + @classmethod + def bind_ctx(cls, database, bind_refs=True, bind_backrefs=True): + return _BoundModelsContext((cls,), database, bind_refs, bind_backrefs) + + @classmethod + def table_exists(cls): + M = cls._meta + return cls._schema.database.table_exists(M.table.__name__, M.schema) + + @classmethod + def create_table(cls, safe=True, **options): + if 'fail_silently' in options: + __deprecated__('"fail_silently" has been deprecated in favor of ' + '"safe" for the create_table() method.') + safe = options.pop('fail_silently') + + if safe and not cls._schema.database.safe_create_index \ + and cls.table_exists(): + return + if cls._meta.temporary: + options.setdefault('temporary', cls._meta.temporary) + cls._schema.create_all(safe, **options) + + @classmethod + def drop_table(cls, safe=True, drop_sequences=True, **options): + if safe and not cls._schema.database.safe_drop_index \ + and not cls.table_exists(): + return + if cls._meta.temporary: + options.setdefault('temporary', cls._meta.temporary) + cls._schema.drop_all(safe, drop_sequences, **options) + + @classmethod + def truncate_table(cls, **options): + cls._schema.truncate_table(**options) + + @classmethod + def index(cls, *fields, **kwargs): + return ModelIndex(cls, fields, **kwargs) + + @classmethod + def add_index(cls, *fields, **kwargs): + if len(fields) == 1 and isinstance(fields[0], (SQL, Index)): + cls._meta.indexes.append(fields[0]) + else: + cls._meta.indexes.append(ModelIndex(cls, fields, **kwargs)) + + +class ModelAlias(Node): + """Provide a separate reference to a model in a query.""" + def __init__(self, model, alias=None): + self.__dict__['model'] = model + self.__dict__['alias'] = alias + + def __getattr__(self, attr): + # Hack to work-around the fact that properties or other objects + # implementing the descriptor protocol (on the model being aliased), + # will not work correctly when we use getattr(). So we explicitly pass + # the model alias to the descriptor's getter. + try: + obj = self.model.__dict__[attr] + except KeyError: + pass + else: + if isinstance(obj, ModelDescriptor): + return obj.__get__(None, self) + + model_attr = getattr(self.model, attr) + if isinstance(model_attr, Field): + self.__dict__[attr] = FieldAlias.create(self, model_attr) + return self.__dict__[attr] + return model_attr + + def __setattr__(self, attr, value): + raise AttributeError('Cannot set attributes on model aliases.') + + def get_field_aliases(self): + return [getattr(self, n) for n in self.model._meta.sorted_field_names] + + def select(self, *selection): + if not selection: + selection = self.get_field_aliases() + return ModelSelect(self, selection) + + def __call__(self, **kwargs): + return self.model(**kwargs) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + # Return the quoted table name. + return ctx.sql(self.model) + + if self.alias: + ctx.alias_manager[self] = self.alias + + if ctx.scope == SCOPE_SOURCE: + # Define the table and its alias. + return (ctx + .sql(self.model._meta.entity) + .literal(' AS ') + .sql(Entity(ctx.alias_manager[self]))) + else: + # Refer to the table using the alias. + return ctx.sql(Entity(ctx.alias_manager[self])) + + +class FieldAlias(Field): + def __init__(self, source, field): + self.source = source + self.model = source.model + self.field = field + + @classmethod + def create(cls, source, field): + class _FieldAlias(cls, type(field)): + pass + return _FieldAlias(source, field) + + def clone(self): + return FieldAlias(self.source, self.field) + + def adapt(self, value): return self.field.adapt(value) + def python_value(self, value): return self.field.python_value(value) + def db_value(self, value): return self.field.db_value(value) + def __getattr__(self, attr): + return self.source if attr == 'model' else getattr(self.field, attr) + + def __sql__(self, ctx): + return ctx.sql(Column(self.source, self.field.column_name)) + + +def sort_models(models): + models = set(models) + seen = set() + ordering = [] + def dfs(model): + if model in models and model not in seen: + seen.add(model) + for foreign_key, rel_model in model._meta.refs.items(): + # Do not depth-first search deferred foreign-keys as this can + # cause tables to be created in the incorrect order. + if not foreign_key.deferred: + dfs(rel_model) + if model._meta.depends_on: + for dependency in model._meta.depends_on: + dfs(dependency) + ordering.append(model) + + names = lambda m: (m._meta.name, m._meta.table_name) + for m in sorted(models, key=names): + dfs(m) + return ordering + + +class _ModelQueryHelper(object): + default_row_type = ROW.MODEL + + def __init__(self, *args, **kwargs): + super(_ModelQueryHelper, self).__init__(*args, **kwargs) + if not self._database: + self._database = self.model._meta.database + + @Node.copy + def objects(self, constructor=None): + self._row_type = ROW.CONSTRUCTOR + self._constructor = self.model if constructor is None else constructor + + def _get_cursor_wrapper(self, cursor): + row_type = self._row_type or self.default_row_type + if row_type == ROW.MODEL: + return self._get_model_cursor_wrapper(cursor) + elif row_type == ROW.DICT: + return ModelDictCursorWrapper(cursor, self.model, self._returning) + elif row_type == ROW.TUPLE: + return ModelTupleCursorWrapper(cursor, self.model, self._returning) + elif row_type == ROW.NAMED_TUPLE: + return ModelNamedTupleCursorWrapper(cursor, self.model, + self._returning) + elif row_type == ROW.CONSTRUCTOR: + return ModelObjectCursorWrapper(cursor, self.model, + self._returning, self._constructor) + else: + raise ValueError('Unrecognized row type: "%s".' % row_type) + + def _get_model_cursor_wrapper(self, cursor): + return ModelObjectCursorWrapper(cursor, self.model, [], self.model) + + +class ModelRaw(_ModelQueryHelper, RawQuery): + def __init__(self, model, sql, params, **kwargs): + self.model = model + self._returning = () + super(ModelRaw, self).__init__(sql=sql, params=params, **kwargs) + + def get(self): + try: + return self.execute()[0] + except IndexError: + sql, params = self.sql() + raise self.model.DoesNotExist('%s instance matching query does ' + 'not exist:\nSQL: %s\nParams: %s' % + (self.model, sql, params)) + + +class BaseModelSelect(_ModelQueryHelper): + def union_all(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) + __add__ = union_all + + def union(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'UNION', rhs) + __or__ = union + + def intersect(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) + __and__ = intersect + + def except_(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) + __sub__ = except_ + + def __iter__(self): + if not self._cursor_wrapper: + self.execute() + return iter(self._cursor_wrapper) + + def prefetch(self, *subqueries): + return prefetch(self, *subqueries) + + def get(self, database=None): + clone = self.paginate(1, 1) + clone._cursor_wrapper = None + try: + return clone.execute(database)[0] + except IndexError: + sql, params = clone.sql() + raise self.model.DoesNotExist('%s instance matching query does ' + 'not exist:\nSQL: %s\nParams: %s' % + (clone.model, sql, params)) + + @Node.copy + def group_by(self, *columns): + grouping = [] + for column in columns: + if is_model(column): + grouping.extend(column._meta.sorted_fields) + elif isinstance(column, Table): + if not column._columns: + raise ValueError('Cannot pass a table to group_by() that ' + 'does not have columns explicitly ' + 'declared.') + grouping.extend([getattr(column, col_name) + for col_name in column._columns]) + else: + grouping.append(column) + self._group_by = grouping + + +class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): + def __init__(self, model, *args, **kwargs): + self.model = model + super(ModelCompoundSelectQuery, self).__init__(*args, **kwargs) + + def _get_model_cursor_wrapper(self, cursor): + return self.lhs._get_model_cursor_wrapper(cursor) + + +def _normalize_model_select(fields_or_models): + fields = [] + for fm in fields_or_models: + if is_model(fm): + fields.extend(fm._meta.sorted_fields) + elif isinstance(fm, ModelAlias): + fields.extend(fm.get_field_aliases()) + elif isinstance(fm, Table) and fm._columns: + fields.extend([getattr(fm, col) for col in fm._columns]) + else: + fields.append(fm) + return fields + + +class ModelSelect(BaseModelSelect, Select): + def __init__(self, model, fields_or_models, is_default=False): + self.model = self._join_ctx = model + self._joins = {} + self._is_default = is_default + fields = _normalize_model_select(fields_or_models) + super(ModelSelect, self).__init__([model], fields) + + def clone(self): + clone = super(ModelSelect, self).clone() + if clone._joins: + clone._joins = dict(clone._joins) + return clone + + def select(self, *fields_or_models): + if fields_or_models or not self._is_default: + self._is_default = False + fields = _normalize_model_select(fields_or_models) + return super(ModelSelect, self).select(*fields) + return self + + def switch(self, ctx=None): + self._join_ctx = self.model if ctx is None else ctx + return self + + def _get_model(self, src): + if is_model(src): + return src, True + elif isinstance(src, Table) and src._model: + return src._model, False + elif isinstance(src, ModelAlias): + return src.model, False + elif isinstance(src, ModelSelect): + return src.model, False + return None, False + + def _normalize_join(self, src, dest, on, attr): + # Allow "on" expression to have an alias that determines the + # destination attribute for the joined data. + on_alias = isinstance(on, Alias) + if on_alias: + attr = attr or on._alias + on = on.alias() + + # Obtain references to the source and destination models being joined. + src_model, src_is_model = self._get_model(src) + dest_model, dest_is_model = self._get_model(dest) + + if src_model and dest_model: + self._join_ctx = dest + constructor = dest_model + + # In the case where the "on" clause is a Column or Field, we will + # convert that field into the appropriate predicate expression. + if not (src_is_model and dest_is_model) and isinstance(on, Column): + if on.source is src: + to_field = src_model._meta.columns[on.name] + elif on.source is dest: + to_field = dest_model._meta.columns[on.name] + else: + raise AttributeError('"on" clause Column %s does not ' + 'belong to %s or %s.' % + (on, src_model, dest_model)) + on = None + elif isinstance(on, Field): + to_field = on + on = None + else: + to_field = None + + fk_field, is_backref = self._generate_on_clause( + src_model, dest_model, to_field, on) + + if on is None: + src_attr = 'name' if src_is_model else 'column_name' + dest_attr = 'name' if dest_is_model else 'column_name' + if is_backref: + lhs = getattr(dest, getattr(fk_field, dest_attr)) + rhs = getattr(src, getattr(fk_field.rel_field, src_attr)) + else: + lhs = getattr(src, getattr(fk_field, src_attr)) + rhs = getattr(dest, getattr(fk_field.rel_field, dest_attr)) + on = (lhs == rhs) + + if not attr: + if fk_field is not None and not is_backref: + attr = fk_field.name + else: + attr = dest_model._meta.name + elif on_alias and fk_field is not None and \ + attr == fk_field.object_id_name and not is_backref: + raise ValueError('Cannot assign join alias to "%s", as this ' + 'attribute is the object_id_name for the ' + 'foreign-key field "%s"' % (attr, fk_field)) + + elif isinstance(dest, Source): + constructor = dict + attr = attr or dest._alias + if not attr and isinstance(dest, Table): + attr = attr or dest.__name__ + + return (on, attr, constructor) + + def _generate_on_clause(self, src, dest, to_field=None, on=None): + meta = src._meta + is_backref = fk_fields = False + + # Get all the foreign keys between source and dest, and determine if + # the join is via a back-reference. + if dest in meta.model_refs: + fk_fields = meta.model_refs[dest] + elif dest in meta.model_backrefs: + fk_fields = meta.model_backrefs[dest] + is_backref = True + + if not fk_fields: + if on is not None: + return None, False + raise ValueError('Unable to find foreign key between %s and %s. ' + 'Please specify an explicit join condition.' % + (src, dest)) + elif to_field is not None: + # If the foreign-key field was specified explicitly, remove all + # other foreign-key fields from the list. + target = (to_field.field if isinstance(to_field, FieldAlias) + else to_field) + fk_fields = [f for f in fk_fields if ( + (f is target) or + (is_backref and f.rel_field is to_field))] + + if len(fk_fields) == 1: + return fk_fields[0], is_backref + + if on is None: + # If multiple foreign-keys exist, try using the FK whose name + # matches that of the related model. If not, raise an error as this + # is ambiguous. + for fk in fk_fields: + if fk.name == dest._meta.name: + return fk, is_backref + + raise ValueError('More than one foreign key between %s and %s.' + ' Please specify which you are joining on.' % + (src, dest)) + + # If there are multiple foreign-keys to choose from and the join + # predicate is an expression, we'll try to figure out which + # foreign-key field we're joining on so that we can assign to the + # correct attribute when resolving the model graph. + to_field = None + if isinstance(on, Expression): + lhs, rhs = on.lhs, on.rhs + # Coerce to set() so that we force Python to compare using the + # object's hash rather than equality test, which returns a + # false-positive due to overriding __eq__. + fk_set = set(fk_fields) + + if isinstance(lhs, Field): + lhs_f = lhs.field if isinstance(lhs, FieldAlias) else lhs + if lhs_f in fk_set: + to_field = lhs_f + elif isinstance(rhs, Field): + rhs_f = rhs.field if isinstance(rhs, FieldAlias) else rhs + if rhs_f in fk_set: + to_field = rhs_f + + return to_field, False + + @Node.copy + def join(self, dest, join_type=JOIN.INNER, on=None, src=None, attr=None): + src = self._join_ctx if src is None else src + + if join_type == JOIN.LATERAL or join_type == JOIN.LEFT_LATERAL: + on = True + elif join_type != JOIN.CROSS: + on, attr, constructor = self._normalize_join(src, dest, on, attr) + if attr: + self._joins.setdefault(src, []) + self._joins[src].append((dest, attr, constructor, join_type)) + elif on is not None: + raise ValueError('Cannot specify on clause with cross join.') + + if not self._from_list: + raise ValueError('No sources to join on.') + + item = self._from_list.pop() + self._from_list.append(Join(item, dest, join_type, on)) + + def join_from(self, src, dest, join_type=JOIN.INNER, on=None, attr=None): + return self.join(dest, join_type, on, src, attr) + + def _get_model_cursor_wrapper(self, cursor): + if len(self._from_list) == 1 and not self._joins: + return ModelObjectCursorWrapper(cursor, self.model, + self._returning, self.model) + return ModelCursorWrapper(cursor, self.model, self._returning, + self._from_list, self._joins) + + def ensure_join(self, lm, rm, on=None, **join_kwargs): + join_ctx = self._join_ctx + for dest, _, constructor, _ in self._joins.get(lm, []): + if dest == rm: + return self + return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx) + + def convert_dict_to_node(self, qdict): + accum = [] + joins = [] + fks = (ForeignKeyField, BackrefAccessor) + for key, value in sorted(qdict.items()): + curr = self.model + if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP: + key, op = key.rsplit('__', 1) + op = DJANGO_MAP[op] + elif value is None: + op = DJANGO_MAP['is'] + else: + op = DJANGO_MAP['eq'] + + if '__' not in key: + # Handle simplest case. This avoids joining over-eagerly when a + # direct FK lookup is all that is required. + model_attr = getattr(curr, key) + else: + for piece in key.split('__'): + for dest, attr, _, _ in self._joins.get(curr, ()): + if attr == piece or (isinstance(dest, ModelAlias) and + dest.alias == piece): + curr = dest + break + else: + model_attr = getattr(curr, piece) + if value is not None and isinstance(model_attr, fks): + curr = model_attr.rel_model + joins.append(model_attr) + accum.append(op(model_attr, value)) + return accum, joins + + def filter(self, *args, **kwargs): + # normalize args and kwargs into a new expression + dq_node = ColumnBase() + if args: + dq_node &= reduce(operator.and_, [a.clone() for a in args]) + if kwargs: + dq_node &= DQ(**kwargs) + + # dq_node should now be an Expression, lhs = Node(), rhs = ... + q = collections.deque([dq_node]) + dq_joins = [] + seen_joins = set() + while q: + curr = q.popleft() + if not isinstance(curr, Expression): + continue + for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)): + if isinstance(piece, DQ): + query, joins = self.convert_dict_to_node(piece.query) + for join in joins: + if join not in seen_joins: + dq_joins.append(join) + seen_joins.add(join) + expression = reduce(operator.and_, query) + # Apply values from the DQ object. + if piece._negated: + expression = Negated(expression) + #expression._alias = piece._alias + setattr(curr, side, expression) + else: + q.append(piece) + + dq_node = dq_node.rhs + + query = self.clone() + for field in dq_joins: + if isinstance(field, ForeignKeyField): + lm, rm = field.model, field.rel_model + field_obj = field + elif isinstance(field, BackrefAccessor): + lm, rm = field.model, field.rel_model + field_obj = field.field + query = query.ensure_join(lm, rm, field_obj) + return query.where(dq_node) + + def create_table(self, name, safe=True, **meta): + return self.model._schema.create_table_as(name, self, safe, **meta) + + def __sql_selection__(self, ctx, is_subquery=False): + if self._is_default and is_subquery and len(self._returning) > 1 and \ + self.model._meta.primary_key is not False: + return ctx.sql(self.model._meta.primary_key) + + return ctx.sql(CommaNodeList(self._returning)) + + +class NoopModelSelect(ModelSelect): + def __sql__(self, ctx): + return self.model._meta.database.get_noop_select(ctx) + + def _get_cursor_wrapper(self, cursor): + return CursorWrapper(cursor) + + +class _ModelWriteQueryHelper(_ModelQueryHelper): + def __init__(self, model, *args, **kwargs): + self.model = model + super(_ModelWriteQueryHelper, self).__init__(model, *args, **kwargs) + + def returning(self, *returning): + accum = [] + for item in returning: + if is_model(item): + accum.extend(item._meta.sorted_fields) + else: + accum.append(item) + return super(_ModelWriteQueryHelper, self).returning(*accum) + + def _set_table_alias(self, ctx): + table = self.model._meta.table + ctx.alias_manager[table] = table.__name__ + + +class ModelUpdate(_ModelWriteQueryHelper, Update): + pass + + +class ModelInsert(_ModelWriteQueryHelper, Insert): + default_row_type = ROW.TUPLE + + def __init__(self, *args, **kwargs): + super(ModelInsert, self).__init__(*args, **kwargs) + if self._returning is None and self.model._meta.database is not None: + if self.model._meta.database.returning_clause: + self._returning = self.model._meta.get_primary_keys() + + def returning(self, *returning): + # By default ModelInsert will yield a `tuple` containing the + # primary-key of the newly inserted row. But if we are explicitly + # specifying a returning clause and have not set a row type, we will + # default to returning model instances instead. + if returning and self._row_type is None: + self._row_type = ROW.MODEL + return super(ModelInsert, self).returning(*returning) + + def get_default_data(self): + return self.model._meta.defaults + + def get_default_columns(self): + fields = self.model._meta.sorted_fields + return fields[1:] if self.model._meta.auto_increment else fields + + +class ModelDelete(_ModelWriteQueryHelper, Delete): + pass + + +class ManyToManyQuery(ModelSelect): + def __init__(self, instance, accessor, rel, *args, **kwargs): + self._instance = instance + self._accessor = accessor + self._src_attr = accessor.src_fk.rel_field.name + self._dest_attr = accessor.dest_fk.rel_field.name + super(ManyToManyQuery, self).__init__(rel, (rel,), *args, **kwargs) + + def _id_list(self, model_or_id_list): + if isinstance(model_or_id_list[0], Model): + return [getattr(obj, self._dest_attr) for obj in model_or_id_list] + return model_or_id_list + + def add(self, value, clear_existing=False): + if clear_existing: + self.clear() + + accessor = self._accessor + src_id = getattr(self._instance, self._src_attr) + if isinstance(value, SelectQuery): + query = value.columns( + Value(src_id), + accessor.dest_fk.rel_field) + accessor.through_model.insert_from( + fields=[accessor.src_fk, accessor.dest_fk], + query=query).execute() + else: + value = ensure_tuple(value) + if not value: return + + inserts = [{ + accessor.src_fk.name: src_id, + accessor.dest_fk.name: rel_id} + for rel_id in self._id_list(value)] + accessor.through_model.insert_many(inserts).execute() + + def remove(self, value): + src_id = getattr(self._instance, self._src_attr) + if isinstance(value, SelectQuery): + column = getattr(value.model, self._dest_attr) + subquery = value.columns(column) + return (self._accessor.through_model + .delete() + .where( + (self._accessor.dest_fk << subquery) & + (self._accessor.src_fk == src_id)) + .execute()) + else: + value = ensure_tuple(value) + if not value: + return + return (self._accessor.through_model + .delete() + .where( + (self._accessor.dest_fk << self._id_list(value)) & + (self._accessor.src_fk == src_id)) + .execute()) + + def clear(self): + src_id = getattr(self._instance, self._src_attr) + return (self._accessor.through_model + .delete() + .where(self._accessor.src_fk == src_id) + .execute()) + + +def safe_python_value(conv_func): + def validate(value): + try: + return conv_func(value) + except (TypeError, ValueError): + return value + return validate + + +class BaseModelCursorWrapper(DictCursorWrapper): + def __init__(self, cursor, model, columns): + super(BaseModelCursorWrapper, self).__init__(cursor) + self.model = model + self.select = columns or [] + + def _initialize_columns(self): + combined = self.model._meta.combined + table = self.model._meta.table + description = self.cursor.description + + self.ncols = len(self.cursor.description) + self.columns = [] + self.converters = converters = [None] * self.ncols + self.fields = fields = [None] * self.ncols + + for idx, description_item in enumerate(description): + column = description_item[0] + dot_index = column.find('.') + if dot_index != -1: + column = column[dot_index + 1:] + + column = column.strip('"') + self.columns.append(column) + try: + raw_node = self.select[idx] + except IndexError: + if column in combined: + raw_node = node = combined[column] + else: + continue + else: + node = raw_node.unwrap() + + # Heuristics used to attempt to get the field associated with a + # given SELECT column, so that we can accurately convert the value + # returned by the database-cursor into a Python object. + if isinstance(node, Field): + if raw_node._coerce: + converters[idx] = node.python_value + fields[idx] = node + if not raw_node.is_alias(): + self.columns[idx] = node.name + elif isinstance(node, Function) and node._coerce: + if node._python_value is not None: + converters[idx] = node._python_value + elif node.arguments and isinstance(node.arguments[0], Node): + # If the first argument is a field or references a column + # on a Model, try using that field's conversion function. + # This usually works, but we use "safe_python_value()" so + # that if a TypeError or ValueError occurs during + # conversion we can just fall-back to the raw cursor value. + first = node.arguments[0].unwrap() + if isinstance(first, Entity): + path = first._path[-1] # Try to look-up by name. + first = combined.get(path) + if isinstance(first, Field): + converters[idx] = safe_python_value(first.python_value) + elif column in combined: + if node._coerce: + converters[idx] = combined[column].python_value + if isinstance(node, Column) and node.source == table: + fields[idx] = combined[column] + + initialize = _initialize_columns + + def process_row(self, row): + raise NotImplementedError + + +class ModelDictCursorWrapper(BaseModelCursorWrapper): + def process_row(self, row): + result = {} + columns, converters = self.columns, self.converters + fields = self.fields + + for i in range(self.ncols): + attr = columns[i] + if attr in result: continue # Don't overwrite if we have dupes. + if converters[i] is not None: + result[attr] = converters[i](row[i]) + else: + result[attr] = row[i] + + return result + + +class ModelTupleCursorWrapper(ModelDictCursorWrapper): + constructor = tuple + + def process_row(self, row): + columns, converters = self.columns, self.converters + return self.constructor([ + (converters[i](row[i]) if converters[i] is not None else row[i]) + for i in range(self.ncols)]) + + +class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper): + def initialize(self): + self._initialize_columns() + attributes = [] + for i in range(self.ncols): + attributes.append(self.columns[i]) + self.tuple_class = collections.namedtuple('Row', attributes) + self.constructor = lambda row: self.tuple_class(*row) + + +class ModelObjectCursorWrapper(ModelDictCursorWrapper): + def __init__(self, cursor, model, select, constructor): + self.constructor = constructor + self.is_model = is_model(constructor) + super(ModelObjectCursorWrapper, self).__init__(cursor, model, select) + + def process_row(self, row): + data = super(ModelObjectCursorWrapper, self).process_row(row) + if self.is_model: + # Clear out any dirty fields before returning to the user. + obj = self.constructor(__no_default__=1, **data) + obj._dirty.clear() + return obj + else: + return self.constructor(**data) + + +class ModelCursorWrapper(BaseModelCursorWrapper): + def __init__(self, cursor, model, select, from_list, joins): + super(ModelCursorWrapper, self).__init__(cursor, model, select) + self.from_list = from_list + self.joins = joins + + def initialize(self): + self._initialize_columns() + selected_src = set([field.model for field in self.fields + if field is not None]) + select, columns = self.select, self.columns + + self.key_to_constructor = {self.model: self.model} + self.src_is_dest = {} + self.src_to_dest = [] + accum = collections.deque(self.from_list) + dests = set() + + while accum: + curr = accum.popleft() + if isinstance(curr, Join): + accum.append(curr.lhs) + accum.append(curr.rhs) + continue + + if curr not in self.joins: + continue + + is_dict = isinstance(curr, dict) + for key, attr, constructor, join_type in self.joins[curr]: + if key not in self.key_to_constructor: + self.key_to_constructor[key] = constructor + + # (src, attr, dest, is_dict, join_type). + self.src_to_dest.append((curr, attr, key, is_dict, + join_type)) + dests.add(key) + accum.append(key) + + # Ensure that we accommodate everything selected. + for src in selected_src: + if src not in self.key_to_constructor: + if is_model(src): + self.key_to_constructor[src] = src + elif isinstance(src, ModelAlias): + self.key_to_constructor[src] = src.model + + # Indicate which sources are also dests. + for src, _, dest, _, _ in self.src_to_dest: + self.src_is_dest[src] = src in dests and (dest in selected_src + or src in selected_src) + + self.column_keys = [] + for idx, node in enumerate(select): + key = self.model + field = self.fields[idx] + if field is not None: + if isinstance(field, FieldAlias): + key = field.source + else: + key = field.model + else: + if isinstance(node, Node): + node = node.unwrap() + if isinstance(node, Column): + key = node.source + + self.column_keys.append(key) + + def process_row(self, row): + objects = {} + object_list = [] + for key, constructor in self.key_to_constructor.items(): + objects[key] = constructor(__no_default__=True) + object_list.append(objects[key]) + + set_keys = set() + for idx, key in enumerate(self.column_keys): + instance = objects[key] + column = self.columns[idx] + value = row[idx] + if value is not None: + set_keys.add(key) + if self.converters[idx]: + value = self.converters[idx](value) + + if isinstance(instance, dict): + instance[column] = value + else: + setattr(instance, column, value) + + # Need to do some analysis on the joins before this. + for (src, attr, dest, is_dict, join_type) in self.src_to_dest: + instance = objects[src] + try: + joined_instance = objects[dest] + except KeyError: + continue + + # If no fields were set on the destination instance then do not + # assign an "empty" instance. + if instance is None or dest is None or \ + (dest not in set_keys and not self.src_is_dest.get(dest)): + continue + + # If no fields were set on either the source or the destination, + # then we have nothing to do here. + if instance not in set_keys and dest not in set_keys \ + and join_type.endswith('OUTER JOIN'): + continue + + if is_dict: + instance[attr] = joined_instance + else: + setattr(instance, attr, joined_instance) + + # When instantiating models from a cursor, we clear the dirty fields. + for instance in object_list: + if isinstance(instance, Model): + instance._dirty.clear() + + return objects[self.model] + + +class PrefetchQuery(collections.namedtuple('_PrefetchQuery', ( + 'query', 'fields', 'is_backref', 'rel_models', 'field_to_name', 'model'))): + def __new__(cls, query, fields=None, is_backref=None, rel_models=None, + field_to_name=None, model=None): + if fields: + if is_backref: + if rel_models is None: + rel_models = [field.model for field in fields] + foreign_key_attrs = [field.rel_field.name for field in fields] + else: + if rel_models is None: + rel_models = [field.rel_model for field in fields] + foreign_key_attrs = [field.name for field in fields] + field_to_name = list(zip(fields, foreign_key_attrs)) + model = query.model + return super(PrefetchQuery, cls).__new__( + cls, query, fields, is_backref, rel_models, field_to_name, model) + + def populate_instance(self, instance, id_map): + if self.is_backref: + for field in self.fields: + identifier = instance.__data__[field.name] + key = (field, identifier) + if key in id_map: + setattr(instance, field.name, id_map[key]) + else: + for field, attname in self.field_to_name: + identifier = instance.__data__[field.rel_field.name] + key = (field, identifier) + rel_instances = id_map.get(key, []) + for inst in rel_instances: + setattr(inst, attname, instance) + setattr(instance, field.backref, rel_instances) + + def store_instance(self, instance, id_map): + for field, attname in self.field_to_name: + identity = field.rel_field.python_value(instance.__data__[attname]) + key = (field, identity) + if self.is_backref: + id_map[key] = instance + else: + id_map.setdefault(key, []) + id_map[key].append(instance) + + +def prefetch_add_subquery(sq, subqueries): + fixed_queries = [PrefetchQuery(sq)] + for i, subquery in enumerate(subqueries): + if isinstance(subquery, tuple): + subquery, target_model = subquery + else: + target_model = None + if not isinstance(subquery, Query) and is_model(subquery) or \ + isinstance(subquery, ModelAlias): + subquery = subquery.select() + subquery_model = subquery.model + fks = backrefs = None + for j in reversed(range(i + 1)): + fixed = fixed_queries[j] + last_query = fixed.query + last_model = last_obj = fixed.model + if isinstance(last_model, ModelAlias): + last_model = last_model.model + rels = subquery_model._meta.model_refs.get(last_model, []) + if rels: + fks = [getattr(subquery_model, fk.name) for fk in rels] + pks = [getattr(last_obj, fk.rel_field.name) for fk in rels] + else: + backrefs = subquery_model._meta.model_backrefs.get(last_model) + if (fks or backrefs) and ((target_model is last_obj) or + (target_model is None)): + break + + if not fks and not backrefs: + tgt_err = ' using %s' % target_model if target_model else '' + raise AttributeError('Error: unable to find foreign key for ' + 'query: %s%s' % (subquery, tgt_err)) + + dest = (target_model,) if target_model else None + + if fks: + expr = reduce(operator.or_, [ + (fk << last_query.select(pk)) + for (fk, pk) in zip(fks, pks)]) + subquery = subquery.where(expr) + fixed_queries.append(PrefetchQuery(subquery, fks, False, dest)) + elif backrefs: + expressions = [] + for backref in backrefs: + rel_field = getattr(subquery_model, backref.rel_field.name) + fk_field = getattr(last_obj, backref.name) + expressions.append(rel_field << last_query.select(fk_field)) + subquery = subquery.where(reduce(operator.or_, expressions)) + fixed_queries.append(PrefetchQuery(subquery, backrefs, True, dest)) + + return fixed_queries + + +def prefetch(sq, *subqueries): + if not subqueries: + return sq + + fixed_queries = prefetch_add_subquery(sq, subqueries) + deps = {} + rel_map = {} + for pq in reversed(fixed_queries): + query_model = pq.model + if pq.fields: + for rel_model in pq.rel_models: + rel_map.setdefault(rel_model, []) + rel_map[rel_model].append(pq) + + deps.setdefault(query_model, {}) + id_map = deps[query_model] + has_relations = bool(rel_map.get(query_model)) + + for instance in pq.query: + if pq.fields: + pq.store_instance(instance, id_map) + if has_relations: + for rel in rel_map[query_model]: + rel.populate_instance(instance, deps[rel.model]) + + return list(pq.query) diff --git a/python2.7libs/playhouse/__init__.py b/python2.7libs/playhouse/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python2.7libs/playhouse/apsw_ext.py b/python2.7libs/playhouse/apsw_ext.py new file mode 100644 index 0000000..654ee77 --- /dev/null +++ b/python2.7libs/playhouse/apsw_ext.py @@ -0,0 +1,146 @@ +""" +Peewee integration with APSW, "another python sqlite wrapper". + +Project page: https://rogerbinns.github.io/apsw/ + +APSW is a really neat library that provides a thin wrapper on top of SQLite's +C interface. + +Here are just a few reasons to use APSW, taken from the documentation: + +* APSW gives all functionality of SQLite, including virtual tables, virtual + file system, blob i/o, backups and file control. +* Connections can be shared across threads without any additional locking. +* Transactions are managed explicitly by your code. +* APSW can handle nested transactions. +* Unicode is handled correctly. +* APSW is faster. +""" +import apsw +from peewee import * +from peewee import __exception_wrapper__ +from peewee import BooleanField as _BooleanField +from peewee import DateField as _DateField +from peewee import DateTimeField as _DateTimeField +from peewee import DecimalField as _DecimalField +from peewee import TimeField as _TimeField +from peewee import logger + +from playhouse.sqlite_ext import SqliteExtDatabase + + +class APSWDatabase(SqliteExtDatabase): + server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.')) + + def __init__(self, database, **kwargs): + self._modules = {} + super(APSWDatabase, self).__init__(database, **kwargs) + + def register_module(self, mod_name, mod_inst): + self._modules[mod_name] = mod_inst + if not self.is_closed(): + self.connection().createmodule(mod_name, mod_inst) + + def unregister_module(self, mod_name): + del(self._modules[mod_name]) + + def _connect(self): + conn = apsw.Connection(self.database, **self.connect_params) + if self._timeout is not None: + conn.setbusytimeout(self._timeout * 1000) + try: + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def _add_conn_hooks(self, conn): + super(APSWDatabase, self)._add_conn_hooks(conn) + self._load_modules(conn) # APSW-only. + + def _load_modules(self, conn): + for mod_name, mod_inst in self._modules.items(): + conn.createmodule(mod_name, mod_inst) + return conn + + def _load_aggregates(self, conn): + for name, (klass, num_params) in self._aggregates.items(): + def make_aggregate(): + return (klass(), klass.step, klass.finalize) + conn.createaggregatefunction(name, make_aggregate) + + def _load_collations(self, conn): + for name, fn in self._collations.items(): + conn.createcollation(name, fn) + + def _load_functions(self, conn): + for name, (fn, num_params) in self._functions.items(): + conn.createscalarfunction(name, fn, num_params) + + def _load_extensions(self, conn): + conn.enableloadextension(True) + for extension in self._extensions: + conn.loadextension(extension) + + def load_extension(self, extension): + self._extensions.add(extension) + if not self.is_closed(): + conn = self.connection() + conn.enableloadextension(True) + conn.loadextension(extension) + + def last_insert_id(self, cursor, query_type=None): + return cursor.getconnection().last_insert_rowid() + + def rows_affected(self, cursor): + return cursor.getconnection().changes() + + def begin(self, lock_type='deferred'): + self.cursor().execute('begin %s;' % lock_type) + + def commit(self): + with __exception_wrapper__: + curs = self.cursor() + if curs.getconnection().getautocommit(): + return False + curs.execute('commit;') + return True + + def rollback(self): + with __exception_wrapper__: + curs = self.cursor() + if curs.getconnection().getautocommit(): + return False + curs.execute('rollback;') + return True + + def execute_sql(self, sql, params=None, commit=True): + logger.debug((sql, params)) + with __exception_wrapper__: + cursor = self.cursor() + cursor.execute(sql, params or ()) + return cursor + + +def nh(s, v): + if v is not None: + return str(v) + +class BooleanField(_BooleanField): + def db_value(self, v): + v = super(BooleanField, self).db_value(v) + if v is not None: + return v and 1 or 0 + +class DateField(_DateField): + db_value = nh + +class TimeField(_TimeField): + db_value = nh + +class DateTimeField(_DateTimeField): + db_value = nh + +class DecimalField(_DecimalField): + db_value = nh diff --git a/python2.7libs/playhouse/cockroachdb.py b/python2.7libs/playhouse/cockroachdb.py new file mode 100644 index 0000000..afb1107 --- /dev/null +++ b/python2.7libs/playhouse/cockroachdb.py @@ -0,0 +1,204 @@ +import functools +import re + +from peewee import * +from peewee import _atomic +from peewee import _manual +from peewee import _transaction +from peewee import ColumnMetadata # (name, data_type, null, primary_key, table, default) +from peewee import ForeignKeyMetadata # (column, dest_table, dest_column, table). +from peewee import IndexMetadata +from playhouse.pool import _PooledPostgresqlDatabase +try: + from playhouse.postgres_ext import ArrayField + from playhouse.postgres_ext import BinaryJSONField + from playhouse.postgres_ext import IntervalField + JSONField = BinaryJSONField +except ImportError: # psycopg2 not installed, ignore. + ArrayField = BinaryJSONField = IntervalField = JSONField = None + + +TXN_ERR_MSG = ('CockroachDB does not support nested transactions. You may ' + 'alternatively use the @transaction context-manager/decorator, ' + 'which only wraps the outer-most block in transactional logic. ' + 'To run a transaction with automatic retries, use the ' + 'run_transaction() helper.') + +class ExceededMaxAttempts(OperationalError): pass + + +class UUIDKeyField(UUIDField): + auto_increment = True + + def __init__(self, *args, **kwargs): + if kwargs.get('constraints'): + raise ValueError('%s cannot specify constraints.' % type(self)) + kwargs['constraints'] = [SQL('DEFAULT gen_random_uuid()')] + kwargs.setdefault('primary_key', True) + super(UUIDKeyField, self).__init__(*args, **kwargs) + + +class RowIDField(AutoField): + field_type = 'INT' + + def __init__(self, *args, **kwargs): + if kwargs.get('constraints'): + raise ValueError('%s cannot specify constraints.' % type(self)) + kwargs['constraints'] = [SQL('DEFAULT unique_rowid()')] + super(RowIDField, self).__init__(*args, **kwargs) + + +class CockroachDatabase(PostgresqlDatabase): + field_types = PostgresqlDatabase.field_types.copy() + field_types.update({ + 'BLOB': 'BYTES', + }) + + for_update = False + nulls_ordering = False + + def __init__(self, *args, **kwargs): + kwargs.setdefault('user', 'root') + kwargs.setdefault('port', 26257) + super(CockroachDatabase, self).__init__(*args, **kwargs) + + def _set_server_version(self, conn): + curs = conn.cursor() + curs.execute('select version()') + raw, = curs.fetchone() + match_obj = re.match('^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw) + if match_obj is not None: + clean = '%d%02d%02d' % tuple(int(i) for i in match_obj.groups()) + self.server_version = int(clean) # 19.1.5 -> 190105. + else: + # Fallback to use whatever cockroachdb tells us via protocol. + super(CockroachDatabase, self)._set_server_version(conn) + + def _get_pk_constraint(self, table, schema=None): + query = ('SELECT constraint_name ' + 'FROM information_schema.table_constraints ' + 'WHERE table_name = %s AND table_schema = %s ' + 'AND constraint_type = %s') + cursor = self.execute_sql(query, (table, schema or 'public', + 'PRIMARY KEY')) + row = cursor.fetchone() + return row and row[0] or None + + def get_indexes(self, table, schema=None): + # The primary-key index is returned by default, so we will just strip + # it out here. + indexes = super(CockroachDatabase, self).get_indexes(table, schema) + pkc = self._get_pk_constraint(table, schema) + return [idx for idx in indexes if (not pkc) or (idx.name != pkc)] + + def conflict_statement(self, on_conflict, query): + if not on_conflict._action: return + + action = on_conflict._action.lower() + if action in ('replace', 'upsert'): + return SQL('UPSERT') + elif action not in ('ignore', 'nothing', 'update'): + raise ValueError('Un-supported action for conflict resolution. ' + 'CockroachDB supports REPLACE (UPSERT), IGNORE ' + 'and UPDATE.') + + def conflict_update(self, oc, query): + action = oc._action.lower() if oc._action else '' + if action in ('ignore', 'nothing'): + return SQL('ON CONFLICT DO NOTHING') + elif action in ('replace', 'upsert'): + # No special stuff is necessary, this is just indicated by starting + # the statement with UPSERT instead of INSERT. + return + elif oc._conflict_constraint: + raise ValueError('CockroachDB does not support the usage of a ' + 'constraint name. Use the column(s) instead.') + + return super(CockroachDatabase, self).conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.extract(date_part, date_field) + + def from_timestamp(self, date_field): + # CRDB does not allow casting a decimal/float to timestamp, so we first + # cast to int, then to timestamptz. + return date_field.cast('int').cast('timestamptz') + + def begin(self, system_time=None, priority=None): + super(CockroachDatabase, self).begin() + if system_time is not None: + self.execute_sql('SET TRANSACTION AS OF SYSTEM TIME %s', + (system_time,), commit=False) + if priority is not None: + priority = priority.lower() + if priority not in ('low', 'normal', 'high'): + raise ValueError('priority must be low, normal or high') + self.execute_sql('SET TRANSACTION PRIORITY %s' % priority, + commit=False) + + def atomic(self, system_time=None, priority=None): + return _crdb_atomic(self, system_time, priority) + + def transaction(self, system_time=None, priority=None): + return _transaction(self, system_time, priority) + + def savepoint(self): + raise NotImplementedError(TXN_ERR_MSG) + + def retry_transaction(self, max_attempts=None, system_time=None, + priority=None): + def deco(cb): + @functools.wraps(cb) + def new_fn(): + return run_transaction(self, cb, max_attempts, system_time, + priority) + return new_fn + return deco + + def run_transaction(self, cb, max_attempts=None, system_time=None, + priority=None): + return run_transaction(self, cb, max_attempts, system_time, priority) + + +class _crdb_atomic(_atomic): + def __enter__(self): + if self.db.transaction_depth() > 0: + if not isinstance(self.db.top_transaction(), _manual): + raise NotImplementedError(TXN_ERR_MSG) + return super(_crdb_atomic, self).__enter__() + + +def run_transaction(db, callback, max_attempts=None, system_time=None, + priority=None): + """ + Run transactional SQL in a transaction with automatic retries. + + User-provided `callback`: + * Must accept one parameter, the `db` instance representing the connection + the transaction is running under. + * Must not attempt to commit, rollback or otherwise manage transactions. + * May be called more than once. + * Should ideally only contain SQL operations. + + Additionally, the database must not have any open transaction at the time + this function is called, as CRDB does not support nested transactions. + """ + max_attempts = max_attempts or -1 + with db.atomic(system_time=system_time, priority=priority) as txn: + db.execute_sql('SAVEPOINT cockroach_restart') + while max_attempts != 0: + try: + result = callback(db) + db.execute_sql('RELEASE SAVEPOINT cockroach_restart') + return result + except OperationalError as exc: + if exc.orig.pgcode == '40001': + max_attempts -= 1 + db.execute_sql('ROLLBACK TO SAVEPOINT cockroach_restart') + continue + raise + raise ExceededMaxAttempts(None, 'unable to commit transaction') + + +class PooledCockroachDatabase(_PooledPostgresqlDatabase, CockroachDatabase): + pass diff --git a/python2.7libs/playhouse/dataset.py b/python2.7libs/playhouse/dataset.py new file mode 100644 index 0000000..f5bbf8b --- /dev/null +++ b/python2.7libs/playhouse/dataset.py @@ -0,0 +1,452 @@ +import csv +import datetime +from decimal import Decimal +import json +import operator +try: + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse +import sys + +from peewee import * +from playhouse.db_url import connect +from playhouse.migrate import migrate +from playhouse.migrate import SchemaMigrator +from playhouse.reflection import Introspector + +if sys.version_info[0] == 3: + basestring = str + from functools import reduce + def open_file(f, mode): + return open(f, mode, encoding='utf8') +else: + open_file = open + + +class DataSet(object): + def __init__(self, url, bare_fields=False): + if isinstance(url, Database): + self._url = None + self._database = url + self._database_path = self._database.database + else: + self._url = url + parse_result = urlparse(url) + self._database_path = parse_result.path[1:] + + # Connect to the database. + self._database = connect(url) + + self._database.connect() + + # Introspect the database and generate models. + self._introspector = Introspector.from_database(self._database) + self._models = self._introspector.generate_models( + skip_invalid=True, + literal_column_names=True, + bare_fields=bare_fields) + self._migrator = SchemaMigrator.from_database(self._database) + + class BaseModel(Model): + class Meta: + database = self._database + self._base_model = BaseModel + self._export_formats = self.get_export_formats() + self._import_formats = self.get_import_formats() + + def __repr__(self): + return '' % self._database_path + + def get_export_formats(self): + return { + 'csv': CSVExporter, + 'json': JSONExporter, + 'tsv': TSVExporter} + + def get_import_formats(self): + return { + 'csv': CSVImporter, + 'json': JSONImporter, + 'tsv': TSVImporter} + + def __getitem__(self, table): + if table not in self._models and table in self.tables: + self.update_cache(table) + return Table(self, table, self._models.get(table)) + + @property + def tables(self): + return self._database.get_tables() + + def __contains__(self, table): + return table in self.tables + + def connect(self): + self._database.connect() + + def close(self): + self._database.close() + + def update_cache(self, table=None): + if table: + dependencies = [table] + if table in self._models: + model_class = self._models[table] + dependencies.extend([ + related._meta.table_name for _, related, _ in + model_class._meta.model_graph()]) + else: + dependencies.extend(self.get_table_dependencies(table)) + else: + dependencies = None # Update all tables. + self._models = {} + updated = self._introspector.generate_models( + skip_invalid=True, + table_names=dependencies, + literal_column_names=True) + self._models.update(updated) + + def get_table_dependencies(self, table): + stack = [table] + accum = [] + seen = set() + while stack: + table = stack.pop() + for fk_meta in self._database.get_foreign_keys(table): + dest = fk_meta.dest_table + if dest not in seen: + stack.append(dest) + accum.append(dest) + return accum + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._database.is_closed(): + self.close() + + def query(self, sql, params=None, commit=True): + return self._database.execute_sql(sql, params, commit) + + def transaction(self): + if self._database.transaction_depth() == 0: + return self._database.transaction() + else: + return self._database.savepoint() + + def _check_arguments(self, filename, file_obj, format, format_dict): + if filename and file_obj: + raise ValueError('file is over-specified. Please use either ' + 'filename or file_obj, but not both.') + if not filename and not file_obj: + raise ValueError('A filename or file-like object must be ' + 'specified.') + if format not in format_dict: + valid_formats = ', '.join(sorted(format_dict.keys())) + raise ValueError('Unsupported format "%s". Use one of %s.' % ( + format, valid_formats)) + + def freeze(self, query, format='csv', filename=None, file_obj=None, + **kwargs): + self._check_arguments(filename, file_obj, format, self._export_formats) + if filename: + file_obj = open_file(filename, 'w') + + exporter = self._export_formats[format](query) + exporter.export(file_obj, **kwargs) + + if filename: + file_obj.close() + + def thaw(self, table, format='csv', filename=None, file_obj=None, + strict=False, **kwargs): + self._check_arguments(filename, file_obj, format, self._export_formats) + if filename: + file_obj = open_file(filename, 'r') + + importer = self._import_formats[format](self[table], strict) + count = importer.load(file_obj, **kwargs) + + if filename: + file_obj.close() + + return count + + +class Table(object): + def __init__(self, dataset, name, model_class): + self.dataset = dataset + self.name = name + if model_class is None: + model_class = self._create_model() + model_class.create_table() + self.dataset._models[name] = model_class + + @property + def model_class(self): + return self.dataset._models[self.name] + + def __repr__(self): + return '' % self.name + + def __len__(self): + return self.find().count() + + def __iter__(self): + return iter(self.find().iterator()) + + def _create_model(self): + class Meta: + table_name = self.name + return type( + str(self.name), + (self.dataset._base_model,), + {'Meta': Meta}) + + def create_index(self, columns, unique=False): + self.dataset._database.create_index( + self.model_class, + columns, + unique=unique) + + def _guess_field_type(self, value): + if isinstance(value, basestring): + return TextField + if isinstance(value, (datetime.date, datetime.datetime)): + return DateTimeField + elif value is True or value is False: + return BooleanField + elif isinstance(value, int): + return IntegerField + elif isinstance(value, float): + return FloatField + elif isinstance(value, Decimal): + return DecimalField + return TextField + + @property + def columns(self): + return [f.name for f in self.model_class._meta.sorted_fields] + + def _migrate_new_columns(self, data): + new_keys = set(data) - set(self.model_class._meta.fields) + if new_keys: + operations = [] + for key in new_keys: + field_class = self._guess_field_type(data[key]) + field = field_class(null=True) + operations.append( + self.dataset._migrator.add_column(self.name, key, field)) + field.bind(self.model_class, key) + + migrate(*operations) + + self.dataset.update_cache(self.name) + + def __getitem__(self, item): + try: + return self.model_class[item] + except self.model_class.DoesNotExist: + pass + + def __setitem__(self, item, value): + if not isinstance(value, dict): + raise ValueError('Table.__setitem__() value must be a dict') + + pk = self.model_class._meta.primary_key + value[pk.name] = item + + try: + with self.dataset.transaction() as txn: + self.insert(**value) + except IntegrityError: + self.dataset.update_cache(self.name) + self.update(columns=[pk.name], **value) + + def __delitem__(self, item): + del self.model_class[item] + + def insert(self, **data): + self._migrate_new_columns(data) + return self.model_class.insert(**data).execute() + + def _apply_where(self, query, filters, conjunction=None): + conjunction = conjunction or operator.and_ + if filters: + expressions = [ + (self.model_class._meta.fields[column] == value) + for column, value in filters.items()] + query = query.where(reduce(conjunction, expressions)) + return query + + def update(self, columns=None, conjunction=None, **data): + self._migrate_new_columns(data) + filters = {} + if columns: + for column in columns: + filters[column] = data.pop(column) + + return self._apply_where( + self.model_class.update(**data), + filters, + conjunction).execute() + + def _query(self, **query): + return self._apply_where(self.model_class.select(), query) + + def find(self, **query): + return self._query(**query).dicts() + + def find_one(self, **query): + try: + return self.find(**query).get() + except self.model_class.DoesNotExist: + return None + + def all(self): + return self.find() + + def delete(self, **query): + return self._apply_where(self.model_class.delete(), query).execute() + + def freeze(self, *args, **kwargs): + return self.dataset.freeze(self.all(), *args, **kwargs) + + def thaw(self, *args, **kwargs): + return self.dataset.thaw(self.name, *args, **kwargs) + + +class Exporter(object): + def __init__(self, query): + self.query = query + + def export(self, file_obj): + raise NotImplementedError + + +class JSONExporter(Exporter): + def __init__(self, query, iso8601_datetimes=False): + super(JSONExporter, self).__init__(query) + self.iso8601_datetimes = iso8601_datetimes + + def _make_default(self): + datetime_types = (datetime.datetime, datetime.date, datetime.time) + + if self.iso8601_datetimes: + def default(o): + if isinstance(o, datetime_types): + return o.isoformat() + elif isinstance(o, Decimal): + return str(o) + raise TypeError('Unable to serialize %r as JSON' % o) + else: + def default(o): + if isinstance(o, datetime_types + (Decimal,)): + return str(o) + raise TypeError('Unable to serialize %r as JSON' % o) + return default + + def export(self, file_obj, **kwargs): + json.dump( + list(self.query), + file_obj, + default=self._make_default(), + **kwargs) + + +class CSVExporter(Exporter): + def export(self, file_obj, header=True, **kwargs): + writer = csv.writer(file_obj, **kwargs) + tuples = self.query.tuples().execute() + tuples.initialize() + if header and getattr(tuples, 'columns', None): + writer.writerow([column for column in tuples.columns]) + for row in tuples: + writer.writerow(row) + + +class TSVExporter(CSVExporter): + def export(self, file_obj, header=True, **kwargs): + kwargs.setdefault('delimiter', '\t') + return super(TSVExporter, self).export(file_obj, header, **kwargs) + + +class Importer(object): + def __init__(self, table, strict=False): + self.table = table + self.strict = strict + + model = self.table.model_class + self.columns = model._meta.columns + self.columns.update(model._meta.fields) + + def load(self, file_obj): + raise NotImplementedError + + +class JSONImporter(Importer): + def load(self, file_obj, **kwargs): + data = json.load(file_obj, **kwargs) + count = 0 + + for row in data: + if self.strict: + obj = {} + for key in row: + field = self.columns.get(key) + if field is not None: + obj[field.name] = field.python_value(row[key]) + else: + obj = row + + if obj: + self.table.insert(**obj) + count += 1 + + return count + + +class CSVImporter(Importer): + def load(self, file_obj, header=True, **kwargs): + count = 0 + reader = csv.reader(file_obj, **kwargs) + if header: + try: + header_keys = next(reader) + except StopIteration: + return count + + if self.strict: + header_fields = [] + for idx, key in enumerate(header_keys): + if key in self.columns: + header_fields.append((idx, self.columns[key])) + else: + header_fields = list(enumerate(header_keys)) + else: + header_fields = list(enumerate(self.model._meta.sorted_fields)) + + if not header_fields: + return count + + for row in reader: + obj = {} + for idx, field in header_fields: + if self.strict: + obj[field.name] = field.python_value(row[idx]) + else: + obj[field] = row[idx] + + self.table.insert(**obj) + count += 1 + + return count + + +class TSVImporter(CSVImporter): + def load(self, file_obj, header=True, **kwargs): + kwargs.setdefault('delimiter', '\t') + return super(TSVImporter, self).load(file_obj, header, **kwargs) diff --git a/python2.7libs/playhouse/db_url.py b/python2.7libs/playhouse/db_url.py new file mode 100644 index 0000000..7176c80 --- /dev/null +++ b/python2.7libs/playhouse/db_url.py @@ -0,0 +1,130 @@ +try: + from urlparse import parse_qsl, unquote, urlparse +except ImportError: + from urllib.parse import parse_qsl, unquote, urlparse + +from peewee import * +from playhouse.cockroachdb import CockroachDatabase +from playhouse.cockroachdb import PooledCockroachDatabase +from playhouse.pool import PooledMySQLDatabase +from playhouse.pool import PooledPostgresqlDatabase +from playhouse.pool import PooledSqliteDatabase +from playhouse.pool import PooledSqliteExtDatabase +from playhouse.sqlite_ext import SqliteExtDatabase + + +schemes = { + 'cockroachdb': CockroachDatabase, + 'cockroachdb+pool': PooledCockroachDatabase, + 'crdb': CockroachDatabase, + 'crdb+pool': PooledCockroachDatabase, + 'mysql': MySQLDatabase, + 'mysql+pool': PooledMySQLDatabase, + 'postgres': PostgresqlDatabase, + 'postgresql': PostgresqlDatabase, + 'postgres+pool': PooledPostgresqlDatabase, + 'postgresql+pool': PooledPostgresqlDatabase, + 'sqlite': SqliteDatabase, + 'sqliteext': SqliteExtDatabase, + 'sqlite+pool': PooledSqliteDatabase, + 'sqliteext+pool': PooledSqliteExtDatabase, +} + +def register_database(db_class, *names): + global schemes + for name in names: + schemes[name] = db_class + +def parseresult_to_dict(parsed, unquote_password=False): + + # urlparse in python 2.6 is broken so query will be empty and instead + # appended to path complete with '?' + path_parts = parsed.path[1:].split('?') + try: + query = path_parts[1] + except IndexError: + query = parsed.query + + connect_kwargs = {'database': path_parts[0]} + if parsed.username: + connect_kwargs['user'] = parsed.username + if parsed.password: + connect_kwargs['password'] = parsed.password + if unquote_password: + connect_kwargs['password'] = unquote(connect_kwargs['password']) + if parsed.hostname: + connect_kwargs['host'] = parsed.hostname + if parsed.port: + connect_kwargs['port'] = parsed.port + + # Adjust parameters for MySQL. + if parsed.scheme == 'mysql' and 'password' in connect_kwargs: + connect_kwargs['passwd'] = connect_kwargs.pop('password') + elif 'sqlite' in parsed.scheme and not connect_kwargs['database']: + connect_kwargs['database'] = ':memory:' + + # Get additional connection args from the query string + qs_args = parse_qsl(query, keep_blank_values=True) + for key, value in qs_args: + if value.lower() == 'false': + value = False + elif value.lower() == 'true': + value = True + elif value.isdigit(): + value = int(value) + elif '.' in value and all(p.isdigit() for p in value.split('.', 1)): + try: + value = float(value) + except ValueError: + pass + elif value.lower() in ('null', 'none'): + value = None + + connect_kwargs[key] = value + + return connect_kwargs + +def parse(url, unquote_password=False): + parsed = urlparse(url) + return parseresult_to_dict(parsed, unquote_password) + +def connect(url, unquote_password=False, **connect_params): + parsed = urlparse(url) + connect_kwargs = parseresult_to_dict(parsed, unquote_password) + connect_kwargs.update(connect_params) + database_class = schemes.get(parsed.scheme) + + if database_class is None: + if database_class in schemes: + raise RuntimeError('Attempted to use "%s" but a required library ' + 'could not be imported.' % parsed.scheme) + else: + raise RuntimeError('Unrecognized or unsupported scheme: "%s".' % + parsed.scheme) + + return database_class(**connect_kwargs) + +# Conditionally register additional databases. +try: + from playhouse.pool import PooledPostgresqlExtDatabase +except ImportError: + pass +else: + register_database( + PooledPostgresqlExtDatabase, + 'postgresext+pool', + 'postgresqlext+pool') + +try: + from playhouse.apsw_ext import APSWDatabase +except ImportError: + pass +else: + register_database(APSWDatabase, 'apsw') + +try: + from playhouse.postgres_ext import PostgresqlExtDatabase +except ImportError: + pass +else: + register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext') diff --git a/python2.7libs/playhouse/fields.py b/python2.7libs/playhouse/fields.py new file mode 100644 index 0000000..fce1a3d --- /dev/null +++ b/python2.7libs/playhouse/fields.py @@ -0,0 +1,64 @@ +try: + import bz2 +except ImportError: + bz2 = None +try: + import zlib +except ImportError: + zlib = None +try: + import cPickle as pickle +except ImportError: + import pickle +import sys + +from peewee import BlobField +from peewee import buffer_type + + +PY2 = sys.version_info[0] == 2 + + +class CompressedField(BlobField): + ZLIB = 'zlib' + BZ2 = 'bz2' + algorithm_to_import = { + ZLIB: zlib, + BZ2: bz2, + } + + def __init__(self, compression_level=6, algorithm=ZLIB, *args, + **kwargs): + self.compression_level = compression_level + if algorithm not in self.algorithm_to_import: + raise ValueError('Unrecognized algorithm %s' % algorithm) + compress_module = self.algorithm_to_import[algorithm] + if compress_module is None: + raise ValueError('Missing library required for %s.' % algorithm) + + self.algorithm = algorithm + self.compress = compress_module.compress + self.decompress = compress_module.decompress + super(CompressedField, self).__init__(*args, **kwargs) + + def python_value(self, value): + if value is not None: + return self.decompress(value) + + def db_value(self, value): + if value is not None: + return self._constructor( + self.compress(value, self.compression_level)) + + +class PickleField(BlobField): + def python_value(self, value): + if value is not None: + if isinstance(value, buffer_type): + value = bytes(value) + return pickle.loads(value) + + def db_value(self, value): + if value is not None: + pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL) + return self._constructor(pickled) diff --git a/python2.7libs/playhouse/flask_utils.py b/python2.7libs/playhouse/flask_utils.py new file mode 100644 index 0000000..76a2a62 --- /dev/null +++ b/python2.7libs/playhouse/flask_utils.py @@ -0,0 +1,185 @@ +import math +import sys + +from flask import abort +from flask import render_template +from flask import request +from peewee import Database +from peewee import DoesNotExist +from peewee import Model +from peewee import Proxy +from peewee import SelectQuery +from playhouse.db_url import connect as db_url_connect + + +class PaginatedQuery(object): + def __init__(self, query_or_model, paginate_by, page_var='page', page=None, + check_bounds=False): + self.paginate_by = paginate_by + self.page_var = page_var + self.page = page or None + self.check_bounds = check_bounds + + if isinstance(query_or_model, SelectQuery): + self.query = query_or_model + self.model = self.query.model + else: + self.model = query_or_model + self.query = self.model.select() + + def get_page(self): + if self.page is not None: + return self.page + + curr_page = request.args.get(self.page_var) + if curr_page and curr_page.isdigit(): + return max(1, int(curr_page)) + return 1 + + def get_page_count(self): + if not hasattr(self, '_page_count'): + self._page_count = int(math.ceil( + float(self.query.count()) / self.paginate_by)) + return self._page_count + + def get_object_list(self): + if self.check_bounds and self.get_page() > self.get_page_count(): + abort(404) + return self.query.paginate(self.get_page(), self.paginate_by) + + +def get_object_or_404(query_or_model, *query): + if not isinstance(query_or_model, SelectQuery): + query_or_model = query_or_model.select() + try: + return query_or_model.where(*query).get() + except DoesNotExist: + abort(404) + +def object_list(template_name, query, context_variable='object_list', + paginate_by=20, page_var='page', page=None, check_bounds=True, + **kwargs): + paginated_query = PaginatedQuery( + query, + paginate_by=paginate_by, + page_var=page_var, + page=page, + check_bounds=check_bounds) + kwargs[context_variable] = paginated_query.get_object_list() + return render_template( + template_name, + pagination=paginated_query, + page=paginated_query.get_page(), + **kwargs) + +def get_current_url(): + if not request.query_string: + return request.path + return '%s?%s' % (request.path, request.query_string) + +def get_next_url(default='/'): + if request.args.get('next'): + return request.args['next'] + elif request.form.get('next'): + return request.form['next'] + return default + +class FlaskDB(object): + def __init__(self, app=None, database=None, model_class=Model): + self.database = None # Reference to actual Peewee database instance. + self.base_model_class = model_class + self._app = app + self._db = database # dict, url, Database, or None (default). + if app is not None: + self.init_app(app) + + def init_app(self, app): + self._app = app + + if self._db is None: + if 'DATABASE' in app.config: + initial_db = app.config['DATABASE'] + elif 'DATABASE_URL' in app.config: + initial_db = app.config['DATABASE_URL'] + else: + raise ValueError('Missing required configuration data for ' + 'database: DATABASE or DATABASE_URL.') + else: + initial_db = self._db + + self._load_database(app, initial_db) + self._register_handlers(app) + + def _load_database(self, app, config_value): + if isinstance(config_value, Database): + database = config_value + elif isinstance(config_value, dict): + database = self._load_from_config_dict(dict(config_value)) + else: + # Assume a database connection URL. + database = db_url_connect(config_value) + + if isinstance(self.database, Proxy): + self.database.initialize(database) + else: + self.database = database + + def _load_from_config_dict(self, config_dict): + try: + name = config_dict.pop('name') + engine = config_dict.pop('engine') + except KeyError: + raise RuntimeError('DATABASE configuration must specify a ' + '`name` and `engine`.') + + if '.' in engine: + path, class_name = engine.rsplit('.', 1) + else: + path, class_name = 'peewee', engine + + try: + __import__(path) + module = sys.modules[path] + database_class = getattr(module, class_name) + assert issubclass(database_class, Database) + except ImportError: + raise RuntimeError('Unable to import %s' % engine) + except AttributeError: + raise RuntimeError('Database engine not found %s' % engine) + except AssertionError: + raise RuntimeError('Database engine not a subclass of ' + 'peewee.Database: %s' % engine) + + return database_class(name, **config_dict) + + def _register_handlers(self, app): + app.before_request(self.connect_db) + app.teardown_request(self.close_db) + + def get_model_class(self): + if self.database is None: + raise RuntimeError('Database must be initialized.') + + class BaseModel(self.base_model_class): + class Meta: + database = self.database + + return BaseModel + + @property + def Model(self): + if self._app is None: + database = getattr(self, 'database', None) + if database is None: + self.database = Proxy() + + if not hasattr(self, '_model_class'): + self._model_class = self.get_model_class() + return self._model_class + + def connect_db(self): + self.database.connect() + + def close_db(self, exc): + if not self.database.is_closed(): + self.database.close() diff --git a/python2.7libs/playhouse/hybrid.py b/python2.7libs/playhouse/hybrid.py new file mode 100644 index 0000000..50531cc --- /dev/null +++ b/python2.7libs/playhouse/hybrid.py @@ -0,0 +1,53 @@ +from peewee import ModelDescriptor + + +# Hybrid methods/attributes, based on similar functionality in SQLAlchemy: +# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html +class hybrid_method(ModelDescriptor): + def __init__(self, func, expr=None): + self.func = func + self.expr = expr or func + + def __get__(self, instance, instance_type): + if instance is None: + return self.expr.__get__(instance_type, instance_type.__class__) + return self.func.__get__(instance, instance_type) + + def expression(self, expr): + self.expr = expr + return self + + +class hybrid_property(ModelDescriptor): + def __init__(self, fget, fset=None, fdel=None, expr=None): + self.fget = fget + self.fset = fset + self.fdel = fdel + self.expr = expr or fget + + def __get__(self, instance, instance_type): + if instance is None: + return self.expr(instance_type) + return self.fget(instance) + + def __set__(self, instance, value): + if self.fset is None: + raise AttributeError('Cannot set attribute.') + self.fset(instance, value) + + def __delete__(self, instance): + if self.fdel is None: + raise AttributeError('Cannot delete attribute.') + self.fdel(instance) + + def setter(self, fset): + self.fset = fset + return self + + def deleter(self, fdel): + self.fdel = fdel + return self + + def expression(self, expr): + self.expr = expr + return self diff --git a/python2.7libs/playhouse/kv.py b/python2.7libs/playhouse/kv.py new file mode 100644 index 0000000..742b49c --- /dev/null +++ b/python2.7libs/playhouse/kv.py @@ -0,0 +1,172 @@ +import operator + +from peewee import * +from peewee import Expression +from playhouse.fields import PickleField +try: + from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase +except ImportError: + from playhouse.sqlite_ext import SqliteExtDatabase + + +Sentinel = type('Sentinel', (object,), {}) + + +class KeyValue(object): + """ + Persistent dictionary. + + :param Field key_field: field to use for key. Defaults to CharField. + :param Field value_field: field to use for value. Defaults to PickleField. + :param bool ordered: data should be returned in key-sorted order. + :param Database database: database where key/value data is stored. + :param str table_name: table name for data. + """ + def __init__(self, key_field=None, value_field=None, ordered=False, + database=None, table_name='keyvalue'): + if key_field is None: + key_field = CharField(max_length=255, primary_key=True) + if not key_field.primary_key: + raise ValueError('key_field must have primary_key=True.') + + if value_field is None: + value_field = PickleField() + + self._key_field = key_field + self._value_field = value_field + self._ordered = ordered + self._database = database or SqliteExtDatabase(':memory:') + self._table_name = table_name + if isinstance(self._database, PostgresqlDatabase): + self.upsert = self._postgres_upsert + self.update = self._postgres_update + else: + self.upsert = self._upsert + self.update = self._update + + self.model = self.create_model() + self.key = self.model.key + self.value = self.model.value + + # Ensure table exists. + self.model.create_table() + + def create_model(self): + class KeyValue(Model): + key = self._key_field + value = self._value_field + class Meta: + database = self._database + table_name = self._table_name + return KeyValue + + def query(self, *select): + query = self.model.select(*select).tuples() + if self._ordered: + query = query.order_by(self.key) + return query + + def convert_expression(self, expr): + if not isinstance(expr, Expression): + return (self.key == expr), True + return expr, False + + def __contains__(self, key): + expr, _ = self.convert_expression(key) + return self.model.select().where(expr).exists() + + def __len__(self): + return len(self.model) + + def __getitem__(self, expr): + converted, is_single = self.convert_expression(expr) + query = self.query(self.value).where(converted) + item_getter = operator.itemgetter(0) + result = [item_getter(row) for row in query] + if len(result) == 0 and is_single: + raise KeyError(expr) + elif is_single: + return result[0] + return result + + def _upsert(self, key, value): + (self.model + .insert(key=key, value=value) + .on_conflict('replace') + .execute()) + + def _postgres_upsert(self, key, value): + (self.model + .insert(key=key, value=value) + .on_conflict(conflict_target=[self.key], + preserve=[self.value]) + .execute()) + + def __setitem__(self, expr, value): + if isinstance(expr, Expression): + self.model.update(value=value).where(expr).execute() + else: + self.upsert(expr, value) + + def __delitem__(self, expr): + converted, _ = self.convert_expression(expr) + self.model.delete().where(converted).execute() + + def __iter__(self): + return iter(self.query().execute()) + + def keys(self): + return map(operator.itemgetter(0), self.query(self.key)) + + def values(self): + return map(operator.itemgetter(0), self.query(self.value)) + + def items(self): + return iter(self.query().execute()) + + def _update(self, __data=None, **mapping): + if __data is not None: + mapping.update(__data) + return (self.model + .insert_many(list(mapping.items()), + fields=[self.key, self.value]) + .on_conflict('replace') + .execute()) + + def _postgres_update(self, __data=None, **mapping): + if __data is not None: + mapping.update(__data) + return (self.model + .insert_many(list(mapping.items()), + fields=[self.key, self.value]) + .on_conflict(conflict_target=[self.key], + preserve=[self.value]) + .execute()) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + def pop(self, key, default=Sentinel): + with self._database.atomic(): + try: + result = self[key] + except KeyError: + if default is Sentinel: + raise + return default + del self[key] + + return result + + def clear(self): + self.model.delete().execute() diff --git a/python2.7libs/playhouse/migrate.py b/python2.7libs/playhouse/migrate.py new file mode 100644 index 0000000..28ffa52 --- /dev/null +++ b/python2.7libs/playhouse/migrate.py @@ -0,0 +1,881 @@ +""" +Lightweight schema migrations. + +NOTE: Currently tested with SQLite and Postgresql. MySQL may be missing some +features. + +Example Usage +------------- + +Instantiate a migrator: + + # Postgres example: + my_db = PostgresqlDatabase(...) + migrator = PostgresqlMigrator(my_db) + + # SQLite example: + my_db = SqliteDatabase('my_database.db') + migrator = SqliteMigrator(my_db) + +Then you will use the `migrate` function to run various `Operation`s which +are generated by the migrator: + + migrate( + migrator.add_column('some_table', 'column_name', CharField(default='')) + ) + +Migrations are not run inside a transaction, so if you wish the migration to +run in a transaction you will need to wrap the call to `migrate` in a +transaction block, e.g.: + + with my_db.transaction(): + migrate(...) + +Supported Operations +-------------------- + +Add new field(s) to an existing model: + + # Create your field instances. For non-null fields you must specify a + # default value. + pubdate_field = DateTimeField(null=True) + comment_field = TextField(default='') + + # Run the migration, specifying the database table, field name and field. + migrate( + migrator.add_column('comment_tbl', 'pub_date', pubdate_field), + migrator.add_column('comment_tbl', 'comment', comment_field), + ) + +Renaming a field: + + # Specify the table, original name of the column, and its new name. + migrate( + migrator.rename_column('story', 'pub_date', 'publish_date'), + migrator.rename_column('story', 'mod_date', 'modified_date'), + ) + +Dropping a field: + + migrate( + migrator.drop_column('story', 'some_old_field'), + ) + +Making a field nullable or not nullable: + + # Note that when making a field not null that field must not have any + # NULL values present. + migrate( + # Make `pub_date` allow NULL values. + migrator.drop_not_null('story', 'pub_date'), + + # Prevent `modified_date` from containing NULL values. + migrator.add_not_null('story', 'modified_date'), + ) + +Renaming a table: + + migrate( + migrator.rename_table('story', 'stories_tbl'), + ) + +Adding an index: + + # Specify the table, column names, and whether the index should be + # UNIQUE or not. + migrate( + # Create an index on the `pub_date` column. + migrator.add_index('story', ('pub_date',), False), + + # Create a multi-column index on the `pub_date` and `status` fields. + migrator.add_index('story', ('pub_date', 'status'), False), + + # Create a unique index on the category and title fields. + migrator.add_index('story', ('category_id', 'title'), True), + ) + +Dropping an index: + + # Specify the index name. + migrate(migrator.drop_index('story', 'story_pub_date_status')) + +Adding or dropping table constraints: + +.. code-block:: python + + # Add a CHECK() constraint to enforce the price cannot be negative. + migrate(migrator.add_constraint( + 'products', + 'price_check', + Check('price >= 0'))) + + # Remove the price check constraint. + migrate(migrator.drop_constraint('products', 'price_check')) + + # Add a UNIQUE constraint on the first and last names. + migrate(migrator.add_unique('person', 'first_name', 'last_name')) +""" +from collections import namedtuple +import functools +import hashlib +import re + +from peewee import * +from peewee import CommaNodeList +from peewee import EnclosedNodeList +from peewee import Entity +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import OP +from peewee import callable_ +from peewee import sort_models +from peewee import _truncate_constraint_name +try: + from playhouse.cockroachdb import CockroachDatabase +except ImportError: + CockroachDatabase = None + + +class Operation(object): + """Encapsulate a single schema altering operation.""" + def __init__(self, migrator, method, *args, **kwargs): + self.migrator = migrator + self.method = method + self.args = args + self.kwargs = kwargs + + def execute(self, node): + self.migrator.database.execute(node) + + def _handle_result(self, result): + if isinstance(result, (Node, Context)): + self.execute(result) + elif isinstance(result, Operation): + result.run() + elif isinstance(result, (list, tuple)): + for item in result: + self._handle_result(item) + + def run(self): + kwargs = self.kwargs.copy() + kwargs['with_context'] = True + method = getattr(self.migrator, self.method) + self._handle_result(method(*self.args, **kwargs)) + + +def operation(fn): + @functools.wraps(fn) + def inner(self, *args, **kwargs): + with_context = kwargs.pop('with_context', False) + if with_context: + return fn(self, *args, **kwargs) + return Operation(self, fn.__name__, *args, **kwargs) + return inner + + +def make_index_name(table_name, columns): + index_name = '_'.join((table_name,) + tuple(columns)) + if len(index_name) > 64: + index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest() + index_name = '%s_%s' % (index_name[:56], index_hash[:7]) + return index_name + + +class SchemaMigrator(object): + explicit_create_foreign_key = False + explicit_delete_foreign_key = False + + def __init__(self, database): + self.database = database + + def make_context(self): + return self.database.get_sql_context() + + @classmethod + def from_database(cls, database): + if CockroachDatabase and isinstance(database, CockroachDatabase): + return CockroachDBMigrator(database) + elif isinstance(database, PostgresqlDatabase): + return PostgresqlMigrator(database) + elif isinstance(database, MySQLDatabase): + return MySQLMigrator(database) + elif isinstance(database, SqliteDatabase): + return SqliteMigrator(database) + raise ValueError('Unsupported database: %s' % database) + + @operation + def apply_default(self, table, column_name, field): + default = field.default + if callable_(default): + default = default() + + return (self.make_context() + .literal('UPDATE ') + .sql(Entity(table)) + .literal(' SET ') + .sql(Expression( + Entity(column_name), + OP.EQ, + field.db_value(default), + flat=True))) + + def _alter_table(self, ctx, table): + return ctx.literal('ALTER TABLE ').sql(Entity(table)) + + def _alter_column(self, ctx, table, column): + return (self + ._alter_table(ctx, table) + .literal(' ALTER COLUMN ') + .sql(Entity(column))) + + @operation + def alter_add_column(self, table, column_name, field): + # Make field null at first. + ctx = self.make_context() + field_null, field.null = field.null, True + field.name = field.column_name = column_name + (self + ._alter_table(ctx, table) + .literal(' ADD COLUMN ') + .sql(field.ddl(ctx))) + + field.null = field_null + if isinstance(field, ForeignKeyField): + self.add_inline_fk_sql(ctx, field) + return ctx + + @operation + def add_constraint(self, table, name, constraint): + return (self + ._alter_table(self.make_context(), table) + .literal(' ADD CONSTRAINT ') + .sql(Entity(name)) + .literal(' ') + .sql(constraint)) + + @operation + def add_unique(self, table, *column_names): + constraint_name = 'uniq_%s' % '_'.join(column_names) + constraint = NodeList(( + SQL('UNIQUE'), + EnclosedNodeList([Entity(column) for column in column_names]))) + return self.add_constraint(table, constraint_name, constraint) + + @operation + def drop_constraint(self, table, name): + return (self + ._alter_table(self.make_context(), table) + .literal(' DROP CONSTRAINT ') + .sql(Entity(name))) + + def add_inline_fk_sql(self, ctx, field): + ctx = (ctx + .literal(' REFERENCES ') + .sql(Entity(field.rel_model._meta.table_name)) + .literal(' ') + .sql(EnclosedNodeList((Entity(field.rel_field.column_name),)))) + if field.on_delete is not None: + ctx = ctx.literal(' ON DELETE %s' % field.on_delete) + if field.on_update is not None: + ctx = ctx.literal(' ON UPDATE %s' % field.on_update) + return ctx + + @operation + def add_foreign_key_constraint(self, table, column_name, rel, rel_column, + on_delete=None, on_update=None): + constraint = 'fk_%s_%s_refs_%s' % (table, column_name, rel) + ctx = (self + .make_context() + .literal('ALTER TABLE ') + .sql(Entity(table)) + .literal(' ADD CONSTRAINT ') + .sql(Entity(_truncate_constraint_name(constraint))) + .literal(' FOREIGN KEY ') + .sql(EnclosedNodeList((Entity(column_name),))) + .literal(' REFERENCES ') + .sql(Entity(rel)) + .literal(' (') + .sql(Entity(rel_column)) + .literal(')')) + if on_delete is not None: + ctx = ctx.literal(' ON DELETE %s' % on_delete) + if on_update is not None: + ctx = ctx.literal(' ON UPDATE %s' % on_update) + return ctx + + @operation + def add_column(self, table, column_name, field): + # Adding a column is complicated by the fact that if there are rows + # present and the field is non-null, then we need to first add the + # column as a nullable field, then set the value, then add a not null + # constraint. + if not field.null and field.default is None: + raise ValueError('%s is not null but has no default' % column_name) + + is_foreign_key = isinstance(field, ForeignKeyField) + if is_foreign_key and not field.rel_field: + raise ValueError('Foreign keys must specify a `field`.') + + operations = [self.alter_add_column(table, column_name, field)] + + # In the event the field is *not* nullable, update with the default + # value and set not null. + if not field.null: + operations.extend([ + self.apply_default(table, column_name, field), + self.add_not_null(table, column_name)]) + + if is_foreign_key and self.explicit_create_foreign_key: + operations.append( + self.add_foreign_key_constraint( + table, + column_name, + field.rel_model._meta.table_name, + field.rel_field.column_name, + field.on_delete, + field.on_update)) + + if field.index or field.unique: + using = getattr(field, 'index_type', None) + operations.append(self.add_index(table, (column_name,), + field.unique, using)) + + return operations + + @operation + def drop_foreign_key_constraint(self, table, column_name): + raise NotImplementedError + + @operation + def drop_column(self, table, column_name, cascade=True): + ctx = self.make_context() + (self._alter_table(ctx, table) + .literal(' DROP COLUMN ') + .sql(Entity(column_name))) + + if cascade: + ctx.literal(' CASCADE') + + fk_columns = [ + foreign_key.column + for foreign_key in self.database.get_foreign_keys(table)] + if column_name in fk_columns and self.explicit_delete_foreign_key: + return [self.drop_foreign_key_constraint(table, column_name), ctx] + + return ctx + + @operation + def rename_column(self, table, old_name, new_name): + return (self + ._alter_table(self.make_context(), table) + .literal(' RENAME COLUMN ') + .sql(Entity(old_name)) + .literal(' TO ') + .sql(Entity(new_name))) + + @operation + def add_not_null(self, table, column): + return (self + ._alter_column(self.make_context(), table, column) + .literal(' SET NOT NULL')) + + @operation + def drop_not_null(self, table, column): + return (self + ._alter_column(self.make_context(), table, column) + .literal(' DROP NOT NULL')) + + @operation + def alter_column_type(self, table, column, field, cast=None): + # ALTER TABLE
ALTER COLUMN + ctx = self.make_context() + ctx = (self + ._alter_column(ctx, table, column) + .literal(' TYPE ') + .sql(field.ddl_datatype(ctx))) + if cast is not None: + if not isinstance(cast, Node): + cast = SQL(cast) + ctx = ctx.literal(' USING ').sql(cast) + return ctx + + @operation + def rename_table(self, old_name, new_name): + return (self + ._alter_table(self.make_context(), old_name) + .literal(' RENAME TO ') + .sql(Entity(new_name))) + + @operation + def add_index(self, table, columns, unique=False, using=None): + ctx = self.make_context() + index_name = make_index_name(table, columns) + table_obj = Table(table) + cols = [getattr(table_obj.c, column) for column in columns] + index = Index(index_name, table_obj, cols, unique=unique, using=using) + return ctx.sql(index) + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name))) + + +class PostgresqlMigrator(SchemaMigrator): + def _primary_key_columns(self, tbl): + query = """ + SELECT pg_attribute.attname + FROM pg_index, pg_class, pg_attribute + WHERE + pg_class.oid = '%s'::regclass AND + indrelid = pg_class.oid AND + pg_attribute.attrelid = pg_class.oid AND + pg_attribute.attnum = any(pg_index.indkey) AND + indisprimary; + """ + cursor = self.database.execute_sql(query % tbl) + return [row[0] for row in cursor.fetchall()] + + @operation + def set_search_path(self, schema_name): + return (self + .make_context() + .literal('SET search_path TO %s' % schema_name)) + + @operation + def rename_table(self, old_name, new_name): + pk_names = self._primary_key_columns(old_name) + ParentClass = super(PostgresqlMigrator, self) + + operations = [ + ParentClass.rename_table(old_name, new_name, with_context=True)] + + if len(pk_names) == 1: + # Check for existence of primary key sequence. + seq_name = '%s_%s_seq' % (old_name, pk_names[0]) + query = """ + SELECT 1 + FROM information_schema.sequences + WHERE LOWER(sequence_name) = LOWER(%s) + """ + cursor = self.database.execute_sql(query, (seq_name,)) + if bool(cursor.fetchone()): + new_seq_name = '%s_%s_seq' % (new_name, pk_names[0]) + operations.append(ParentClass.rename_table( + seq_name, new_seq_name)) + + return operations + + +class CockroachDBMigrator(PostgresqlMigrator): + explicit_create_foreign_key = True + + def add_inline_fk_sql(self, ctx, field): + pass + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name)) + .literal(' CASCADE')) + + +class MySQLColumn(namedtuple('_Column', ('name', 'definition', 'null', 'pk', + 'default', 'extra'))): + @property + def is_pk(self): + return self.pk == 'PRI' + + @property + def is_unique(self): + return self.pk == 'UNI' + + @property + def is_null(self): + return self.null == 'YES' + + def sql(self, column_name=None, is_null=None): + if is_null is None: + is_null = self.is_null + if column_name is None: + column_name = self.name + parts = [ + Entity(column_name), + SQL(self.definition)] + if self.is_unique: + parts.append(SQL('UNIQUE')) + if is_null: + parts.append(SQL('NULL')) + else: + parts.append(SQL('NOT NULL')) + if self.is_pk: + parts.append(SQL('PRIMARY KEY')) + if self.extra: + parts.append(SQL(self.extra)) + return NodeList(parts) + + +class MySQLMigrator(SchemaMigrator): + explicit_create_foreign_key = True + explicit_delete_foreign_key = True + + def _alter_column(self, ctx, table, column): + return (self + ._alter_table(ctx, table) + .literal(' MODIFY ') + .sql(Entity(column))) + + @operation + def rename_table(self, old_name, new_name): + return (self + .make_context() + .literal('RENAME TABLE ') + .sql(Entity(old_name)) + .literal(' TO ') + .sql(Entity(new_name))) + + def _get_column_definition(self, table, column_name): + cursor = self.database.execute_sql('DESCRIBE `%s`;' % table) + rows = cursor.fetchall() + for row in rows: + column = MySQLColumn(*row) + if column.name == column_name: + return column + return False + + def get_foreign_key_constraint(self, table, column_name): + cursor = self.database.execute_sql( + ('SELECT constraint_name ' + 'FROM information_schema.key_column_usage WHERE ' + 'table_schema = DATABASE() AND ' + 'table_name = %s AND ' + 'column_name = %s AND ' + 'referenced_table_name IS NOT NULL AND ' + 'referenced_column_name IS NOT NULL;'), + (table, column_name)) + result = cursor.fetchone() + if not result: + raise AttributeError( + 'Unable to find foreign key constraint for ' + '"%s" on table "%s".' % (table, column_name)) + return result[0] + + @operation + def drop_foreign_key_constraint(self, table, column_name): + fk_constraint = self.get_foreign_key_constraint(table, column_name) + return (self + ._alter_table(self.make_context(), table) + .literal(' DROP FOREIGN KEY ') + .sql(Entity(fk_constraint))) + + def add_inline_fk_sql(self, ctx, field): + pass + + @operation + def add_not_null(self, table, column): + column_def = self._get_column_definition(table, column) + add_not_null = (self + ._alter_table(self.make_context(), table) + .literal(' MODIFY ') + .sql(column_def.sql(is_null=False))) + + fk_objects = dict( + (fk.column, fk) + for fk in self.database.get_foreign_keys(table)) + if column not in fk_objects: + return add_not_null + + fk_metadata = fk_objects[column] + return (self.drop_foreign_key_constraint(table, column), + add_not_null, + self.add_foreign_key_constraint( + table, + column, + fk_metadata.dest_table, + fk_metadata.dest_column)) + + @operation + def drop_not_null(self, table, column): + column = self._get_column_definition(table, column) + if column.is_pk: + raise ValueError('Primary keys can not be null') + return (self + ._alter_table(self.make_context(), table) + .literal(' MODIFY ') + .sql(column.sql(is_null=True))) + + @operation + def rename_column(self, table, old_name, new_name): + fk_objects = dict( + (fk.column, fk) + for fk in self.database.get_foreign_keys(table)) + is_foreign_key = old_name in fk_objects + + column = self._get_column_definition(table, old_name) + rename_ctx = (self + ._alter_table(self.make_context(), table) + .literal(' CHANGE ') + .sql(Entity(old_name)) + .literal(' ') + .sql(column.sql(column_name=new_name))) + if is_foreign_key: + fk_metadata = fk_objects[old_name] + return [ + self.drop_foreign_key_constraint(table, old_name), + rename_ctx, + self.add_foreign_key_constraint( + table, + new_name, + fk_metadata.dest_table, + fk_metadata.dest_column), + ] + else: + return rename_ctx + + @operation + def alter_column_type(self, table, column, field, cast=None): + if cast is not None: + raise ValueError('alter_column_type() does not support cast with ' + 'MySQL.') + ctx = self.make_context() + return (self + ._alter_table(ctx, table) + .literal(' MODIFY ') + .sql(Entity(column)) + .literal(' ') + .sql(field.ddl(ctx))) + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name)) + .literal(' ON ') + .sql(Entity(table))) + + +class SqliteMigrator(SchemaMigrator): + """ + SQLite supports a subset of ALTER TABLE queries, view the docs for the + full details http://sqlite.org/lang_altertable.html + """ + column_re = re.compile('(.+?)\((.+)\)') + column_split_re = re.compile(r'(?:[^,(]|\([^)]*\))+') + column_name_re = re.compile('["`\']?([\w]+)') + fk_re = re.compile('FOREIGN KEY\s+\("?([\w]+)"?\)\s+', re.I) + + def _get_column_names(self, table): + res = self.database.execute_sql('select * from "%s" limit 1' % table) + return [item[0] for item in res.description] + + def _get_create_table(self, table): + res = self.database.execute_sql( + ('select name, sql from sqlite_master ' + 'where type=? and LOWER(name)=?'), + ['table', table.lower()]) + return res.fetchone() + + @operation + def _update_column(self, table, column_to_update, fn): + columns = set(column.name.lower() + for column in self.database.get_columns(table)) + if column_to_update.lower() not in columns: + raise ValueError('Column "%s" does not exist on "%s"' % + (column_to_update, table)) + + # Get the SQL used to create the given table. + table, create_table = self._get_create_table(table) + + # Get the indexes and SQL to re-create indexes. + indexes = self.database.get_indexes(table) + + # Find any foreign keys we may need to remove. + self.database.get_foreign_keys(table) + + # Make sure the create_table does not contain any newlines or tabs, + # allowing the regex to work correctly. + create_table = re.sub(r'\s+', ' ', create_table) + + # Parse out the `CREATE TABLE` and column list portions of the query. + raw_create, raw_columns = self.column_re.search(create_table).groups() + + # Clean up the individual column definitions. + split_columns = self.column_split_re.findall(raw_columns) + column_defs = [col.strip() for col in split_columns] + + new_column_defs = [] + new_column_names = [] + original_column_names = [] + constraint_terms = ('foreign ', 'primary ', 'constraint ') + + for column_def in column_defs: + column_name, = self.column_name_re.match(column_def).groups() + + if column_name == column_to_update: + new_column_def = fn(column_name, column_def) + if new_column_def: + new_column_defs.append(new_column_def) + original_column_names.append(column_name) + column_name, = self.column_name_re.match( + new_column_def).groups() + new_column_names.append(column_name) + else: + new_column_defs.append(column_def) + + # Avoid treating constraints as columns. + if not column_def.lower().startswith(constraint_terms): + new_column_names.append(column_name) + original_column_names.append(column_name) + + # Create a mapping of original columns to new columns. + original_to_new = dict(zip(original_column_names, new_column_names)) + new_column = original_to_new.get(column_to_update) + + fk_filter_fn = lambda column_def: column_def + if not new_column: + # Remove any foreign keys associated with this column. + fk_filter_fn = lambda column_def: None + elif new_column != column_to_update: + # Update any foreign keys for this column. + fk_filter_fn = lambda column_def: self.fk_re.sub( + 'FOREIGN KEY ("%s") ' % new_column, + column_def) + + cleaned_columns = [] + for column_def in new_column_defs: + match = self.fk_re.match(column_def) + if match is not None and match.groups()[0] == column_to_update: + column_def = fk_filter_fn(column_def) + if column_def: + cleaned_columns.append(column_def) + + # Update the name of the new CREATE TABLE query. + temp_table = table + '__tmp__' + rgx = re.compile('("?)%s("?)' % table, re.I) + create = rgx.sub( + '\\1%s\\2' % temp_table, + raw_create) + + # Create the new table. + columns = ', '.join(cleaned_columns) + queries = [ + NodeList([SQL('DROP TABLE IF EXISTS'), Entity(temp_table)]), + SQL('%s (%s)' % (create.strip(), columns))] + + # Populate new table. + populate_table = NodeList(( + SQL('INSERT INTO'), + Entity(temp_table), + EnclosedNodeList([Entity(col) for col in new_column_names]), + SQL('SELECT'), + CommaNodeList([Entity(col) for col in original_column_names]), + SQL('FROM'), + Entity(table))) + drop_original = NodeList([SQL('DROP TABLE'), Entity(table)]) + + # Drop existing table and rename temp table. + queries += [ + populate_table, + drop_original, + self.rename_table(temp_table, table)] + + # Re-create user-defined indexes. User-defined indexes will have a + # non-empty SQL attribute. + for index in filter(lambda idx: idx.sql, indexes): + if column_to_update not in index.columns: + queries.append(SQL(index.sql)) + elif new_column: + sql = self._fix_index(index.sql, column_to_update, new_column) + if sql is not None: + queries.append(SQL(sql)) + + return queries + + def _fix_index(self, sql, column_to_update, new_column): + # Split on the name of the column to update. If it splits into two + # pieces, then there's no ambiguity and we can simply replace the + # old with the new. + parts = sql.split(column_to_update) + if len(parts) == 2: + return sql.replace(column_to_update, new_column) + + # Find the list of columns in the index expression. + lhs, rhs = sql.rsplit('(', 1) + + # Apply the same "split in two" logic to the column list portion of + # the query. + if len(rhs.split(column_to_update)) == 2: + return '%s(%s' % (lhs, rhs.replace(column_to_update, new_column)) + + # Strip off the trailing parentheses and go through each column. + parts = rhs.rsplit(')', 1)[0].split(',') + columns = [part.strip('"`[]\' ') for part in parts] + + # `columns` looks something like: ['status', 'timestamp" DESC'] + # https://www.sqlite.org/lang_keywords.html + # Strip out any junk after the column name. + clean = [] + for column in columns: + if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column): + column = new_column + column[len(column_to_update):] + clean.append(column) + + return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean)) + + @operation + def drop_column(self, table, column_name, cascade=True): + return self._update_column(table, column_name, lambda a, b: None) + + @operation + def rename_column(self, table, old_name, new_name): + def _rename(column_name, column_def): + return column_def.replace(column_name, new_name) + return self._update_column(table, old_name, _rename) + + @operation + def add_not_null(self, table, column): + def _add_not_null(column_name, column_def): + return column_def + ' NOT NULL' + return self._update_column(table, column, _add_not_null) + + @operation + def drop_not_null(self, table, column): + def _drop_not_null(column_name, column_def): + return column_def.replace('NOT NULL', '') + return self._update_column(table, column, _drop_not_null) + + @operation + def alter_column_type(self, table, column, field, cast=None): + if cast is not None: + raise ValueError('alter_column_type() does not support cast with ' + 'Sqlite.') + ctx = self.make_context() + def _alter_column_type(column_name, column_def): + node_list = field.ddl(ctx) + sql, _ = ctx.sql(Entity(column)).sql(node_list).query() + return sql + return self._update_column(table, column, _alter_column_type) + + @operation + def add_constraint(self, table, name, constraint): + raise NotImplementedError + + @operation + def drop_constraint(self, table, name): + raise NotImplementedError + + @operation + def add_foreign_key_constraint(self, table, column_name, field, + on_delete=None, on_update=None): + raise NotImplementedError + + +def migrate(*operations, **kwargs): + for operation in operations: + operation.run() diff --git a/python2.7libs/playhouse/mysql_ext.py b/python2.7libs/playhouse/mysql_ext.py new file mode 100644 index 0000000..9ee2655 --- /dev/null +++ b/python2.7libs/playhouse/mysql_ext.py @@ -0,0 +1,49 @@ +import json + +try: + import mysql.connector as mysql_connector +except ImportError: + mysql_connector = None + +from peewee import ImproperlyConfigured +from peewee import MySQLDatabase +from peewee import NodeList +from peewee import SQL +from peewee import TextField +from peewee import fn + + +class MySQLConnectorDatabase(MySQLDatabase): + def _connect(self): + if mysql_connector is None: + raise ImproperlyConfigured('MySQL connector not installed!') + return mysql_connector.connect(db=self.database, **self.connect_params) + + def cursor(self, commit=None): + if self.is_closed(): + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') + return self._state.conn.cursor(buffered=True) + + +class JSONField(TextField): + field_type = 'JSON' + + def db_value(self, value): + if value is not None: + return json.dumps(value) + + def python_value(self, value): + if value is not None: + return json.loads(value) + + +def Match(columns, expr, modifier=None): + if isinstance(columns, (list, tuple)): + match = fn.MATCH(*columns) # Tuple of one or more columns / fields. + else: + match = fn.MATCH(columns) # Single column / field. + args = expr if modifier is None else NodeList((expr, SQL(modifier))) + return NodeList((match, fn.AGAINST(args))) diff --git a/python2.7libs/playhouse/pool.py b/python2.7libs/playhouse/pool.py new file mode 100644 index 0000000..2ee3b48 --- /dev/null +++ b/python2.7libs/playhouse/pool.py @@ -0,0 +1,318 @@ +""" +Lightweight connection pooling for peewee. + +In a multi-threaded application, up to `max_connections` will be opened. Each +thread (or, if using gevent, greenlet) will have it's own connection. + +In a single-threaded application, only one connection will be created. It will +be continually recycled until either it exceeds the stale timeout or is closed +explicitly (using `.manual_close()`). + +By default, all your application needs to do is ensure that connections are +closed when you are finished with them, and they will be returned to the pool. +For web applications, this typically means that at the beginning of a request, +you will open a connection, and when you return a response, you will close the +connection. + +Simple Postgres pool example code: + + # Use the special postgresql extensions. + from playhouse.pool import PooledPostgresqlExtDatabase + + db = PooledPostgresqlExtDatabase( + 'my_app', + max_connections=32, + stale_timeout=300, # 5 minutes. + user='postgres') + + class BaseModel(Model): + class Meta: + database = db + +That's it! +""" +import heapq +import logging +import random +import time +from collections import namedtuple +from itertools import chain + +try: + from psycopg2.extensions import TRANSACTION_STATUS_IDLE + from psycopg2.extensions import TRANSACTION_STATUS_INERROR + from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN +except ImportError: + TRANSACTION_STATUS_IDLE = \ + TRANSACTION_STATUS_INERROR = \ + TRANSACTION_STATUS_UNKNOWN = None + +from peewee import MySQLDatabase +from peewee import PostgresqlDatabase +from peewee import SqliteDatabase + +logger = logging.getLogger('peewee.pool') + + +def make_int(val): + if val is not None and not isinstance(val, (int, float)): + return int(val) + return val + + +class MaxConnectionsExceeded(ValueError): pass + + +PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection', + 'checked_out')) + + +class PooledDatabase(object): + def __init__(self, database, max_connections=20, stale_timeout=None, + timeout=None, **kwargs): + self._max_connections = make_int(max_connections) + self._stale_timeout = make_int(stale_timeout) + self._wait_timeout = make_int(timeout) + if self._wait_timeout == 0: + self._wait_timeout = float('inf') + + # Available / idle connections stored in a heap, sorted oldest first. + self._connections = [] + + # Mapping of connection id to PoolConnection. Ordinarily we would want + # to use something like a WeakKeyDictionary, but Python typically won't + # allow us to create weak references to connection objects. + self._in_use = {} + + # Use the memory address of the connection as the key in the event the + # connection object is not hashable. Connections will not get + # garbage-collected, however, because a reference to them will persist + # in "_in_use" as long as the conn has not been closed. + self.conn_key = id + + super(PooledDatabase, self).__init__(database, **kwargs) + + def init(self, database, max_connections=None, stale_timeout=None, + timeout=None, **connect_kwargs): + super(PooledDatabase, self).init(database, **connect_kwargs) + if max_connections is not None: + self._max_connections = make_int(max_connections) + if stale_timeout is not None: + self._stale_timeout = make_int(stale_timeout) + if timeout is not None: + self._wait_timeout = make_int(timeout) + if self._wait_timeout == 0: + self._wait_timeout = float('inf') + + def connect(self, reuse_if_open=False): + if not self._wait_timeout: + return super(PooledDatabase, self).connect(reuse_if_open) + + expires = time.time() + self._wait_timeout + while expires > time.time(): + try: + ret = super(PooledDatabase, self).connect(reuse_if_open) + except MaxConnectionsExceeded: + time.sleep(0.1) + else: + return ret + raise MaxConnectionsExceeded('Max connections exceeded, timed out ' + 'attempting to connect.') + + def _connect(self): + while True: + try: + # Remove the oldest connection from the heap. + ts, conn = heapq.heappop(self._connections) + key = self.conn_key(conn) + except IndexError: + ts = conn = None + logger.debug('No connection available in pool.') + break + else: + if self._is_closed(conn): + # This connecton was closed, but since it was not stale + # it got added back to the queue of available conns. We + # then closed it and marked it as explicitly closed, so + # it's safe to throw it away now. + # (Because Database.close() calls Database._close()). + logger.debug('Connection %s was closed.', key) + ts = conn = None + elif self._stale_timeout and self._is_stale(ts): + # If we are attempting to check out a stale connection, + # then close it. We don't need to mark it in the "closed" + # set, because it is not in the list of available conns + # anymore. + logger.debug('Connection %s was stale, closing.', key) + self._close(conn, True) + ts = conn = None + else: + break + + if conn is None: + if self._max_connections and ( + len(self._in_use) >= self._max_connections): + raise MaxConnectionsExceeded('Exceeded maximum connections.') + conn = super(PooledDatabase, self)._connect() + ts = time.time() - random.random() / 1000 + key = self.conn_key(conn) + logger.debug('Created new connection %s.', key) + + self._in_use[key] = PoolConnection(ts, conn, time.time()) + return conn + + def _is_stale(self, timestamp): + # Called on check-out and check-in to ensure the connection has + # not outlived the stale timeout. + return (time.time() - timestamp) > self._stale_timeout + + def _is_closed(self, conn): + return False + + def _can_reuse(self, conn): + # Called on check-in to make sure the connection can be re-used. + return True + + def _close(self, conn, close_conn=False): + key = self.conn_key(conn) + if close_conn: + super(PooledDatabase, self)._close(conn) + elif key in self._in_use: + pool_conn = self._in_use.pop(key) + if self._stale_timeout and self._is_stale(pool_conn.timestamp): + logger.debug('Closing stale connection %s.', key) + super(PooledDatabase, self)._close(conn) + elif self._can_reuse(conn): + logger.debug('Returning %s to pool.', key) + heapq.heappush(self._connections, (pool_conn.timestamp, conn)) + else: + logger.debug('Closed %s.', key) + + def manual_close(self): + """ + Close the underlying connection without returning it to the pool. + """ + if self.is_closed(): + return False + + # Obtain reference to the connection in-use by the calling thread. + conn = self.connection() + + # A connection will only be re-added to the available list if it is + # marked as "in use" at the time it is closed. We will explicitly + # remove it from the "in use" list, call "close()" for the + # side-effects, and then explicitly close the connection. + self._in_use.pop(self.conn_key(conn), None) + self.close() + self._close(conn, close_conn=True) + + def close_idle(self): + # Close any open connections that are not currently in-use. + with self._lock: + for _, conn in self._connections: + self._close(conn, close_conn=True) + self._connections = [] + + def close_stale(self, age=600): + # Close any connections that are in-use but were checked out quite some + # time ago and can be considered stale. + with self._lock: + in_use = {} + cutoff = time.time() - age + n = 0 + for key, pool_conn in self._in_use.items(): + if pool_conn.checked_out < cutoff: + self._close(pool_conn.connection, close_conn=True) + n += 1 + else: + in_use[key] = pool_conn + self._in_use = in_use + return n + + def close_all(self): + # Close all connections -- available and in-use. Warning: may break any + # active connections used by other threads. + self.close() + with self._lock: + for _, conn in self._connections: + self._close(conn, close_conn=True) + for pool_conn in self._in_use.values(): + self._close(pool_conn.connection, close_conn=True) + self._connections = [] + self._in_use = {} + + +class PooledMySQLDatabase(PooledDatabase, MySQLDatabase): + def _is_closed(self, conn): + try: + conn.ping(False) + except: + return True + else: + return False + + +class _PooledPostgresqlDatabase(PooledDatabase): + def _is_closed(self, conn): + if conn.closed: + return True + + txn_status = conn.get_transaction_status() + if txn_status == TRANSACTION_STATUS_UNKNOWN: + return True + elif txn_status != TRANSACTION_STATUS_IDLE: + conn.rollback() + return False + + def _can_reuse(self, conn): + txn_status = conn.get_transaction_status() + # Do not return connection in an error state, as subsequent queries + # will all fail. If the status is unknown then we lost the connection + # to the server and the connection should not be re-used. + if txn_status == TRANSACTION_STATUS_UNKNOWN: + return False + elif txn_status == TRANSACTION_STATUS_INERROR: + conn.reset() + elif txn_status != TRANSACTION_STATUS_IDLE: + conn.rollback() + return True + +class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase): + pass + +try: + from playhouse.postgres_ext import PostgresqlExtDatabase + + class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase): + pass +except ImportError: + PooledPostgresqlExtDatabase = None + + +class _PooledSqliteDatabase(PooledDatabase): + def _is_closed(self, conn): + try: + conn.total_changes + except: + return True + else: + return False + +class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase): + pass + +try: + from playhouse.sqlite_ext import SqliteExtDatabase + + class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase): + pass +except ImportError: + PooledSqliteExtDatabase = None + +try: + from playhouse.sqlite_ext import CSqliteExtDatabase + + class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase): + pass +except ImportError: + PooledCSqliteExtDatabase = None diff --git a/python2.7libs/playhouse/postgres_ext.py b/python2.7libs/playhouse/postgres_ext.py new file mode 100644 index 0000000..d2b3766 --- /dev/null +++ b/python2.7libs/playhouse/postgres_ext.py @@ -0,0 +1,480 @@ +""" +Collection of postgres-specific extensions, currently including: + +* Support for hstore, a key/value type storage +""" +import json +import logging +import uuid + +from peewee import * +from peewee import ColumnBase +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import SENTINEL +from peewee import __exception_wrapper__ + +try: + from psycopg2cffi import compat + compat.register() +except ImportError: + pass + +try: + from psycopg2.extras import register_hstore +except ImportError: + def register_hstore(c, globally): + pass +try: + from psycopg2.extras import Json +except: + Json = None + + +logger = logging.getLogger('peewee') + + +HCONTAINS_DICT = '@>' +HCONTAINS_KEYS = '?&' +HCONTAINS_KEY = '?' +HCONTAINS_ANY_KEY = '?|' +HKEY = '->' +HUPDATE = '||' +ACONTAINS = '@>' +ACONTAINS_ANY = '&&' +TS_MATCH = '@@' +JSONB_CONTAINS = '@>' +JSONB_CONTAINED_BY = '<@' +JSONB_CONTAINS_KEY = '?' +JSONB_CONTAINS_ANY_KEY = '?|' +JSONB_CONTAINS_ALL_KEYS = '?&' +JSONB_EXISTS = '?' +JSONB_REMOVE = '-' + + +class _LookupNode(ColumnBase): + def __init__(self, node, parts): + self.node = node + self.parts = parts + super(_LookupNode, self).__init__() + + def clone(self): + return type(self)(self.node, list(self.parts)) + + +class _JsonLookupBase(_LookupNode): + def __init__(self, node, parts, as_json=False): + super(_JsonLookupBase, self).__init__(node, parts) + self._as_json = as_json + + def clone(self): + return type(self)(self.node, list(self.parts), self._as_json) + + @Node.copy + def as_json(self, as_json=True): + self._as_json = as_json + + def concat(self, rhs): + return Expression(self.as_json(True), OP.CONCAT, Json(rhs)) + + def contains(self, other): + clone = self.as_json(True) + if isinstance(other, (list, dict)): + return Expression(clone, JSONB_CONTAINS, Json(other)) + return Expression(clone, JSONB_EXISTS, other) + + def contains_any(self, *keys): + return Expression( + self.as_json(True), + JSONB_CONTAINS_ANY_KEY, + Value(list(keys), unpack=False)) + + def contains_all(self, *keys): + return Expression( + self.as_json(True), + JSONB_CONTAINS_ALL_KEYS, + Value(list(keys), unpack=False)) + + def has_key(self, key): + return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key) + + +class JsonLookup(_JsonLookupBase): + def __getitem__(self, value): + return JsonLookup(self.node, self.parts + [value], self._as_json) + + def __sql__(self, ctx): + ctx.sql(self.node) + for part in self.parts[:-1]: + ctx.literal('->').sql(part) + if self.parts: + (ctx + .literal('->' if self._as_json else '->>') + .sql(self.parts[-1])) + + return ctx + + +class JsonPath(_JsonLookupBase): + def __sql__(self, ctx): + return (ctx + .sql(self.node) + .literal('#>' if self._as_json else '#>>') + .sql(Value('{%s}' % ','.join(map(str, self.parts))))) + + +class ObjectSlice(_LookupNode): + @classmethod + def create(cls, node, value): + if isinstance(value, slice): + parts = [value.start or 0, value.stop or 0] + elif isinstance(value, int): + parts = [value] + else: + parts = map(int, value.split(':')) + return cls(node, parts) + + def __sql__(self, ctx): + return (ctx + .sql(self.node) + .literal('[%s]' % ':'.join(str(p + 1) for p in self.parts))) + + def __getitem__(self, value): + return ObjectSlice.create(self, value) + + +class IndexedFieldMixin(object): + default_index_type = 'GIN' + + def __init__(self, *args, **kwargs): + kwargs.setdefault('index', True) # By default, use an index. + super(IndexedFieldMixin, self).__init__(*args, **kwargs) + + +class ArrayField(IndexedFieldMixin, Field): + passthrough = True + + def __init__(self, field_class=IntegerField, field_kwargs=None, + dimensions=1, convert_values=False, *args, **kwargs): + self.__field = field_class(**(field_kwargs or {})) + self.dimensions = dimensions + self.convert_values = convert_values + self.field_type = self.__field.field_type + super(ArrayField, self).__init__(*args, **kwargs) + + def bind(self, model, name, set_attribute=True): + ret = super(ArrayField, self).bind(model, name, set_attribute) + self.__field.bind(model, '__array_%s' % name, False) + return ret + + def ddl_datatype(self, ctx): + data_type = self.__field.ddl_datatype(ctx) + return NodeList((data_type, SQL('[]' * self.dimensions)), glue='') + + def db_value(self, value): + if value is None or isinstance(value, Node): + return value + elif self.convert_values: + return self._process(self.__field.db_value, value, self.dimensions) + else: + return value if isinstance(value, list) else list(value) + + def python_value(self, value): + if self.convert_values and value is not None: + conv = self.__field.python_value + if isinstance(value, list): + return self._process(conv, value, self.dimensions) + else: + return conv(value) + else: + return value + + def _process(self, conv, value, dimensions): + dimensions -= 1 + if dimensions == 0: + return [conv(v) for v in value] + else: + return [self._process(conv, v, dimensions) for v in value] + + def __getitem__(self, value): + return ObjectSlice.create(self, value) + + def _e(op): + def inner(self, rhs): + return Expression(self, op, ArrayValue(self, rhs)) + return inner + __eq__ = _e(OP.EQ) + __ne__ = _e(OP.NE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __hash__ = Field.__hash__ + + def contains(self, *items): + return Expression(self, ACONTAINS, ArrayValue(self, items)) + + def contains_any(self, *items): + return Expression(self, ACONTAINS_ANY, ArrayValue(self, items)) + + +class ArrayValue(Node): + def __init__(self, field, value): + self.field = field + self.value = value + + def __sql__(self, ctx): + return (ctx + .sql(Value(self.value, unpack=False)) + .literal('::') + .sql(self.field.ddl_datatype(ctx))) + + +class DateTimeTZField(DateTimeField): + field_type = 'TIMESTAMPTZ' + + +class HStoreField(IndexedFieldMixin, Field): + field_type = 'HSTORE' + __hash__ = Field.__hash__ + + def __getitem__(self, key): + return Expression(self, HKEY, Value(key)) + + def keys(self): + return fn.akeys(self) + + def values(self): + return fn.avals(self) + + def items(self): + return fn.hstore_to_matrix(self) + + def slice(self, *args): + return fn.slice(self, Value(list(args), unpack=False)) + + def exists(self, key): + return fn.exist(self, key) + + def defined(self, key): + return fn.defined(self, key) + + def update(self, **data): + return Expression(self, HUPDATE, data) + + def delete(self, *keys): + return fn.delete(self, Value(list(keys), unpack=False)) + + def contains(self, value): + if isinstance(value, dict): + rhs = Value(value, unpack=False) + return Expression(self, HCONTAINS_DICT, rhs) + elif isinstance(value, (list, tuple)): + rhs = Value(value, unpack=False) + return Expression(self, HCONTAINS_KEYS, rhs) + return Expression(self, HCONTAINS_KEY, value) + + def contains_any(self, *keys): + return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys), + unpack=False)) + + +class JSONField(Field): + field_type = 'JSON' + _json_datatype = 'json' + + def __init__(self, dumps=None, *args, **kwargs): + if Json is None: + raise Exception('Your version of psycopg2 does not support JSON.') + self.dumps = dumps or json.dumps + super(JSONField, self).__init__(*args, **kwargs) + + def db_value(self, value): + if value is None: + return value + if not isinstance(value, Json): + return Cast(self.dumps(value), self._json_datatype) + return value + + def __getitem__(self, value): + return JsonLookup(self, [value]) + + def path(self, *keys): + return JsonPath(self, keys) + + def concat(self, value): + return super(JSONField, self).concat(Json(value)) + + +def cast_jsonb(node): + return NodeList((node, SQL('::jsonb')), glue='') + + +class BinaryJSONField(IndexedFieldMixin, JSONField): + field_type = 'JSONB' + _json_datatype = 'jsonb' + __hash__ = Field.__hash__ + + def contains(self, other): + if isinstance(other, (list, dict)): + return Expression(self, JSONB_CONTAINS, Json(other)) + elif isinstance(other, JSONField): + return Expression(self, JSONB_CONTAINS, other) + return Expression(cast_jsonb(self), JSONB_EXISTS, other) + + def contained_by(self, other): + return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other)) + + def contains_any(self, *items): + return Expression( + cast_jsonb(self), + JSONB_CONTAINS_ANY_KEY, + Value(list(items), unpack=False)) + + def contains_all(self, *items): + return Expression( + cast_jsonb(self), + JSONB_CONTAINS_ALL_KEYS, + Value(list(items), unpack=False)) + + def has_key(self, key): + return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key) + + def remove(self, *items): + return Expression( + cast_jsonb(self), + JSONB_REMOVE, + Value(list(items), unpack=False)) + + +class TSVectorField(IndexedFieldMixin, TextField): + field_type = 'TSVECTOR' + __hash__ = Field.__hash__ + + def match(self, query, language=None, plain=False): + params = (language, query) if language is not None else (query,) + func = fn.plainto_tsquery if plain else fn.to_tsquery + return Expression(self, TS_MATCH, func(*params)) + + +def Match(field, query, language=None): + params = (language, query) if language is not None else (query,) + field_params = (language, field) if language is not None else (field,) + return Expression( + fn.to_tsvector(*field_params), + TS_MATCH, + fn.to_tsquery(*params)) + + +class IntervalField(Field): + field_type = 'INTERVAL' + + +class FetchManyCursor(object): + __slots__ = ('cursor', 'array_size', 'exhausted', 'iterable') + + def __init__(self, cursor, array_size=None): + self.cursor = cursor + self.array_size = array_size or cursor.itersize + self.exhausted = False + self.iterable = self.row_gen() + + @property + def description(self): + return self.cursor.description + + def close(self): + self.cursor.close() + + def row_gen(self): + while True: + rows = self.cursor.fetchmany(self.array_size) + if not rows: + return + for row in rows: + yield row + + def fetchone(self): + if self.exhausted: + return + try: + return next(self.iterable) + except StopIteration: + self.exhausted = True + + +class ServerSideQuery(Node): + def __init__(self, query, array_size=None): + self.query = query + self.array_size = array_size + self._cursor_wrapper = None + + def __sql__(self, ctx): + return self.query.__sql__(ctx) + + def __iter__(self): + if self._cursor_wrapper is None: + self._execute(self.query._database) + return iter(self._cursor_wrapper.iterator()) + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self.query, named_cursor=True, + array_size=self.array_size) + self._cursor_wrapper = self.query._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + +def ServerSide(query, database=None, array_size=None): + if database is None: + database = query._database + with database.transaction(): + server_side_query = ServerSideQuery(query, array_size=array_size) + for row in server_side_query: + yield row + + +class _empty_object(object): + __slots__ = () + def __nonzero__(self): + return False + __bool__ = __nonzero__ + +__named_cursor__ = _empty_object() + + +class PostgresqlExtDatabase(PostgresqlDatabase): + def __init__(self, *args, **kwargs): + self._register_hstore = kwargs.pop('register_hstore', False) + self._server_side_cursors = kwargs.pop('server_side_cursors', False) + super(PostgresqlExtDatabase, self).__init__(*args, **kwargs) + + def _connect(self): + conn = super(PostgresqlExtDatabase, self)._connect() + if self._register_hstore: + register_hstore(conn, globally=True) + return conn + + def cursor(self, commit=None): + if self.is_closed(): + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') + if commit is __named_cursor__: + return self._state.conn.cursor(name=str(uuid.uuid1())) + return self._state.conn.cursor() + + def execute(self, query, commit=SENTINEL, named_cursor=False, + array_size=None, **context_options): + ctx = self.get_sql_context(**context_options) + sql, params = ctx.sql(query).query() + named_cursor = named_cursor or (self._server_side_cursors and + sql[:6].lower() == 'select') + if named_cursor: + commit = __named_cursor__ + cursor = self.execute_sql(sql, params, commit=commit) + if named_cursor: + cursor = FetchManyCursor(cursor, array_size) + return cursor diff --git a/python2.7libs/playhouse/reflection.py b/python2.7libs/playhouse/reflection.py new file mode 100644 index 0000000..aac24e7 --- /dev/null +++ b/python2.7libs/playhouse/reflection.py @@ -0,0 +1,832 @@ +try: + from collections import OrderedDict +except ImportError: + OrderedDict = dict +from collections import namedtuple +from inspect import isclass +import re + +from peewee import * +from peewee import _StringField +from peewee import _query_val_transform +from peewee import CommaNodeList +from peewee import SCOPE_VALUES +from peewee import make_snake_case +from peewee import text_type +try: + from pymysql.constants import FIELD_TYPE +except ImportError: + try: + from MySQLdb.constants import FIELD_TYPE + except ImportError: + FIELD_TYPE = None +try: + from playhouse import postgres_ext +except ImportError: + postgres_ext = None +try: + from playhouse.cockroachdb import CockroachDatabase +except ImportError: + CockroachDatabase = None + +RESERVED_WORDS = set([ + 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', + 'else', 'except', 'exec', 'finally', 'for', 'from', 'global', 'if', + 'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 'raise', + 'return', 'try', 'while', 'with', 'yield', +]) + + +class UnknownField(object): + pass + + +class Column(object): + """ + Store metadata about a database column. + """ + primary_key_types = (IntegerField, AutoField) + + def __init__(self, name, field_class, raw_column_type, nullable, + primary_key=False, column_name=None, index=False, + unique=False, default=None, extra_parameters=None): + self.name = name + self.field_class = field_class + self.raw_column_type = raw_column_type + self.nullable = nullable + self.primary_key = primary_key + self.column_name = column_name + self.index = index + self.unique = unique + self.default = default + self.extra_parameters = extra_parameters + + # Foreign key metadata. + self.rel_model = None + self.related_name = None + self.to_field = None + + def __repr__(self): + attrs = [ + 'field_class', + 'raw_column_type', + 'nullable', + 'primary_key', + 'column_name'] + keyword_args = ', '.join( + '%s=%s' % (attr, getattr(self, attr)) + for attr in attrs) + return 'Column(%s, %s)' % (self.name, keyword_args) + + def get_field_parameters(self): + params = {} + if self.extra_parameters is not None: + params.update(self.extra_parameters) + + # Set up default attributes. + if self.nullable: + params['null'] = True + if self.field_class is ForeignKeyField or self.name != self.column_name: + params['column_name'] = "'%s'" % self.column_name + if self.primary_key and not issubclass(self.field_class, AutoField): + params['primary_key'] = True + if self.default is not None: + params['constraints'] = '[SQL("DEFAULT %s")]' % self.default + + # Handle ForeignKeyField-specific attributes. + if self.is_foreign_key(): + params['model'] = self.rel_model + if self.to_field: + params['field'] = "'%s'" % self.to_field + if self.related_name: + params['backref'] = "'%s'" % self.related_name + + # Handle indexes on column. + if not self.is_primary_key(): + if self.unique: + params['unique'] = 'True' + elif self.index and not self.is_foreign_key(): + params['index'] = 'True' + + return params + + def is_primary_key(self): + return self.field_class is AutoField or self.primary_key + + def is_foreign_key(self): + return self.field_class is ForeignKeyField + + def is_self_referential_fk(self): + return (self.field_class is ForeignKeyField and + self.rel_model == "'self'") + + def set_foreign_key(self, foreign_key, model_names, dest=None, + related_name=None): + self.foreign_key = foreign_key + self.field_class = ForeignKeyField + if foreign_key.dest_table == foreign_key.table: + self.rel_model = "'self'" + else: + self.rel_model = model_names[foreign_key.dest_table] + self.to_field = dest and dest.name or None + self.related_name = related_name or None + + def get_field(self): + # Generate the field definition for this column. + field_params = {} + for key, value in self.get_field_parameters().items(): + if isclass(value) and issubclass(value, Field): + value = value.__name__ + field_params[key] = value + + param_str = ', '.join('%s=%s' % (k, v) + for k, v in sorted(field_params.items())) + field = '%s = %s(%s)' % ( + self.name, + self.field_class.__name__, + param_str) + + if self.field_class is UnknownField: + field = '%s # %s' % (field, self.raw_column_type) + + return field + + +class Metadata(object): + column_map = {} + extension_import = '' + + def __init__(self, database): + self.database = database + self.requires_extension = False + + def execute(self, sql, *params): + return self.database.execute_sql(sql, params) + + def get_columns(self, table, schema=None): + metadata = OrderedDict( + (metadata.name, metadata) + for metadata in self.database.get_columns(table, schema)) + + # Look up the actual column type for each column. + column_types, extra_params = self.get_column_types(table, schema) + + # Look up the primary keys. + pk_names = self.get_primary_keys(table, schema) + if len(pk_names) == 1: + pk = pk_names[0] + if column_types[pk] is IntegerField: + column_types[pk] = AutoField + elif column_types[pk] is BigIntegerField: + column_types[pk] = BigAutoField + + columns = OrderedDict() + for name, column_data in metadata.items(): + field_class = column_types[name] + default = self._clean_default(field_class, column_data.default) + + columns[name] = Column( + name, + field_class=field_class, + raw_column_type=column_data.data_type, + nullable=column_data.null, + primary_key=column_data.primary_key, + column_name=name, + default=default, + extra_parameters=extra_params.get(name)) + + return columns + + def get_column_types(self, table, schema=None): + raise NotImplementedError + + def _clean_default(self, field_class, default): + if default is None or field_class in (AutoField, BigAutoField) or \ + default.lower() == 'null': + return + if issubclass(field_class, _StringField) and \ + isinstance(default, text_type) and not default.startswith("'"): + default = "'%s'" % default + return default or "''" + + def get_foreign_keys(self, table, schema=None): + return self.database.get_foreign_keys(table, schema) + + def get_primary_keys(self, table, schema=None): + return self.database.get_primary_keys(table, schema) + + def get_indexes(self, table, schema=None): + return self.database.get_indexes(table, schema) + + +class PostgresqlMetadata(Metadata): + column_map = { + 16: BooleanField, + 17: BlobField, + 20: BigIntegerField, + 21: SmallIntegerField, + 23: IntegerField, + 25: TextField, + 700: FloatField, + 701: DoubleField, + 1042: CharField, # blank-padded CHAR + 1043: CharField, + 1082: DateField, + 1114: DateTimeField, + 1184: DateTimeField, + 1083: TimeField, + 1266: TimeField, + 1700: DecimalField, + 2950: UUIDField, # UUID + } + array_types = { + 1000: BooleanField, + 1001: BlobField, + 1005: SmallIntegerField, + 1007: IntegerField, + 1009: TextField, + 1014: CharField, + 1015: CharField, + 1016: BigIntegerField, + 1115: DateTimeField, + 1182: DateField, + 1183: TimeField, + } + extension_import = 'from playhouse.postgres_ext import *' + + def __init__(self, database): + super(PostgresqlMetadata, self).__init__(database) + + if postgres_ext is not None: + # Attempt to add types like HStore and JSON. + cursor = self.execute('select oid, typname, format_type(oid, NULL)' + ' from pg_type;') + results = cursor.fetchall() + + for oid, typname, formatted_type in results: + if typname == 'json': + self.column_map[oid] = postgres_ext.JSONField + elif typname == 'jsonb': + self.column_map[oid] = postgres_ext.BinaryJSONField + elif typname == 'hstore': + self.column_map[oid] = postgres_ext.HStoreField + elif typname == 'tsvector': + self.column_map[oid] = postgres_ext.TSVectorField + + for oid in self.array_types: + self.column_map[oid] = postgres_ext.ArrayField + + def get_column_types(self, table, schema): + column_types = {} + extra_params = {} + extension_types = set(( + postgres_ext.ArrayField, + postgres_ext.BinaryJSONField, + postgres_ext.JSONField, + postgres_ext.TSVectorField, + postgres_ext.HStoreField)) if postgres_ext is not None else set() + + # Look up the actual column type for each column. + identifier = '%s.%s' % (schema, table) + cursor = self.execute( + 'SELECT attname, atttypid FROM pg_catalog.pg_attribute ' + 'WHERE attrelid = %s::regclass AND attnum > %s', identifier, 0) + + # Store column metadata in dictionary keyed by column name. + for name, oid in cursor.fetchall(): + column_types[name] = self.column_map.get(oid, UnknownField) + if column_types[name] in extension_types: + self.requires_extension = True + if oid in self.array_types: + extra_params[name] = {'field_class': self.array_types[oid]} + + return column_types, extra_params + + def get_columns(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_columns(table, schema) + + def get_foreign_keys(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_foreign_keys(table, schema) + + def get_primary_keys(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_primary_keys(table, schema) + + def get_indexes(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_indexes(table, schema) + + +class CockroachDBMetadata(PostgresqlMetadata): + # CRDB treats INT the same as BIGINT, so we just map bigint type OIDs to + # regular IntegerField. + column_map = PostgresqlMetadata.column_map.copy() + column_map[20] = IntegerField + array_types = PostgresqlMetadata.array_types.copy() + array_types[1016] = IntegerField + extension_import = 'from playhouse.cockroachdb import *' + + def __init__(self, database): + Metadata.__init__(self, database) + self.requires_extension = True + + if postgres_ext is not None: + # Attempt to add JSON types. + cursor = self.execute('select oid, typname, format_type(oid, NULL)' + ' from pg_type;') + results = cursor.fetchall() + + for oid, typname, formatted_type in results: + if typname == 'jsonb': + self.column_map[oid] = postgres_ext.BinaryJSONField + + for oid in self.array_types: + self.column_map[oid] = postgres_ext.ArrayField + + +class MySQLMetadata(Metadata): + if FIELD_TYPE is None: + column_map = {} + else: + column_map = { + FIELD_TYPE.BLOB: TextField, + FIELD_TYPE.CHAR: CharField, + FIELD_TYPE.DATE: DateField, + FIELD_TYPE.DATETIME: DateTimeField, + FIELD_TYPE.DECIMAL: DecimalField, + FIELD_TYPE.DOUBLE: FloatField, + FIELD_TYPE.FLOAT: FloatField, + FIELD_TYPE.INT24: IntegerField, + FIELD_TYPE.LONG_BLOB: TextField, + FIELD_TYPE.LONG: IntegerField, + FIELD_TYPE.LONGLONG: BigIntegerField, + FIELD_TYPE.MEDIUM_BLOB: TextField, + FIELD_TYPE.NEWDECIMAL: DecimalField, + FIELD_TYPE.SHORT: IntegerField, + FIELD_TYPE.STRING: CharField, + FIELD_TYPE.TIMESTAMP: DateTimeField, + FIELD_TYPE.TIME: TimeField, + FIELD_TYPE.TINY_BLOB: TextField, + FIELD_TYPE.TINY: IntegerField, + FIELD_TYPE.VAR_STRING: CharField, + } + + def __init__(self, database, **kwargs): + if 'password' in kwargs: + kwargs['passwd'] = kwargs.pop('password') + super(MySQLMetadata, self).__init__(database, **kwargs) + + def get_column_types(self, table, schema=None): + column_types = {} + + # Look up the actual column type for each column. + cursor = self.execute('SELECT * FROM `%s` LIMIT 1' % table) + + # Store column metadata in dictionary keyed by column name. + for column_description in cursor.description: + name, type_code = column_description[:2] + column_types[name] = self.column_map.get(type_code, UnknownField) + + return column_types, {} + + +class SqliteMetadata(Metadata): + column_map = { + 'bigint': BigIntegerField, + 'blob': BlobField, + 'bool': BooleanField, + 'boolean': BooleanField, + 'char': CharField, + 'date': DateField, + 'datetime': DateTimeField, + 'decimal': DecimalField, + 'float': FloatField, + 'integer': IntegerField, + 'integer unsigned': IntegerField, + 'int': IntegerField, + 'long': BigIntegerField, + 'numeric': DecimalField, + 'real': FloatField, + 'smallinteger': IntegerField, + 'smallint': IntegerField, + 'smallint unsigned': IntegerField, + 'text': TextField, + 'time': TimeField, + 'varchar': CharField, + } + + begin = '(?:["\[\(]+)?' + end = '(?:["\]\)]+)?' + re_foreign_key = ( + '(?:FOREIGN KEY\s*)?' + '{begin}(.+?){end}\s+(?:.+\s+)?' + 'references\s+{begin}(.+?){end}' + '\s*\(["|\[]?(.+?)["|\]]?\)').format(begin=begin, end=end) + re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$' + + def _map_col(self, column_type): + raw_column_type = column_type.lower() + if raw_column_type in self.column_map: + field_class = self.column_map[raw_column_type] + elif re.search(self.re_varchar, raw_column_type): + field_class = CharField + else: + column_type = re.sub('\(.+\)', '', raw_column_type) + if column_type == '': + field_class = BareField + else: + field_class = self.column_map.get(column_type, UnknownField) + return field_class + + def get_column_types(self, table, schema=None): + column_types = {} + columns = self.database.get_columns(table) + + for column in columns: + column_types[column.name] = self._map_col(column.data_type) + + return column_types, {} + + +_DatabaseMetadata = namedtuple('_DatabaseMetadata', ( + 'columns', + 'primary_keys', + 'foreign_keys', + 'model_names', + 'indexes')) + + +class DatabaseMetadata(_DatabaseMetadata): + def multi_column_indexes(self, table): + accum = [] + for index in self.indexes[table]: + if len(index.columns) > 1: + field_names = [self.columns[table][column].name + for column in index.columns + if column in self.columns[table]] + accum.append((field_names, index.unique)) + return accum + + def column_indexes(self, table): + accum = {} + for index in self.indexes[table]: + if len(index.columns) == 1: + accum[index.columns[0]] = index.unique + return accum + + +class Introspector(object): + pk_classes = [AutoField, IntegerField] + + def __init__(self, metadata, schema=None): + self.metadata = metadata + self.schema = schema + + def __repr__(self): + return '' % self.metadata.database + + @classmethod + def from_database(cls, database, schema=None): + if CockroachDatabase and isinstance(database, CockroachDatabase): + metadata = CockroachDBMetadata(database) + elif isinstance(database, PostgresqlDatabase): + metadata = PostgresqlMetadata(database) + elif isinstance(database, MySQLDatabase): + metadata = MySQLMetadata(database) + elif isinstance(database, SqliteDatabase): + metadata = SqliteMetadata(database) + else: + raise ValueError('Introspection not supported for %r' % database) + return cls(metadata, schema=schema) + + def get_database_class(self): + return type(self.metadata.database) + + def get_database_name(self): + return self.metadata.database.database + + def get_database_kwargs(self): + return self.metadata.database.connect_params + + def get_additional_imports(self): + if self.metadata.requires_extension: + return '\n' + self.metadata.extension_import + return '' + + def make_model_name(self, table, snake_case=True): + if snake_case: + table = make_snake_case(table) + model = re.sub('[^\w]+', '', table) + model_name = ''.join(sub.title() for sub in model.split('_')) + if not model_name[0].isalpha(): + model_name = 'T' + model_name + return model_name + + def make_column_name(self, column, is_foreign_key=False, snake_case=True): + column = column.strip() + if snake_case: + column = make_snake_case(column) + column = column.lower() + if is_foreign_key: + # Strip "_id" from foreign keys, unless the foreign-key happens to + # be named "_id", in which case the name is retained. + column = re.sub('_id$', '', column) or column + + # Remove characters that are invalid for Python identifiers. + column = re.sub('[^\w]+', '_', column) + if column in RESERVED_WORDS: + column += '_' + if len(column) and column[0].isdigit(): + column = '_' + column + return column + + def introspect(self, table_names=None, literal_column_names=False, + include_views=False, snake_case=True): + # Retrieve all the tables in the database. + tables = self.metadata.database.get_tables(schema=self.schema) + if include_views: + views = self.metadata.database.get_views(schema=self.schema) + tables.extend([view.name for view in views]) + + if table_names is not None: + tables = [table for table in tables if table in table_names] + table_set = set(tables) + + # Store a mapping of table name -> dictionary of columns. + columns = {} + + # Store a mapping of table name -> set of primary key columns. + primary_keys = {} + + # Store a mapping of table -> foreign keys. + foreign_keys = {} + + # Store a mapping of table name -> model name. + model_names = {} + + # Store a mapping of table name -> indexes. + indexes = {} + + # Gather the columns for each table. + for table in tables: + table_indexes = self.metadata.get_indexes(table, self.schema) + table_columns = self.metadata.get_columns(table, self.schema) + try: + foreign_keys[table] = self.metadata.get_foreign_keys( + table, self.schema) + except ValueError as exc: + err(*exc.args) + foreign_keys[table] = [] + else: + # If there is a possibility we could exclude a dependent table, + # ensure that we introspect it so FKs will work. + if table_names is not None: + for foreign_key in foreign_keys[table]: + if foreign_key.dest_table not in table_set: + tables.append(foreign_key.dest_table) + table_set.add(foreign_key.dest_table) + + model_names[table] = self.make_model_name(table, snake_case) + + # Collect sets of all the column names as well as all the + # foreign-key column names. + lower_col_names = set(column_name.lower() + for column_name in table_columns) + fks = set(fk_col.column for fk_col in foreign_keys[table]) + + for col_name, column in table_columns.items(): + if literal_column_names: + new_name = re.sub('[^\w]+', '_', col_name) + else: + new_name = self.make_column_name(col_name, col_name in fks, + snake_case) + + # If we have two columns, "parent" and "parent_id", ensure + # that when we don't introduce naming conflicts. + lower_name = col_name.lower() + if lower_name.endswith('_id') and new_name in lower_col_names: + new_name = col_name.lower() + + column.name = new_name + + for index in table_indexes: + if len(index.columns) == 1: + column = index.columns[0] + if column in table_columns: + table_columns[column].unique = index.unique + table_columns[column].index = True + + primary_keys[table] = self.metadata.get_primary_keys( + table, self.schema) + columns[table] = table_columns + indexes[table] = table_indexes + + # Gather all instances where we might have a `related_name` conflict, + # either due to multiple FKs on a table pointing to the same table, + # or a related_name that would conflict with an existing field. + related_names = {} + sort_fn = lambda foreign_key: foreign_key.column + for table in tables: + models_referenced = set() + for foreign_key in sorted(foreign_keys[table], key=sort_fn): + try: + column = columns[table][foreign_key.column] + except KeyError: + continue + + dest_table = foreign_key.dest_table + if dest_table in models_referenced: + related_names[column] = '%s_%s_set' % ( + dest_table, + column.name) + else: + models_referenced.add(dest_table) + + # On the second pass convert all foreign keys. + for table in tables: + for foreign_key in foreign_keys[table]: + src = columns[foreign_key.table][foreign_key.column] + try: + dest = columns[foreign_key.dest_table][ + foreign_key.dest_column] + except KeyError: + dest = None + + src.set_foreign_key( + foreign_key=foreign_key, + model_names=model_names, + dest=dest, + related_name=related_names.get(src)) + + return DatabaseMetadata( + columns, + primary_keys, + foreign_keys, + model_names, + indexes) + + def generate_models(self, skip_invalid=False, table_names=None, + literal_column_names=False, bare_fields=False, + include_views=False): + database = self.introspect(table_names, literal_column_names, + include_views) + models = {} + + class BaseModel(Model): + class Meta: + database = self.metadata.database + schema = self.schema + + def _create_model(table, models): + for foreign_key in database.foreign_keys[table]: + dest = foreign_key.dest_table + + if dest not in models and dest != table: + _create_model(dest, models) + + primary_keys = [] + columns = database.columns[table] + for column_name, column in columns.items(): + if column.primary_key: + primary_keys.append(column.name) + + multi_column_indexes = database.multi_column_indexes(table) + column_indexes = database.column_indexes(table) + + class Meta: + indexes = multi_column_indexes + table_name = table + + # Fix models with multi-column primary keys. + composite_key = False + if len(primary_keys) == 0: + primary_keys = columns.keys() + if len(primary_keys) > 1: + Meta.primary_key = CompositeKey(*[ + field.name for col, field in columns.items() + if col in primary_keys]) + composite_key = True + + attrs = {'Meta': Meta} + for column_name, column in columns.items(): + FieldClass = column.field_class + if FieldClass is not ForeignKeyField and bare_fields: + FieldClass = BareField + elif FieldClass is UnknownField: + FieldClass = BareField + + params = { + 'column_name': column_name, + 'null': column.nullable} + if column.primary_key and composite_key: + if FieldClass is AutoField: + FieldClass = IntegerField + params['primary_key'] = False + elif column.primary_key and FieldClass is not AutoField: + params['primary_key'] = True + if column.is_foreign_key(): + if column.is_self_referential_fk(): + params['model'] = 'self' + else: + dest_table = column.foreign_key.dest_table + params['model'] = models[dest_table] + if column.to_field: + params['field'] = column.to_field + + # Generate a unique related name. + params['backref'] = '%s_%s_rel' % (table, column_name) + + if column.default is not None: + constraint = SQL('DEFAULT %s' % column.default) + params['constraints'] = [constraint] + + if column_name in column_indexes and not \ + column.is_primary_key(): + if column_indexes[column_name]: + params['unique'] = True + elif not column.is_foreign_key(): + params['index'] = True + + attrs[column.name] = FieldClass(**params) + + try: + models[table] = type(str(table), (BaseModel,), attrs) + except ValueError: + if not skip_invalid: + raise + + # Actually generate Model classes. + for table, model in sorted(database.model_names.items()): + if table not in models: + _create_model(table, models) + + return models + + +def introspect(database, schema=None): + introspector = Introspector.from_database(database, schema=schema) + return introspector.introspect() + + +def generate_models(database, schema=None, **options): + introspector = Introspector.from_database(database, schema=schema) + return introspector.generate_models(**options) + + +def print_model(model, indexes=True, inline_indexes=False): + print(model._meta.name) + for field in model._meta.sorted_fields: + parts = [' %s %s' % (field.name, field.field_type)] + if field.primary_key: + parts.append(' PK') + elif inline_indexes: + if field.unique: + parts.append(' UNIQUE') + elif field.index: + parts.append(' INDEX') + if isinstance(field, ForeignKeyField): + parts.append(' FK: %s.%s' % (field.rel_model.__name__, + field.rel_field.name)) + print(''.join(parts)) + + if indexes: + index_list = model._meta.fields_to_index() + if not index_list: + return + + print('\nindex(es)') + for index in index_list: + parts = [' '] + ctx = model._meta.database.get_sql_context() + with ctx.scope_values(param='%s', quote='""'): + ctx.sql(CommaNodeList(index._expressions)) + if index._where: + ctx.literal(' WHERE ') + ctx.sql(index._where) + sql, params = ctx.query() + + clean = sql % tuple(map(_query_val_transform, params)) + parts.append(clean.replace('"', '')) + + if index._unique: + parts.append(' UNIQUE') + print(''.join(parts)) + + +def get_table_sql(model): + sql, params = model._schema._create_table().query() + if model._meta.database.param != '%s': + sql = sql.replace(model._meta.database.param, '%s') + + # Format and indent the table declaration, simplest possible approach. + match_obj = re.match('^(.+?\()(.+)(\).*)', sql) + create, columns, extra = match_obj.groups() + indented = ',\n'.join(' %s' % column for column in columns.split(', ')) + + clean = '\n'.join((create, indented, extra)).strip() + return clean % tuple(map(_query_val_transform, params)) + +def print_table_sql(model): + print(get_table_sql(model)) diff --git a/python2.7libs/playhouse/shortcuts.py b/python2.7libs/playhouse/shortcuts.py new file mode 100644 index 0000000..1772cf1 --- /dev/null +++ b/python2.7libs/playhouse/shortcuts.py @@ -0,0 +1,228 @@ +from peewee import * +from peewee import Alias +from peewee import SENTINEL +from peewee import callable_ + + +_clone_set = lambda s: set(s) if s else set() + + +def model_to_dict(model, recurse=True, backrefs=False, only=None, + exclude=None, seen=None, extra_attrs=None, + fields_from_query=None, max_depth=None, manytomany=False): + """ + Convert a model instance (and any related objects) to a dictionary. + + :param bool recurse: Whether foreign-keys should be recursed. + :param bool backrefs: Whether lists of related objects should be recursed. + :param only: A list (or set) of field instances indicating which fields + should be included. + :param exclude: A list (or set) of field instances that should be + excluded from the dictionary. + :param list extra_attrs: Names of model instance attributes or methods + that should be included. + :param SelectQuery fields_from_query: Query that was source of model. Take + fields explicitly selected by the query and serialize them. + :param int max_depth: Maximum depth to recurse, value <= 0 means no max. + :param bool manytomany: Process many-to-many fields. + """ + max_depth = -1 if max_depth is None else max_depth + if max_depth == 0: + recurse = False + + only = _clone_set(only) + extra_attrs = _clone_set(extra_attrs) + should_skip = lambda n: (n in exclude) or (only and (n not in only)) + + if fields_from_query is not None: + for item in fields_from_query._returning: + if isinstance(item, Field): + only.add(item) + elif isinstance(item, Alias): + extra_attrs.add(item._alias) + + data = {} + exclude = _clone_set(exclude) + seen = _clone_set(seen) + exclude |= seen + model_class = type(model) + + if manytomany: + for name, m2m in model._meta.manytomany.items(): + if should_skip(name): + continue + + exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref])) + for fkf in m2m.through_model._meta.refs: + exclude.add(fkf) + + accum = [] + for rel_obj in getattr(model, name): + accum.append(model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + max_depth=max_depth - 1)) + data[name] = accum + + for field in model._meta.sorted_fields: + if should_skip(field): + continue + + field_data = model.__data__.get(field.name) + if isinstance(field, ForeignKeyField) and recurse: + if field_data is not None: + seen.add(field) + rel_obj = getattr(model, field.name) + field_data = model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + seen=seen, + max_depth=max_depth - 1) + else: + field_data = None + + data[field.name] = field_data + + if extra_attrs: + for attr_name in extra_attrs: + attr = getattr(model, attr_name) + if callable_(attr): + data[attr_name] = attr() + else: + data[attr_name] = attr + + if backrefs and recurse: + for foreign_key, rel_model in model._meta.backrefs.items(): + if foreign_key.backref == '+': continue + descriptor = getattr(model_class, foreign_key.backref) + if descriptor in exclude or foreign_key in exclude: + continue + if only and (descriptor not in only) and (foreign_key not in only): + continue + + accum = [] + exclude.add(foreign_key) + related_query = getattr(model, foreign_key.backref) + + for rel_obj in related_query: + accum.append(model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + max_depth=max_depth - 1)) + + data[foreign_key.backref] = accum + + return data + + +def update_model_from_dict(instance, data, ignore_unknown=False): + meta = instance._meta + backrefs = dict([(fk.backref, fk) for fk in meta.backrefs]) + + for key, value in data.items(): + if key in meta.combined: + field = meta.combined[key] + is_backref = False + elif key in backrefs: + field = backrefs[key] + is_backref = True + elif ignore_unknown: + setattr(instance, key, value) + continue + else: + raise AttributeError('Unrecognized attribute "%s" for model ' + 'class %s.' % (key, type(instance))) + + is_foreign_key = isinstance(field, ForeignKeyField) + + if not is_backref and is_foreign_key and isinstance(value, dict): + try: + rel_instance = instance.__rel__[field.name] + except KeyError: + rel_instance = field.rel_model() + setattr( + instance, + field.name, + update_model_from_dict(rel_instance, value, ignore_unknown)) + elif is_backref and isinstance(value, (list, tuple)): + instances = [ + dict_to_model(field.model, row_data, ignore_unknown) + for row_data in value] + for rel_instance in instances: + setattr(rel_instance, field.name, instance) + setattr(instance, field.backref, instances) + else: + setattr(instance, field.name, value) + + return instance + + +def dict_to_model(model_class, data, ignore_unknown=False): + return update_model_from_dict(model_class(), data, ignore_unknown) + + +class ReconnectMixin(object): + """ + Mixin class that attempts to automatically reconnect to the database under + certain error conditions. + + For example, MySQL servers will typically close connections that are idle + for 28800 seconds ("wait_timeout" setting). If your application makes use + of long-lived connections, you may find your connections are closed after + a period of no activity. This mixin will attempt to reconnect automatically + when these errors occur. + + This mixin class probably should not be used with Postgres (unless you + REALLY know what you are doing) and definitely has no business being used + with Sqlite. If you wish to use with Postgres, you will need to adapt the + `reconnect_errors` attribute to something appropriate for Postgres. + """ + reconnect_errors = ( + # Error class, error message fragment (or empty string for all). + (OperationalError, '2006'), # MySQL server has gone away. + (OperationalError, '2013'), # Lost connection to MySQL server. + (OperationalError, '2014'), # Commands out of sync. + + # mysql-connector raises a slightly different error when an idle + # connection is terminated by the server. This is equivalent to 2013. + (OperationalError, 'MySQL Connection not available.'), + ) + + def __init__(self, *args, **kwargs): + super(ReconnectMixin, self).__init__(*args, **kwargs) + + # Normalize the reconnect errors to a more efficient data-structure. + self._reconnect_errors = {} + for exc_class, err_fragment in self.reconnect_errors: + self._reconnect_errors.setdefault(exc_class, []) + self._reconnect_errors[exc_class].append(err_fragment.lower()) + + def execute_sql(self, sql, params=None, commit=SENTINEL): + try: + return super(ReconnectMixin, self).execute_sql(sql, params, commit) + except Exception as exc: + exc_class = type(exc) + if exc_class not in self._reconnect_errors: + raise exc + + exc_repr = str(exc).lower() + for err_fragment in self._reconnect_errors[exc_class]: + if err_fragment in exc_repr: + break + else: + raise exc + + if not self.is_closed(): + self.close() + self.connect() + + return super(ReconnectMixin, self).execute_sql(sql, params, commit) diff --git a/python2.7libs/playhouse/signals.py b/python2.7libs/playhouse/signals.py new file mode 100644 index 0000000..4e92872 --- /dev/null +++ b/python2.7libs/playhouse/signals.py @@ -0,0 +1,79 @@ +""" +Provide django-style hooks for model events. +""" +from peewee import Model as _Model + + +class Signal(object): + def __init__(self): + self._flush() + + def _flush(self): + self._receivers = set() + self._receiver_list = [] + + def connect(self, receiver, name=None, sender=None): + name = name or receiver.__name__ + key = (name, sender) + if key not in self._receivers: + self._receivers.add(key) + self._receiver_list.append((name, receiver, sender)) + else: + raise ValueError('receiver named %s (for sender=%s) already ' + 'connected' % (name, sender or 'any')) + + def disconnect(self, receiver=None, name=None, sender=None): + if receiver: + name = name or receiver.__name__ + if not name: + raise ValueError('a receiver or a name must be provided') + + key = (name, sender) + if key not in self._receivers: + raise ValueError('receiver named %s for sender=%s not found.' % + (name, sender or 'any')) + + self._receivers.remove(key) + self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list + if n != name and s != sender] + + def __call__(self, name=None, sender=None): + def decorator(fn): + self.connect(fn, name, sender) + return fn + return decorator + + def send(self, instance, *args, **kwargs): + sender = type(instance) + responses = [] + for n, r, s in self._receiver_list: + if s is None or isinstance(instance, s): + responses.append((r, r(sender, instance, *args, **kwargs))) + return responses + + +pre_save = Signal() +post_save = Signal() +pre_delete = Signal() +post_delete = Signal() +pre_init = Signal() + + +class Model(_Model): + def __init__(self, *args, **kwargs): + super(Model, self).__init__(*args, **kwargs) + pre_init.send(self) + + def save(self, *args, **kwargs): + pk_value = self._pk if self._meta.primary_key else True + created = kwargs.get('force_insert', False) or not bool(pk_value) + pre_save.send(self, created=created) + ret = super(Model, self).save(*args, **kwargs) + post_save.send(self, created=created) + return ret + + def delete_instance(self, *args, **kwargs): + pre_delete.send(self) + ret = super(Model, self).delete_instance(*args, **kwargs) + post_delete.send(self) + return ret diff --git a/python2.7libs/playhouse/sqlcipher_ext.py b/python2.7libs/playhouse/sqlcipher_ext.py new file mode 100644 index 0000000..9bad1ec --- /dev/null +++ b/python2.7libs/playhouse/sqlcipher_ext.py @@ -0,0 +1,103 @@ +""" +Peewee integration with pysqlcipher. + +Project page: https://github.com/leapcode/pysqlcipher/ + +**WARNING!!! EXPERIMENTAL!!!** + +* Although this extention's code is short, it has not been properly + peer-reviewed yet and may have introduced vulnerabilities. + +Also note that this code relies on pysqlcipher and sqlcipher, and +the code there might have vulnerabilities as well, but since these +are widely used crypto modules, we can expect "short zero days" there. + +Example usage: + + from peewee.playground.ciphersql_ext import SqlCipherDatabase + db = SqlCipherDatabase('/path/to/my.db', passphrase="don'tuseme4real") + +* `passphrase`: should be "long enough". + Note that *length beats vocabulary* (much exponential), and even + a lowercase-only passphrase like easytorememberyethardforotherstoguess + packs more noise than 8 random printable characters and *can* be memorized. + +When opening an existing database, passphrase should be the one used when the +database was created. If the passphrase is incorrect, an exception will only be +raised **when you access the database**. + +If you need to ask for an interactive passphrase, here's example code you can +put after the `db = ...` line: + + try: # Just access the database so that it checks the encryption. + db.get_tables() + # We're looking for a DatabaseError with a specific error message. + except peewee.DatabaseError as e: + # Check whether the message *means* "passphrase is wrong" + if e.args[0] == 'file is encrypted or is not a database': + raise Exception('Developer should Prompt user for passphrase ' + 'again.') + else: + # A different DatabaseError. Raise it. + raise e + +See a more elaborate example with this code at +https://gist.github.com/thedod/11048875 +""" +import datetime +import decimal +import sys + +from peewee import * +from playhouse.sqlite_ext import SqliteExtDatabase +if sys.version_info[0] != 3: + from pysqlcipher import dbapi2 as sqlcipher +else: + try: + from sqlcipher3 import dbapi2 as sqlcipher + except ImportError: + from pysqlcipher3 import dbapi2 as sqlcipher + +sqlcipher.register_adapter(decimal.Decimal, str) +sqlcipher.register_adapter(datetime.date, str) +sqlcipher.register_adapter(datetime.time, str) + + +class _SqlCipherDatabase(object): + def _connect(self): + params = dict(self.connect_params) + passphrase = params.pop('passphrase', '').replace("'", "''") + + conn = sqlcipher.connect(self.database, isolation_level=None, **params) + try: + if passphrase: + conn.execute("PRAGMA key='%s'" % passphrase) + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def set_passphrase(self, passphrase): + if not self.is_closed(): + raise ImproperlyConfigured('Cannot set passphrase when database ' + 'is open. To change passphrase of an ' + 'open database use the rekey() method.') + + self.connect_params['passphrase'] = passphrase + + def rekey(self, passphrase): + if self.is_closed(): + self.connect() + + self.execute_sql("PRAGMA rekey='%s'" % passphrase.replace("'", "''")) + self.connect_params['passphrase'] = passphrase + return True + + +class SqlCipherDatabase(_SqlCipherDatabase, SqliteDatabase): + pass + + +class SqlCipherExtDatabase(_SqlCipherDatabase, SqliteExtDatabase): + pass diff --git a/python2.7libs/playhouse/sqlite_changelog.py b/python2.7libs/playhouse/sqlite_changelog.py new file mode 100644 index 0000000..b036af2 --- /dev/null +++ b/python2.7libs/playhouse/sqlite_changelog.py @@ -0,0 +1,123 @@ +from peewee import * +from playhouse.sqlite_ext import JSONField + + +class BaseChangeLog(Model): + timestamp = DateTimeField(constraints=[SQL('DEFAULT CURRENT_TIMESTAMP')]) + action = TextField() + table = TextField() + primary_key = IntegerField() + changes = JSONField() + + +class ChangeLog(object): + # Model class that will serve as the base for the changelog. This model + # will be subclassed and mapped to your application database. + base_model = BaseChangeLog + + # Template for the triggers that handle updating the changelog table. + # table: table name + # action: insert / update / delete + # new_old: NEW or OLD (OLD is for DELETE) + # primary_key: table primary key column name + # column_array: output of build_column_array() + # change_table: changelog table name + template = """CREATE TRIGGER IF NOT EXISTS %(table)s_changes_%(action)s + AFTER %(action)s ON %(table)s + BEGIN + INSERT INTO %(change_table)s + ("action", "table", "primary_key", "changes") + SELECT + '%(action)s', '%(table)s', %(new_old)s."%(primary_key)s", "changes" + FROM ( + SELECT json_group_object( + col, + json_array("oldval", "newval")) AS "changes" + FROM ( + SELECT json_extract(value, '$[0]') as "col", + json_extract(value, '$[1]') as "oldval", + json_extract(value, '$[2]') as "newval" + FROM json_each(json_array(%(column_array)s)) + WHERE "oldval" IS NOT "newval" + ) + ); + END;""" + + drop_template = 'DROP TRIGGER IF EXISTS %(table)s_changes_%(action)s' + + _actions = ('INSERT', 'UPDATE', 'DELETE') + + def __init__(self, db, table_name='changelog'): + self.db = db + self.table_name = table_name + + def _build_column_array(self, model, use_old, use_new, skip_fields=None): + # Builds a list of SQL expressions for each field we are tracking. This + # is used as the data source for change tracking in our trigger. + col_array = [] + for field in model._meta.sorted_fields: + if field.primary_key: + continue + + if skip_fields is not None and field.name in skip_fields: + continue + + column = field.column_name + new = 'NULL' if not use_new else 'NEW."%s"' % column + old = 'NULL' if not use_old else 'OLD."%s"' % column + + if isinstance(field, JSONField): + # Ensure that values are cast to JSON so that the serialization + # is preserved when calculating the old / new. + if use_old: old = 'json(%s)' % old + if use_new: new = 'json(%s)' % new + + col_array.append("json_array('%s', %s, %s)" % (column, old, new)) + + return ', '.join(col_array) + + def trigger_sql(self, model, action, skip_fields=None): + assert action in self._actions + use_old = action != 'INSERT' + use_new = action != 'DELETE' + cols = self._build_column_array(model, use_old, use_new, skip_fields) + return self.template % { + 'table': model._meta.table_name, + 'action': action, + 'new_old': 'NEW' if action != 'DELETE' else 'OLD', + 'primary_key': model._meta.primary_key.column_name, + 'column_array': cols, + 'change_table': self.table_name} + + def drop_trigger_sql(self, model, action): + assert action in self._actions + return self.drop_template % { + 'table': model._meta.table_name, + 'action': action} + + @property + def model(self): + if not hasattr(self, '_changelog_model'): + class ChangeLog(self.base_model): + class Meta: + database = self.db + table_name = self.table_name + self._changelog_model = ChangeLog + + return self._changelog_model + + def install(self, model, skip_fields=None, drop=True, insert=True, + update=True, delete=True, create_table=True): + ChangeLog = self.model + if create_table: + ChangeLog.create_table() + + actions = list(zip((insert, update, delete), self._actions)) + if drop: + for _, action in actions: + self.db.execute_sql(self.drop_trigger_sql(model, action)) + + for enabled, action in actions: + if enabled: + sql = self.trigger_sql(model, action, skip_fields) + self.db.execute_sql(sql) diff --git a/python2.7libs/playhouse/sqlite_ext.py b/python2.7libs/playhouse/sqlite_ext.py new file mode 100644 index 0000000..36d9425 --- /dev/null +++ b/python2.7libs/playhouse/sqlite_ext.py @@ -0,0 +1,1294 @@ +import json +import math +import re +import struct +import sys + +from peewee import * +from peewee import ColumnBase +from peewee import EnclosedNodeList +from peewee import Entity +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import OP +from peewee import VirtualField +from peewee import merge_dict +from peewee import sqlite3 +try: + from playhouse._sqlite_ext import ( + backup, + backup_to_file, + Blob, + ConnectionHelper, + register_bloomfilter, + register_hash_functions, + register_rank_functions, + sqlite_get_db_status, + sqlite_get_status, + TableFunction, + ZeroBlob, + ) + CYTHON_SQLITE_EXTENSIONS = True +except ImportError: + CYTHON_SQLITE_EXTENSIONS = False + + +if sys.version_info[0] == 3: + basestring = str + + +FTS3_MATCHINFO = 'pcx' +FTS4_MATCHINFO = 'pcnalx' +if sqlite3 is not None: + FTS_VERSION = 4 if sqlite3.sqlite_version_info[:3] >= (3, 7, 4) else 3 +else: + FTS_VERSION = 3 + +FTS5_MIN_SQLITE_VERSION = (3, 9, 0) + + +class RowIDField(AutoField): + auto_increment = True + column_name = name = required_name = 'rowid' + + def bind(self, model, name, *args): + if name != self.required_name: + raise ValueError('%s must be named "%s".' % + (type(self), self.required_name)) + super(RowIDField, self).bind(model, name, *args) + + +class DocIDField(RowIDField): + column_name = name = required_name = 'docid' + + +class AutoIncrementField(AutoField): + def ddl(self, ctx): + node_list = super(AutoIncrementField, self).ddl(ctx) + return NodeList((node_list, SQL('AUTOINCREMENT'))) + + +class TDecimalField(DecimalField): + field_type = 'TEXT' + def get_modifiers(self): pass + + +class JSONPath(ColumnBase): + def __init__(self, field, path=None): + super(JSONPath, self).__init__() + self._field = field + self._path = path or () + + @property + def path(self): + return Value('$%s' % ''.join(self._path)) + + def __getitem__(self, idx): + if isinstance(idx, int): + item = '[%s]' % idx + else: + item = '.%s' % idx + return JSONPath(self._field, self._path + (item,)) + + def set(self, value, as_json=None): + if as_json or isinstance(value, (list, dict)): + value = fn.json(self._field._json_dumps(value)) + return fn.json_set(self._field, self.path, value) + + def update(self, value): + return self.set(fn.json_patch(self, self._field._json_dumps(value))) + + def remove(self): + return fn.json_remove(self._field, self.path) + + def json_type(self): + return fn.json_type(self._field, self.path) + + def length(self): + return fn.json_array_length(self._field, self.path) + + def children(self): + return fn.json_each(self._field, self.path) + + def tree(self): + return fn.json_tree(self._field, self.path) + + def __sql__(self, ctx): + return ctx.sql(fn.json_extract(self._field, self.path) + if self._path else self._field) + + +class JSONField(TextField): + field_type = 'JSON' + unpack = False + + def __init__(self, json_dumps=None, json_loads=None, **kwargs): + self._json_dumps = json_dumps or json.dumps + self._json_loads = json_loads or json.loads + super(JSONField, self).__init__(**kwargs) + + def python_value(self, value): + if value is not None: + try: + return self._json_loads(value) + except (TypeError, ValueError): + return value + + def db_value(self, value): + if value is not None: + if not isinstance(value, Node): + value = fn.json(self._json_dumps(value)) + return value + + def _e(op): + def inner(self, rhs): + if isinstance(rhs, (list, dict)): + rhs = Value(rhs, converter=self.db_value, unpack=False) + return Expression(self, op, rhs) + return inner + __eq__ = _e(OP.EQ) + __ne__ = _e(OP.NE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __hash__ = Field.__hash__ + + def __getitem__(self, item): + return JSONPath(self)[item] + + def set(self, value, as_json=None): + return JSONPath(self).set(value, as_json) + + def update(self, data): + return JSONPath(self).update(data) + + def remove(self): + return JSONPath(self).remove() + + def json_type(self): + return fn.json_type(self) + + def length(self): + return fn.json_array_length(self) + + def children(self): + """ + Schema of `json_each` and `json_tree`: + + key, + value, + type TEXT (object, array, string, etc), + atom (value for primitive/scalar types, NULL for array and object) + id INTEGER (unique identifier for element) + parent INTEGER (unique identifier of parent element or NULL) + fullkey TEXT (full path describing element) + path TEXT (path to the container of the current element) + json JSON hidden (1st input parameter to function) + root TEXT hidden (2nd input parameter, path at which to start) + """ + return fn.json_each(self) + + def tree(self): + return fn.json_tree(self) + + +class SearchField(Field): + def __init__(self, unindexed=False, column_name=None, **k): + if k: + raise ValueError('SearchField does not accept these keyword ' + 'arguments: %s.' % sorted(k)) + super(SearchField, self).__init__(unindexed=unindexed, + column_name=column_name, null=True) + + def match(self, term): + return match(self, term) + + +class VirtualTableSchemaManager(SchemaManager): + def _create_virtual_table(self, safe=True, **options): + options = self.model.clean_options( + merge_dict(self.model._meta.options, options)) + + # Structure: + # CREATE VIRTUAL TABLE + # USING + # ([prefix_arguments, ...] fields, ... [arguments, ...], [options...]) + ctx = self._create_context() + ctx.literal('CREATE VIRTUAL TABLE ') + if safe: + ctx.literal('IF NOT EXISTS ') + (ctx + .sql(self.model) + .literal(' USING ')) + + ext_module = self.model._meta.extension_module + if isinstance(ext_module, Node): + return ctx.sql(ext_module) + + ctx.sql(SQL(ext_module)).literal(' ') + arguments = [] + meta = self.model._meta + + if meta.prefix_arguments: + arguments.extend([SQL(a) for a in meta.prefix_arguments]) + + # Constraints, data-types, foreign and primary keys are all omitted. + for field in meta.sorted_fields: + if isinstance(field, (RowIDField)) or field._hidden: + continue + field_def = [Entity(field.column_name)] + if field.unindexed: + field_def.append(SQL('UNINDEXED')) + arguments.append(NodeList(field_def)) + + if meta.arguments: + arguments.extend([SQL(a) for a in meta.arguments]) + + if options: + arguments.extend(self._create_table_option_sql(options)) + return ctx.sql(EnclosedNodeList(arguments)) + + def _create_table(self, safe=True, **options): + if issubclass(self.model, VirtualModel): + return self._create_virtual_table(safe, **options) + + return super(VirtualTableSchemaManager, self)._create_table( + safe, **options) + + +class VirtualModel(Model): + class Meta: + arguments = None + extension_module = None + prefix_arguments = None + primary_key = False + schema_manager_class = VirtualTableSchemaManager + + @classmethod + def clean_options(cls, options): + return options + + +class BaseFTSModel(VirtualModel): + @classmethod + def clean_options(cls, options): + content = options.get('content') + prefix = options.get('prefix') + tokenize = options.get('tokenize') + + if isinstance(content, basestring) and content == '': + # Special-case content-less full-text search tables. + options['content'] = "''" + elif isinstance(content, Field): + # Special-case to ensure fields are fully-qualified. + options['content'] = Entity(content.model._meta.table_name, + content.column_name) + + if prefix: + if isinstance(prefix, (list, tuple)): + prefix = ','.join([str(i) for i in prefix]) + options['prefix'] = "'%s'" % prefix.strip("' ") + + if tokenize and cls._meta.extension_module.lower() == 'fts5': + # Tokenizers need to be in quoted string for FTS5, but not for FTS3 + # or FTS4. + options['tokenize'] = '"%s"' % tokenize + + return options + + +class FTSModel(BaseFTSModel): + """ + VirtualModel class for creating tables that use either the FTS3 or FTS4 + search extensions. Peewee automatically determines which version of the + FTS extension is supported and will use FTS4 if possible. + """ + # FTS3/4 uses "docid" in the same way a normal table uses "rowid". + docid = DocIDField() + + class Meta: + extension_module = 'FTS%s' % FTS_VERSION + + @classmethod + def _fts_cmd(cls, cmd): + tbl = cls._meta.table_name + res = cls._meta.database.execute_sql( + "INSERT INTO %s(%s) VALUES('%s');" % (tbl, tbl, cmd)) + return res.fetchone() + + @classmethod + def optimize(cls): + return cls._fts_cmd('optimize') + + @classmethod + def rebuild(cls): + return cls._fts_cmd('rebuild') + + @classmethod + def integrity_check(cls): + return cls._fts_cmd('integrity-check') + + @classmethod + def merge(cls, blocks=200, segments=8): + return cls._fts_cmd('merge=%s,%s' % (blocks, segments)) + + @classmethod + def automerge(cls, state=True): + return cls._fts_cmd('automerge=%s' % (state and '1' or '0')) + + @classmethod + def match(cls, term): + """ + Generate a `MATCH` expression appropriate for searching this table. + """ + return match(cls._meta.entity, term) + + @classmethod + def rank(cls, *weights): + matchinfo = fn.matchinfo(cls._meta.entity, FTS3_MATCHINFO) + return fn.fts_rank(matchinfo, *weights) + + @classmethod + def bm25(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_bm25(match_info, *weights) + + @classmethod + def bm25f(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_bm25f(match_info, *weights) + + @classmethod + def lucene(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_lucene(match_info, *weights) + + @classmethod + def _search(cls, term, weights, with_score, score_alias, score_fn, + explicit_ordering): + if not weights: + rank = score_fn() + elif isinstance(weights, dict): + weight_args = [] + for field in cls._meta.sorted_fields: + # Attempt to get the specified weight of the field by looking + # it up using it's field instance followed by name. + field_weight = weights.get(field, weights.get(field.name, 1.0)) + weight_args.append(field_weight) + rank = score_fn(*weight_args) + else: + rank = score_fn(*weights) + + selection = () + order_by = rank + if with_score: + selection = (cls, rank.alias(score_alias)) + if with_score and not explicit_ordering: + order_by = SQL(score_alias) + + return (cls + .select(*selection) + .where(cls.match(term)) + .order_by(order_by)) + + @classmethod + def search(cls, term, weights=None, with_score=False, score_alias='score', + explicit_ordering=False): + """Full-text search using selected `term`.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.rank, + explicit_ordering) + + @classmethod + def search_bm25(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.bm25, + explicit_ordering) + + @classmethod + def search_bm25f(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.bm25f, + explicit_ordering) + + @classmethod + def search_lucene(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.lucene, + explicit_ordering) + + +_alphabet = 'abcdefghijklmnopqrstuvwxyz' +_alphanum = (set('\t ,"(){}*:_+0123456789') | + set(_alphabet) | + set(_alphabet.upper()) | + set((chr(26),))) +_invalid_ascii = set(chr(p) for p in range(128) if chr(p) not in _alphanum) +_quote_re = re.compile('(?:[^\s"]|"(?:\\.|[^"])*")+') + + +class FTS5Model(BaseFTSModel): + """ + Requires SQLite >= 3.9.0. + + Table options: + + content: table name of external content, or empty string for "contentless" + content_rowid: column name of external content primary key + prefix: integer(s). Ex: '2' or '2 3 4' + tokenize: porter, unicode61, ascii. Ex: 'porter unicode61' + + The unicode tokenizer supports the following parameters: + + * remove_diacritics (1 or 0, default is 1) + * tokenchars (string of characters, e.g. '-_' + * separators (string of characters) + + Parameters are passed as alternating parameter name and value, so: + + {'tokenize': "unicode61 remove_diacritics 0 tokenchars '-_'"} + + Content-less tables: + + If you don't need the full-text content in it's original form, you can + specify a content-less table. Searches and auxiliary functions will work + as usual, but the only values returned when SELECT-ing can be rowid. Also + content-less tables do not support UPDATE or DELETE. + + External content tables: + + You can set up triggers to sync these, e.g. + + -- Create a table. And an external content fts5 table to index it. + CREATE TABLE tbl(a INTEGER PRIMARY KEY, b); + CREATE VIRTUAL TABLE ft USING fts5(b, content='tbl', content_rowid='a'); + + -- Triggers to keep the FTS index up to date. + CREATE TRIGGER tbl_ai AFTER INSERT ON tbl BEGIN + INSERT INTO ft(rowid, b) VALUES (new.a, new.b); + END; + CREATE TRIGGER tbl_ad AFTER DELETE ON tbl BEGIN + INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); + END; + CREATE TRIGGER tbl_au AFTER UPDATE ON tbl BEGIN + INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); + INSERT INTO ft(rowid, b) VALUES (new.a, new.b); + END; + + Built-in auxiliary functions: + + * bm25(tbl[, weight_0, ... weight_n]) + * highlight(tbl, col_idx, prefix, suffix) + * snippet(tbl, col_idx, prefix, suffix, ?, max_tokens) + """ + # FTS5 does not support declared primary keys, but we can use the + # implicit rowid. + rowid = RowIDField() + + class Meta: + extension_module = 'fts5' + + _error_messages = { + 'field_type': ('Besides the implicit `rowid` column, all columns must ' + 'be instances of SearchField'), + 'index': 'Secondary indexes are not supported for FTS5 models', + 'pk': 'FTS5 models must use the default `rowid` primary key', + } + + @classmethod + def validate_model(cls): + # Perform FTS5-specific validation and options post-processing. + if cls._meta.primary_key.name != 'rowid': + raise ImproperlyConfigured(cls._error_messages['pk']) + for field in cls._meta.fields.values(): + if not isinstance(field, (SearchField, RowIDField)): + raise ImproperlyConfigured(cls._error_messages['field_type']) + if cls._meta.indexes: + raise ImproperlyConfigured(cls._error_messages['index']) + + @classmethod + def fts5_installed(cls): + if sqlite3.sqlite_version_info[:3] < FTS5_MIN_SQLITE_VERSION: + return False + + # Test in-memory DB to determine if the FTS5 extension is installed. + tmp_db = sqlite3.connect(':memory:') + try: + tmp_db.execute('CREATE VIRTUAL TABLE fts5test USING fts5 (data);') + except: + try: + tmp_db.enable_load_extension(True) + tmp_db.load_extension('fts5') + except: + return False + else: + cls._meta.database.load_extension('fts5') + finally: + tmp_db.close() + + return True + + @staticmethod + def validate_query(query): + """ + Simple helper function to indicate whether a search query is a + valid FTS5 query. Note: this simply looks at the characters being + used, and is not guaranteed to catch all problematic queries. + """ + tokens = _quote_re.findall(query) + for token in tokens: + if token.startswith('"') and token.endswith('"'): + continue + if set(token) & _invalid_ascii: + return False + return True + + @staticmethod + def clean_query(query, replace=chr(26)): + """ + Clean a query of invalid tokens. + """ + accum = [] + any_invalid = False + tokens = _quote_re.findall(query) + for token in tokens: + if token.startswith('"') and token.endswith('"'): + accum.append(token) + continue + token_set = set(token) + invalid_for_token = token_set & _invalid_ascii + if invalid_for_token: + any_invalid = True + for c in invalid_for_token: + token = token.replace(c, replace) + accum.append(token) + + if any_invalid: + return ' '.join(accum) + return query + + @classmethod + def match(cls, term): + """ + Generate a `MATCH` expression appropriate for searching this table. + """ + return match(cls._meta.entity, term) + + @classmethod + def rank(cls, *args): + return cls.bm25(*args) if args else SQL('rank') + + @classmethod + def bm25(cls, *weights): + return fn.bm25(cls._meta.entity, *weights) + + @classmethod + def search(cls, term, weights=None, with_score=False, score_alias='score', + explicit_ordering=False): + """Full-text search using selected `term`.""" + return cls.search_bm25( + FTS5Model.clean_query(term), + weights, + with_score, + score_alias, + explicit_ordering) + + @classmethod + def search_bm25(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search using selected `term`.""" + if not weights: + rank = SQL('rank') + elif isinstance(weights, dict): + weight_args = [] + for field in cls._meta.sorted_fields: + if isinstance(field, SearchField) and not field.unindexed: + weight_args.append( + weights.get(field, weights.get(field.name, 1.0))) + rank = fn.bm25(cls._meta.entity, *weight_args) + else: + rank = fn.bm25(cls._meta.entity, *weights) + + selection = () + order_by = rank + if with_score: + selection = (cls, rank.alias(score_alias)) + if with_score and not explicit_ordering: + order_by = SQL(score_alias) + + return (cls + .select(*selection) + .where(cls.match(FTS5Model.clean_query(term))) + .order_by(order_by)) + + @classmethod + def _fts_cmd_sql(cls, cmd, **extra_params): + tbl = cls._meta.entity + columns = [tbl] + values = [cmd] + for key, value in extra_params.items(): + columns.append(Entity(key)) + values.append(value) + + return NodeList(( + SQL('INSERT INTO'), + cls._meta.entity, + EnclosedNodeList(columns), + SQL('VALUES'), + EnclosedNodeList(values))) + + @classmethod + def _fts_cmd(cls, cmd, **extra_params): + query = cls._fts_cmd_sql(cmd, **extra_params) + return cls._meta.database.execute(query) + + @classmethod + def automerge(cls, level): + if not (0 <= level <= 16): + raise ValueError('level must be between 0 and 16') + return cls._fts_cmd('automerge', rank=level) + + @classmethod + def merge(cls, npages): + return cls._fts_cmd('merge', rank=npages) + + @classmethod + def set_pgsz(cls, pgsz): + return cls._fts_cmd('pgsz', rank=pgsz) + + @classmethod + def set_rank(cls, rank_expression): + return cls._fts_cmd('rank', rank=rank_expression) + + @classmethod + def delete_all(cls): + return cls._fts_cmd('delete-all') + + @classmethod + def VocabModel(cls, table_type='row', table=None): + if table_type not in ('row', 'col', 'instance'): + raise ValueError('table_type must be either "row", "col" or ' + '"instance".') + + attr = '_vocab_model_%s' % table_type + + if not hasattr(cls, attr): + class Meta: + database = cls._meta.database + table_name = table or cls._meta.table_name + '_v' + extension_module = fn.fts5vocab( + cls._meta.entity, + SQL(table_type)) + + attrs = { + 'term': VirtualField(TextField), + 'doc': IntegerField(), + 'cnt': IntegerField(), + 'rowid': RowIDField(), + 'Meta': Meta, + } + if table_type == 'col': + attrs['col'] = VirtualField(TextField) + elif table_type == 'instance': + attrs['offset'] = VirtualField(IntegerField) + + class_name = '%sVocab' % cls.__name__ + setattr(cls, attr, type(class_name, (VirtualModel,), attrs)) + + return getattr(cls, attr) + + +def ClosureTable(model_class, foreign_key=None, referencing_class=None, + referencing_key=None): + """Model factory for the transitive closure extension.""" + if referencing_class is None: + referencing_class = model_class + + if foreign_key is None: + for field_obj in model_class._meta.refs: + if field_obj.rel_model is model_class: + foreign_key = field_obj + break + else: + raise ValueError('Unable to find self-referential foreign key.') + + source_key = model_class._meta.primary_key + if referencing_key is None: + referencing_key = source_key + + class BaseClosureTable(VirtualModel): + depth = VirtualField(IntegerField) + id = VirtualField(IntegerField) + idcolumn = VirtualField(TextField) + parentcolumn = VirtualField(TextField) + root = VirtualField(IntegerField) + tablename = VirtualField(TextField) + + class Meta: + extension_module = 'transitive_closure' + + @classmethod + def descendants(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(source_key == cls.id)) + .where(cls.root == node) + .objects()) + if depth is not None: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def ancestors(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(source_key == cls.root)) + .where(cls.id == node) + .objects()) + if depth: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def siblings(cls, node, include_node=False): + if referencing_class is model_class: + # self-join + fk_value = node.__data__.get(foreign_key.name) + query = model_class.select().where(foreign_key == fk_value) + else: + # siblings as given in reference_class + siblings = (referencing_class + .select(referencing_key) + .join(cls, on=(foreign_key == cls.root)) + .where((cls.id == node) & (cls.depth == 1))) + + # the according models + query = (model_class + .select() + .where(source_key << siblings) + .objects()) + + if not include_node: + query = query.where(source_key != node) + + return query + + class Meta: + database = referencing_class._meta.database + options = { + 'tablename': referencing_class._meta.table_name, + 'idcolumn': referencing_key.column_name, + 'parentcolumn': foreign_key.column_name} + primary_key = False + + name = '%sClosure' % model_class.__name__ + return type(name, (BaseClosureTable,), {'Meta': Meta}) + + +class LSMTable(VirtualModel): + class Meta: + extension_module = 'lsm1' + filename = None + + @classmethod + def clean_options(cls, options): + filename = cls._meta.filename + if not filename: + raise ValueError('LSM1 extension requires that you specify a ' + 'filename for the LSM database.') + else: + if len(filename) >= 2 and filename[0] != '"': + filename = '"%s"' % filename + if not cls._meta.primary_key: + raise ValueError('LSM1 models must specify a primary-key field.') + + key = cls._meta.primary_key + if isinstance(key, AutoField): + raise ValueError('LSM1 models must explicitly declare a primary ' + 'key field.') + if not isinstance(key, (TextField, BlobField, IntegerField)): + raise ValueError('LSM1 key must be a TextField, BlobField, or ' + 'IntegerField.') + key._hidden = True + if isinstance(key, IntegerField): + data_type = 'UINT' + elif isinstance(key, BlobField): + data_type = 'BLOB' + else: + data_type = 'TEXT' + cls._meta.prefix_arguments = [filename, '"%s"' % key.name, data_type] + + # Does the key map to a scalar value, or a tuple of values? + if len(cls._meta.sorted_fields) == 2: + cls._meta._value_field = cls._meta.sorted_fields[1] + else: + cls._meta._value_field = None + + return options + + @classmethod + def load_extension(cls, path='lsm.so'): + cls._meta.database.load_extension(path) + + @staticmethod + def slice_to_expr(key, idx): + if idx.start is not None and idx.stop is not None: + return key.between(idx.start, idx.stop) + elif idx.start is not None: + return key >= idx.start + elif idx.stop is not None: + return key <= idx.stop + + @staticmethod + def _apply_lookup_to_query(query, key, lookup): + if isinstance(lookup, slice): + expr = LSMTable.slice_to_expr(key, lookup) + if expr is not None: + query = query.where(expr) + return query, False + elif isinstance(lookup, Expression): + return query.where(lookup), False + else: + return query.where(key == lookup), True + + @classmethod + def get_by_id(cls, pk): + query, is_single = cls._apply_lookup_to_query( + cls.select().namedtuples(), + cls._meta.primary_key, + pk) + + if is_single: + try: + row = query.get() + except cls.DoesNotExist: + raise KeyError(pk) + return row[1] if cls._meta._value_field is not None else row + else: + return query + + @classmethod + def set_by_id(cls, key, value): + if cls._meta._value_field is not None: + data = {cls._meta._value_field: value} + elif isinstance(value, tuple): + data = {} + for field, fval in zip(cls._meta.sorted_fields[1:], value): + data[field] = fval + elif isinstance(value, dict): + data = value + elif isinstance(value, cls): + data = value.__dict__ + data[cls._meta.primary_key] = key + cls.replace(data).execute() + + @classmethod + def delete_by_id(cls, pk): + query, is_single = cls._apply_lookup_to_query( + cls.delete(), + cls._meta.primary_key, + pk) + return query.execute() + + +OP.MATCH = 'MATCH' + +def _sqlite_regexp(regex, value): + return re.search(regex, value) is not None + + +class SqliteExtDatabase(SqliteDatabase): + def __init__(self, database, c_extensions=None, rank_functions=True, + hash_functions=False, regexp_function=False, + bloomfilter=False, json_contains=False, *args, **kwargs): + super(SqliteExtDatabase, self).__init__(database, *args, **kwargs) + self._row_factory = None + + if c_extensions and not CYTHON_SQLITE_EXTENSIONS: + raise ImproperlyConfigured('SqliteExtDatabase initialized with ' + 'C extensions, but shared library was ' + 'not found!') + prefer_c = CYTHON_SQLITE_EXTENSIONS and (c_extensions is not False) + if rank_functions: + if prefer_c: + register_rank_functions(self) + else: + self.register_function(bm25, 'fts_bm25') + self.register_function(rank, 'fts_rank') + self.register_function(bm25, 'fts_bm25f') # Fall back to bm25. + self.register_function(bm25, 'fts_lucene') + if hash_functions: + if not prefer_c: + raise ValueError('C extension required to register hash ' + 'functions.') + register_hash_functions(self) + if regexp_function: + self.register_function(_sqlite_regexp, 'regexp', 2) + if bloomfilter: + if not prefer_c: + raise ValueError('C extension required to use bloomfilter.') + register_bloomfilter(self) + if json_contains: + self.register_function(_json_contains, 'json_contains') + + self._c_extensions = prefer_c + + def _add_conn_hooks(self, conn): + super(SqliteExtDatabase, self)._add_conn_hooks(conn) + if self._row_factory: + conn.row_factory = self._row_factory + + def row_factory(self, fn): + self._row_factory = fn + + +if CYTHON_SQLITE_EXTENSIONS: + SQLITE_STATUS_MEMORY_USED = 0 + SQLITE_STATUS_PAGECACHE_USED = 1 + SQLITE_STATUS_PAGECACHE_OVERFLOW = 2 + SQLITE_STATUS_SCRATCH_USED = 3 + SQLITE_STATUS_SCRATCH_OVERFLOW = 4 + SQLITE_STATUS_MALLOC_SIZE = 5 + SQLITE_STATUS_PARSER_STACK = 6 + SQLITE_STATUS_PAGECACHE_SIZE = 7 + SQLITE_STATUS_SCRATCH_SIZE = 8 + SQLITE_STATUS_MALLOC_COUNT = 9 + SQLITE_DBSTATUS_LOOKASIDE_USED = 0 + SQLITE_DBSTATUS_CACHE_USED = 1 + SQLITE_DBSTATUS_SCHEMA_USED = 2 + SQLITE_DBSTATUS_STMT_USED = 3 + SQLITE_DBSTATUS_LOOKASIDE_HIT = 4 + SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5 + SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6 + SQLITE_DBSTATUS_CACHE_HIT = 7 + SQLITE_DBSTATUS_CACHE_MISS = 8 + SQLITE_DBSTATUS_CACHE_WRITE = 9 + SQLITE_DBSTATUS_DEFERRED_FKS = 10 + #SQLITE_DBSTATUS_CACHE_USED_SHARED = 11 + + def __status__(flag, return_highwater=False): + """ + Expose a sqlite3_status() call for a particular flag as a property of + the Database object. + """ + def getter(self): + result = sqlite_get_status(flag) + return result[1] if return_highwater else result + return property(getter) + + def __dbstatus__(flag, return_highwater=False, return_current=False): + """ + Expose a sqlite3_dbstatus() call for a particular flag as a property of + the Database instance. Unlike sqlite3_status(), the dbstatus properties + pertain to the current connection. + """ + def getter(self): + if self._state.conn is None: + raise ImproperlyConfigured('database connection not opened.') + result = sqlite_get_db_status(self._state.conn, flag) + if return_current: + return result[0] + return result[1] if return_highwater else result + return property(getter) + + class CSqliteExtDatabase(SqliteExtDatabase): + def __init__(self, *args, **kwargs): + self._conn_helper = None + self._commit_hook = self._rollback_hook = self._update_hook = None + self._replace_busy_handler = False + super(CSqliteExtDatabase, self).__init__(*args, **kwargs) + + def init(self, database, replace_busy_handler=False, **kwargs): + super(CSqliteExtDatabase, self).init(database, **kwargs) + self._replace_busy_handler = replace_busy_handler + + def _close(self, conn): + if self._commit_hook: + self._conn_helper.set_commit_hook(None) + if self._rollback_hook: + self._conn_helper.set_rollback_hook(None) + if self._update_hook: + self._conn_helper.set_update_hook(None) + return super(CSqliteExtDatabase, self)._close(conn) + + def _add_conn_hooks(self, conn): + super(CSqliteExtDatabase, self)._add_conn_hooks(conn) + self._conn_helper = ConnectionHelper(conn) + if self._commit_hook is not None: + self._conn_helper.set_commit_hook(self._commit_hook) + if self._rollback_hook is not None: + self._conn_helper.set_rollback_hook(self._rollback_hook) + if self._update_hook is not None: + self._conn_helper.set_update_hook(self._update_hook) + if self._replace_busy_handler: + timeout = self._timeout or 5 + self._conn_helper.set_busy_handler(timeout * 1000) + + def on_commit(self, fn): + self._commit_hook = fn + if not self.is_closed(): + self._conn_helper.set_commit_hook(fn) + return fn + + def on_rollback(self, fn): + self._rollback_hook = fn + if not self.is_closed(): + self._conn_helper.set_rollback_hook(fn) + return fn + + def on_update(self, fn): + self._update_hook = fn + if not self.is_closed(): + self._conn_helper.set_update_hook(fn) + return fn + + def changes(self): + return self._conn_helper.changes() + + @property + def last_insert_rowid(self): + return self._conn_helper.last_insert_rowid() + + @property + def autocommit(self): + return self._conn_helper.autocommit() + + def backup(self, destination, pages=None, name=None, progress=None): + return backup(self.connection(), destination.connection(), + pages=pages, name=name, progress=progress) + + def backup_to_file(self, filename, pages=None, name=None, + progress=None): + return backup_to_file(self.connection(), filename, pages=pages, + name=name, progress=progress) + + def blob_open(self, table, column, rowid, read_only=False): + return Blob(self, table, column, rowid, read_only) + + # Status properties. + memory_used = __status__(SQLITE_STATUS_MEMORY_USED) + malloc_size = __status__(SQLITE_STATUS_MALLOC_SIZE, True) + malloc_count = __status__(SQLITE_STATUS_MALLOC_COUNT) + pagecache_used = __status__(SQLITE_STATUS_PAGECACHE_USED) + pagecache_overflow = __status__(SQLITE_STATUS_PAGECACHE_OVERFLOW) + pagecache_size = __status__(SQLITE_STATUS_PAGECACHE_SIZE, True) + scratch_used = __status__(SQLITE_STATUS_SCRATCH_USED) + scratch_overflow = __status__(SQLITE_STATUS_SCRATCH_OVERFLOW) + scratch_size = __status__(SQLITE_STATUS_SCRATCH_SIZE, True) + + # Connection status properties. + lookaside_used = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_USED) + lookaside_hit = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_HIT, True) + lookaside_miss = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE, + True) + lookaside_miss_full = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL, + True) + cache_used = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED, False, True) + #cache_used_shared = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED_SHARED, + # False, True) + schema_used = __dbstatus__(SQLITE_DBSTATUS_SCHEMA_USED, False, True) + statement_used = __dbstatus__(SQLITE_DBSTATUS_STMT_USED, False, True) + cache_hit = __dbstatus__(SQLITE_DBSTATUS_CACHE_HIT, False, True) + cache_miss = __dbstatus__(SQLITE_DBSTATUS_CACHE_MISS, False, True) + cache_write = __dbstatus__(SQLITE_DBSTATUS_CACHE_WRITE, False, True) + + +def match(lhs, rhs): + return Expression(lhs, OP.MATCH, rhs) + +def _parse_match_info(buf): + # See http://sqlite.org/fts3.html#matchinfo + bufsize = len(buf) # Length in bytes. + return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] + +def get_weights(ncol, raw_weights): + if not raw_weights: + return [1] * ncol + else: + weights = [0] * ncol + for i, weight in enumerate(raw_weights): + weights[i] = weight + return weights + +# Ranking implementation, which parse matchinfo. +def rank(raw_match_info, *raw_weights): + # Handle match_info called w/default args 'pcx' - based on the example rank + # function http://sqlite.org/fts3.html#appendix_a + match_info = _parse_match_info(raw_match_info) + score = 0.0 + + p, c = match_info[:2] + weights = get_weights(c, raw_weights) + + # matchinfo X value corresponds to, for each phrase in the search query, a + # list of 3 values for each column in the search table. + # So if we have a two-phrase search query and three columns of data, the + # following would be the layout: + # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] + # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] + for phrase_num in range(p): + phrase_info_idx = 2 + (phrase_num * c * 3) + for col_num in range(c): + weight = weights[col_num] + if not weight: + continue + + col_idx = phrase_info_idx + (col_num * 3) + + # The idea is that we count the number of times the phrase appears + # in this column of the current row, compared to how many times it + # appears in this column across all rows. The ratio of these values + # provides a rough way to score based on "high value" terms. + row_hits = match_info[col_idx] + all_rows_hits = match_info[col_idx + 1] + if row_hits > 0: + score += weight * (float(row_hits) / all_rows_hits) + + return -score + +# Okapi BM25 ranking implementation (FTS4 only). +def bm25(raw_match_info, *args): + """ + Usage: + + # Format string *must* be pcnalx + # Second parameter to bm25 specifies the index of the column, on + # the table being queries. + bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank + """ + match_info = _parse_match_info(raw_match_info) + K = 1.2 + B = 0.75 + score = 0.0 + + P_O, C_O, N_O, A_O = range(4) # Offsets into the matchinfo buffer. + term_count = match_info[P_O] # n + col_count = match_info[C_O] + total_docs = match_info[N_O] # N + L_O = A_O + col_count + X_O = L_O + col_count + + # Worked example of pcnalx for two columns and two phrases, 100 docs total. + # { + # p = 2 + # c = 2 + # n = 100 + # a0 = 4 -- avg number of tokens for col0, e.g. title + # a1 = 40 -- avg number of tokens for col1, e.g. body + # l0 = 5 -- curr doc has 5 tokens in col0 + # l1 = 30 -- curr doc has 30 tokens in col1 + # + # x000 -- hits this row for phrase0, col0 + # x001 -- hits all rows for phrase0, col0 + # x002 -- rows with phrase0 in col0 at least once + # + # x010 -- hits this row for phrase0, col1 + # x011 -- hits all rows for phrase0, col1 + # x012 -- rows with phrase0 in col1 at least once + # + # x100 -- hits this row for phrase1, col0 + # x101 -- hits all rows for phrase1, col0 + # x102 -- rows with phrase1 in col0 at least once + # + # x110 -- hits this row for phrase1, col1 + # x111 -- hits all rows for phrase1, col1 + # x112 -- rows with phrase1 in col1 at least once + # } + + weights = get_weights(col_count, args) + + for i in range(term_count): + for j in range(col_count): + weight = weights[j] + if weight == 0: + continue + + x = X_O + (3 * (j + i * col_count)) + term_frequency = float(match_info[x]) # f(qi, D) + docs_with_term = float(match_info[x + 2]) # n(qi) + + # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) + idf = math.log( + (total_docs - docs_with_term + 0.5) / + (docs_with_term + 0.5)) + if idf <= 0.0: + idf = 1e-6 + + doc_length = float(match_info[L_O + j]) # |D| + avg_length = float(match_info[A_O + j]) or 1. # avgdl + ratio = doc_length / avg_length + + num = term_frequency * (K + 1.0) + b_part = 1.0 - B + (B * ratio) + denom = term_frequency + (K * b_part) + + pc_score = idf * (num / denom) + score += (pc_score * weight) + + return -score + + +def _json_contains(src_json, obj_json): + stack = [] + try: + stack.append((json.loads(obj_json), json.loads(src_json))) + except: + # Invalid JSON! + return False + + while stack: + obj, src = stack.pop() + if isinstance(src, dict): + if isinstance(obj, dict): + for key in obj: + if key not in src: + return False + stack.append((obj[key], src[key])) + elif isinstance(obj, list): + for item in obj: + if item not in src: + return False + elif obj not in src: + return False + elif isinstance(src, list): + if isinstance(obj, dict): + return False + elif isinstance(obj, list): + try: + for i in range(len(obj)): + stack.append((obj[i], src[i])) + except IndexError: + return False + elif obj not in src: + return False + elif obj != src: + return False + return True diff --git a/python2.7libs/playhouse/sqlite_udf.py b/python2.7libs/playhouse/sqlite_udf.py new file mode 100644 index 0000000..28dbd85 --- /dev/null +++ b/python2.7libs/playhouse/sqlite_udf.py @@ -0,0 +1,522 @@ +import datetime +import hashlib +import heapq +import math +import os +import random +import re +import sys +import threading +import zlib +try: + from collections import Counter +except ImportError: + Counter = None +try: + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse + +try: + from playhouse._sqlite_ext import TableFunction +except ImportError: + TableFunction = None + + +SQLITE_DATETIME_FORMATS = ( + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d', + '%H:%M:%S', + '%H:%M:%S.%f', + '%H:%M') + +from peewee import format_date_time + +def format_date_time_sqlite(date_value): + return format_date_time(date_value, SQLITE_DATETIME_FORMATS) + +try: + from playhouse import _sqlite_udf as cython_udf +except ImportError: + cython_udf = None + + +# Group udf by function. +CONTROL_FLOW = 'control_flow' +DATE = 'date' +FILE = 'file' +HELPER = 'helpers' +MATH = 'math' +STRING = 'string' + +AGGREGATE_COLLECTION = {} +TABLE_FUNCTION_COLLECTION = {} +UDF_COLLECTION = {} + + +class synchronized_dict(dict): + def __init__(self, *args, **kwargs): + super(synchronized_dict, self).__init__(*args, **kwargs) + self._lock = threading.Lock() + + def __getitem__(self, key): + with self._lock: + return super(synchronized_dict, self).__getitem__(key) + + def __setitem__(self, key, value): + with self._lock: + return super(synchronized_dict, self).__setitem__(key, value) + + def __delitem__(self, key): + with self._lock: + return super(synchronized_dict, self).__delitem__(key) + + +STATE = synchronized_dict() +SETTINGS = synchronized_dict() + +# Class and function decorators. +def aggregate(*groups): + def decorator(klass): + for group in groups: + AGGREGATE_COLLECTION.setdefault(group, []) + AGGREGATE_COLLECTION[group].append(klass) + return klass + return decorator + +def table_function(*groups): + def decorator(klass): + for group in groups: + TABLE_FUNCTION_COLLECTION.setdefault(group, []) + TABLE_FUNCTION_COLLECTION[group].append(klass) + return klass + return decorator + +def udf(*groups): + def decorator(fn): + for group in groups: + UDF_COLLECTION.setdefault(group, []) + UDF_COLLECTION[group].append(fn) + return fn + return decorator + +# Register aggregates / functions with connection. +def register_aggregate_groups(db, *groups): + seen = set() + for group in groups: + klasses = AGGREGATE_COLLECTION.get(group, ()) + for klass in klasses: + name = getattr(klass, 'name', klass.__name__) + if name not in seen: + seen.add(name) + db.register_aggregate(klass, name) + +def register_table_function_groups(db, *groups): + seen = set() + for group in groups: + klasses = TABLE_FUNCTION_COLLECTION.get(group, ()) + for klass in klasses: + if klass.name not in seen: + seen.add(klass.name) + db.register_table_function(klass) + +def register_udf_groups(db, *groups): + seen = set() + for group in groups: + functions = UDF_COLLECTION.get(group, ()) + for function in functions: + name = function.__name__ + if name not in seen: + seen.add(name) + db.register_function(function, name) + +def register_groups(db, *groups): + register_aggregate_groups(db, *groups) + register_table_function_groups(db, *groups) + register_udf_groups(db, *groups) + +def register_all(db): + register_aggregate_groups(db, *AGGREGATE_COLLECTION) + register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION) + register_udf_groups(db, *UDF_COLLECTION) + + +# Begin actual user-defined functions and aggregates. + +# Scalar functions. +@udf(CONTROL_FLOW) +def if_then_else(cond, truthy, falsey=None): + if cond: + return truthy + return falsey + +@udf(DATE) +def strip_tz(date_str): + date_str = date_str.replace('T', ' ') + tz_idx1 = date_str.find('+') + if tz_idx1 != -1: + return date_str[:tz_idx1] + tz_idx2 = date_str.find('-') + if tz_idx2 > 13: + return date_str[:tz_idx2] + return date_str + +@udf(DATE) +def human_delta(nseconds, glue=', '): + parts = ( + (86400 * 365, 'year'), + (86400 * 30, 'month'), + (86400 * 7, 'week'), + (86400, 'day'), + (3600, 'hour'), + (60, 'minute'), + (1, 'second'), + ) + accum = [] + for offset, name in parts: + val, nseconds = divmod(nseconds, offset) + if val: + suffix = val != 1 and 's' or '' + accum.append('%s %s%s' % (val, name, suffix)) + if not accum: + return '0 seconds' + return glue.join(accum) + +@udf(FILE) +def file_ext(filename): + try: + res = os.path.splitext(filename) + except ValueError: + return None + return res[1] + +@udf(FILE) +def file_read(filename): + try: + with open(filename) as fh: + return fh.read() + except: + pass + +if sys.version_info[0] == 2: + @udf(HELPER) + def gzip(data, compression=9): + return buffer(zlib.compress(data, compression)) + + @udf(HELPER) + def gunzip(data): + return zlib.decompress(data) +else: + @udf(HELPER) + def gzip(data, compression=9): + if isinstance(data, str): + data = bytes(data.encode('raw_unicode_escape')) + return zlib.compress(data, compression) + + @udf(HELPER) + def gunzip(data): + return zlib.decompress(data) + +@udf(HELPER) +def hostname(url): + parse_result = urlparse(url) + if parse_result: + return parse_result.netloc + +@udf(HELPER) +def toggle(key): + key = key.lower() + STATE[key] = ret = not STATE.get(key) + return ret + +@udf(HELPER) +def setting(key, value=None): + if value is None: + return SETTINGS.get(key) + else: + SETTINGS[key] = value + return value + +@udf(HELPER) +def clear_settings(): + SETTINGS.clear() + +@udf(HELPER) +def clear_toggles(): + STATE.clear() + +@udf(MATH) +def randomrange(start, end=None, step=None): + if end is None: + start, end = 0, start + elif step is None: + step = 1 + return random.randrange(start, end, step) + +@udf(MATH) +def gauss_distribution(mean, sigma): + try: + return random.gauss(mean, sigma) + except ValueError: + return None + +@udf(MATH) +def sqrt(n): + try: + return math.sqrt(n) + except ValueError: + return None + +@udf(MATH) +def tonumber(s): + try: + return int(s) + except ValueError: + try: + return float(s) + except: + return None + +@udf(STRING) +def substr_count(haystack, needle): + if not haystack or not needle: + return 0 + return haystack.count(needle) + +@udf(STRING) +def strip_chars(haystack, chars): + return haystack.strip(chars) + +def _hash(constructor, *args): + hash_obj = constructor() + for arg in args: + hash_obj.update(arg) + return hash_obj.hexdigest() + +# Aggregates. +class _heap_agg(object): + def __init__(self): + self.heap = [] + self.ct = 0 + + def process(self, value): + return value + + def step(self, value): + self.ct += 1 + heapq.heappush(self.heap, self.process(value)) + +class _datetime_heap_agg(_heap_agg): + def process(self, value): + return format_date_time_sqlite(value) + +if sys.version_info[:2] == (2, 6): + def total_seconds(td): + return (td.seconds + + (td.days * 86400) + + (td.microseconds / (10.**6))) +else: + total_seconds = lambda td: td.total_seconds() + +@aggregate(DATE) +class mintdiff(_datetime_heap_agg): + def finalize(self): + dtp = min_diff = None + while self.heap: + if min_diff is None: + if dtp is None: + dtp = heapq.heappop(self.heap) + continue + dt = heapq.heappop(self.heap) + diff = dt - dtp + if min_diff is None or min_diff > diff: + min_diff = diff + dtp = dt + if min_diff is not None: + return total_seconds(min_diff) + +@aggregate(DATE) +class avgtdiff(_datetime_heap_agg): + def finalize(self): + if self.ct < 1: + return + elif self.ct == 1: + return 0 + + total = ct = 0 + dtp = None + while self.heap: + if total == 0: + if dtp is None: + dtp = heapq.heappop(self.heap) + continue + + dt = heapq.heappop(self.heap) + diff = dt - dtp + ct += 1 + total += total_seconds(diff) + dtp = dt + + return float(total) / ct + +@aggregate(DATE) +class duration(object): + def __init__(self): + self._min = self._max = None + + def step(self, value): + dt = format_date_time_sqlite(value) + if self._min is None or dt < self._min: + self._min = dt + if self._max is None or dt > self._max: + self._max = dt + + def finalize(self): + if self._min and self._max: + td = (self._max - self._min) + return total_seconds(td) + return None + +@aggregate(MATH) +class mode(object): + if Counter: + def __init__(self): + self.items = Counter() + + def step(self, *args): + self.items.update(args) + + def finalize(self): + if self.items: + return self.items.most_common(1)[0][0] + else: + def __init__(self): + self.items = [] + + def step(self, item): + self.items.append(item) + + def finalize(self): + if self.items: + return max(set(self.items), key=self.items.count) + +@aggregate(MATH) +class minrange(_heap_agg): + def finalize(self): + if self.ct == 0: + return + elif self.ct == 1: + return 0 + + prev = min_diff = None + + while self.heap: + if min_diff is None: + if prev is None: + prev = heapq.heappop(self.heap) + continue + curr = heapq.heappop(self.heap) + diff = curr - prev + if min_diff is None or min_diff > diff: + min_diff = diff + prev = curr + return min_diff + +@aggregate(MATH) +class avgrange(_heap_agg): + def finalize(self): + if self.ct == 0: + return + elif self.ct == 1: + return 0 + + total = ct = 0 + prev = None + while self.heap: + if total == 0: + if prev is None: + prev = heapq.heappop(self.heap) + continue + + curr = heapq.heappop(self.heap) + diff = curr - prev + ct += 1 + total += diff + prev = curr + + return float(total) / ct + +@aggregate(MATH) +class _range(object): + name = 'range' + + def __init__(self): + self._min = self._max = None + + def step(self, value): + if self._min is None or value < self._min: + self._min = value + if self._max is None or value > self._max: + self._max = value + + def finalize(self): + if self._min is not None and self._max is not None: + return self._max - self._min + return None + + +if cython_udf is not None: + damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist) + levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist) + str_dist = udf(STRING)(cython_udf.str_dist) + median = aggregate(MATH)(cython_udf.median) + + +if TableFunction is not None: + @table_function(STRING) + class RegexSearch(TableFunction): + params = ['regex', 'search_string'] + columns = ['match'] + name = 'regex_search' + + def initialize(self, regex=None, search_string=None): + self._iter = re.finditer(regex, search_string) + + def iterate(self, idx): + return (next(self._iter).group(0),) + + @table_function(DATE) + class DateSeries(TableFunction): + params = ['start', 'stop', 'step_seconds'] + columns = ['date'] + name = 'date_series' + + def initialize(self, start, stop, step_seconds=86400): + self.start = format_date_time_sqlite(start) + self.stop = format_date_time_sqlite(stop) + step_seconds = int(step_seconds) + self.step_seconds = datetime.timedelta(seconds=step_seconds) + + if (self.start.hour == 0 and + self.start.minute == 0 and + self.start.second == 0 and + step_seconds >= 86400): + self.format = '%Y-%m-%d' + elif (self.start.year == 1900 and + self.start.month == 1 and + self.start.day == 1 and + self.stop.year == 1900 and + self.stop.month == 1 and + self.stop.day == 1 and + step_seconds < 86400): + self.format = '%H:%M:%S' + else: + self.format = '%Y-%m-%d %H:%M:%S' + + def iterate(self, idx): + if self.start > self.stop: + raise StopIteration + current = self.start + self.start += self.step_seconds + return (current.strftime(self.format),) diff --git a/python2.7libs/playhouse/sqliteq.py b/python2.7libs/playhouse/sqliteq.py new file mode 100644 index 0000000..bd21354 --- /dev/null +++ b/python2.7libs/playhouse/sqliteq.py @@ -0,0 +1,330 @@ +import logging +import weakref +from threading import local as thread_local +from threading import Event +from threading import Thread +try: + from Queue import Queue +except ImportError: + from queue import Queue + +try: + import gevent + from gevent import Greenlet as GThread + from gevent.event import Event as GEvent + from gevent.local import local as greenlet_local + from gevent.queue import Queue as GQueue +except ImportError: + GThread = GQueue = GEvent = None + +from peewee import SENTINEL +from playhouse.sqlite_ext import SqliteExtDatabase + + +logger = logging.getLogger('peewee.sqliteq') + + +class ResultTimeout(Exception): + pass + +class WriterPaused(Exception): + pass + +class ShutdownException(Exception): + pass + + +class AsyncCursor(object): + __slots__ = ('sql', 'params', 'commit', 'timeout', + '_event', '_cursor', '_exc', '_idx', '_rows', '_ready') + + def __init__(self, event, sql, params, commit, timeout): + self._event = event + self.sql = sql + self.params = params + self.commit = commit + self.timeout = timeout + self._cursor = self._exc = self._idx = self._rows = None + self._ready = False + + def set_result(self, cursor, exc=None): + self._cursor = cursor + self._exc = exc + self._idx = 0 + self._rows = cursor.fetchall() if exc is None else [] + self._event.set() + return self + + def _wait(self, timeout=None): + timeout = timeout if timeout is not None else self.timeout + if not self._event.wait(timeout=timeout) and timeout: + raise ResultTimeout('results not ready, timed out.') + if self._exc is not None: + raise self._exc + self._ready = True + + def __iter__(self): + if not self._ready: + self._wait() + if self._exc is not None: + raise self._exec + return self + + def next(self): + if not self._ready: + self._wait() + try: + obj = self._rows[self._idx] + except IndexError: + raise StopIteration + else: + self._idx += 1 + return obj + __next__ = next + + @property + def lastrowid(self): + if not self._ready: + self._wait() + return self._cursor.lastrowid + + @property + def rowcount(self): + if not self._ready: + self._wait() + return self._cursor.rowcount + + @property + def description(self): + return self._cursor.description + + def close(self): + self._cursor.close() + + def fetchall(self): + return list(self) # Iterating implies waiting until populated. + + def fetchone(self): + if not self._ready: + self._wait() + try: + return next(self) + except StopIteration: + return None + +SHUTDOWN = StopIteration +PAUSE = object() +UNPAUSE = object() + + +class Writer(object): + __slots__ = ('database', 'queue') + + def __init__(self, database, queue): + self.database = database + self.queue = queue + + def run(self): + conn = self.database.connection() + try: + while True: + try: + if conn is None: # Paused. + if self.wait_unpause(): + conn = self.database.connection() + else: + conn = self.loop(conn) + except ShutdownException: + logger.info('writer received shutdown request, exiting.') + return + finally: + if conn is not None: + self.database._close(conn) + self.database._state.reset() + + def wait_unpause(self): + obj = self.queue.get() + if obj is UNPAUSE: + logger.info('writer unpaused - reconnecting to database.') + return True + elif obj is SHUTDOWN: + raise ShutdownException() + elif obj is PAUSE: + logger.error('writer received pause, but is already paused.') + else: + obj.set_result(None, WriterPaused()) + logger.warning('writer paused, not handling %s', obj) + + def loop(self, conn): + obj = self.queue.get() + if isinstance(obj, AsyncCursor): + self.execute(obj) + elif obj is PAUSE: + logger.info('writer paused - closing database connection.') + self.database._close(conn) + self.database._state.reset() + return + elif obj is UNPAUSE: + logger.error('writer received unpause, but is already running.') + elif obj is SHUTDOWN: + raise ShutdownException() + else: + logger.error('writer received unsupported object: %s', obj) + return conn + + def execute(self, obj): + logger.debug('received query %s', obj.sql) + try: + cursor = self.database._execute(obj.sql, obj.params, obj.commit) + except Exception as execute_err: + cursor = None + exc = execute_err # python3 is so fucking lame. + else: + exc = None + return obj.set_result(cursor, exc) + + +class SqliteQueueDatabase(SqliteExtDatabase): + WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL ' + 'journal mode when using this feature. WAL mode ' + 'allows one or more readers to continue reading ' + 'while another connection writes to the ' + 'database.') + + def __init__(self, database, use_gevent=False, autostart=True, + queue_max_size=None, results_timeout=None, *args, **kwargs): + kwargs['check_same_thread'] = False + + # Ensure that journal_mode is WAL. This value is passed to the parent + # class constructor below. + pragmas = self._validate_journal_mode(kwargs.pop('pragmas', None)) + + # Reference to execute_sql on the parent class. Since we've overridden + # execute_sql(), this is just a handy way to reference the real + # implementation. + Parent = super(SqliteQueueDatabase, self) + self._execute = Parent.execute_sql + + # Call the parent class constructor with our modified pragmas. + Parent.__init__(database, pragmas=pragmas, *args, **kwargs) + + self._autostart = autostart + self._results_timeout = results_timeout + self._is_stopped = True + + # Get different objects depending on the threading implementation. + self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size) + + # Create the writer thread, optionally starting it. + self._create_write_queue() + if self._autostart: + self.start() + + def get_thread_impl(self, use_gevent): + return GreenletHelper if use_gevent else ThreadHelper + + def _validate_journal_mode(self, pragmas=None): + if pragmas: + pdict = dict((k.lower(), v) for (k, v) in pragmas) + if pdict.get('journal_mode', 'wal').lower() != 'wal': + raise ValueError(self.WAL_MODE_ERROR_MESSAGE) + + return [(k, v) for (k, v) in pragmas + if k != 'journal_mode'] + [('journal_mode', 'wal')] + else: + return [('journal_mode', 'wal')] + + def _create_write_queue(self): + self._write_queue = self._thread_helper.queue() + + def queue_size(self): + return self._write_queue.qsize() + + def execute_sql(self, sql, params=None, commit=SENTINEL, timeout=None): + if commit is SENTINEL: + commit = not sql.lower().startswith('select') + + if not commit: + return self._execute(sql, params, commit=commit) + + cursor = AsyncCursor( + event=self._thread_helper.event(), + sql=sql, + params=params, + commit=commit, + timeout=self._results_timeout if timeout is None else timeout) + self._write_queue.put(cursor) + return cursor + + def start(self): + with self._lock: + if not self._is_stopped: + return False + def run(): + writer = Writer(self, self._write_queue) + writer.run() + + self._writer = self._thread_helper.thread(run) + self._writer.start() + self._is_stopped = False + return True + + def stop(self): + logger.debug('environment stop requested.') + with self._lock: + if self._is_stopped: + return False + self._write_queue.put(SHUTDOWN) + self._writer.join() + self._is_stopped = True + return True + + def is_stopped(self): + with self._lock: + return self._is_stopped + + def pause(self): + with self._lock: + self._write_queue.put(PAUSE) + + def unpause(self): + with self._lock: + self._write_queue.put(UNPAUSE) + + def __unsupported__(self, *args, **kwargs): + raise ValueError('This method is not supported by %r.' % type(self)) + atomic = transaction = savepoint = __unsupported__ + + +class ThreadHelper(object): + __slots__ = ('queue_max_size',) + + def __init__(self, queue_max_size=None): + self.queue_max_size = queue_max_size + + def event(self): return Event() + + def queue(self, max_size=None): + max_size = max_size if max_size is not None else self.queue_max_size + return Queue(maxsize=max_size or 0) + + def thread(self, fn, *args, **kwargs): + thread = Thread(target=fn, args=args, kwargs=kwargs) + thread.daemon = True + return thread + + +class GreenletHelper(ThreadHelper): + __slots__ = () + + def event(self): return GEvent() + + def queue(self, max_size=None): + max_size = max_size if max_size is not None else self.queue_max_size + return GQueue(maxsize=max_size or 0) + + def thread(self, fn, *args, **kwargs): + def wrap(*a, **k): + gevent.sleep() + return fn(*a, **k) + return GThread(wrap, *args, **kwargs) diff --git a/python2.7libs/playhouse/test_utils.py b/python2.7libs/playhouse/test_utils.py new file mode 100644 index 0000000..333dc07 --- /dev/null +++ b/python2.7libs/playhouse/test_utils.py @@ -0,0 +1,62 @@ +from functools import wraps +import logging + + +logger = logging.getLogger('peewee') + + +class _QueryLogHandler(logging.Handler): + def __init__(self, *args, **kwargs): + self.queries = [] + logging.Handler.__init__(self, *args, **kwargs) + + def emit(self, record): + self.queries.append(record) + + +class count_queries(object): + def __init__(self, only_select=False): + self.only_select = only_select + self.count = 0 + + def get_queries(self): + return self._handler.queries + + def __enter__(self): + self._handler = _QueryLogHandler() + logger.setLevel(logging.DEBUG) + logger.addHandler(self._handler) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + logger.removeHandler(self._handler) + if self.only_select: + self.count = len([q for q in self._handler.queries + if q.msg[0].startswith('SELECT ')]) + else: + self.count = len(self._handler.queries) + + +class assert_query_count(count_queries): + def __init__(self, expected, only_select=False): + super(assert_query_count, self).__init__(only_select=only_select) + self.expected = expected + + def __call__(self, f): + @wraps(f) + def decorated(*args, **kwds): + with self: + ret = f(*args, **kwds) + + self._assert_count() + return ret + + return decorated + + def _assert_count(self): + error_msg = '%s != %s' % (self.count, self.expected) + assert self.count == self.expected, error_msg + + def __exit__(self, exc_type, exc_val, exc_tb): + super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb) + self._assert_count() diff --git a/scripts/456.py b/scripts/456.py index e6845f6..6afc55b 100644 --- a/scripts/456.py +++ b/scripts/456.py @@ -1,9 +1,12 @@ from __future__ import print_function from searcher import searcher_data from searcher import util +from searcher import platformselect +from searcher import ptime as ptime from peewee import * -from playhouse.sqlite_ext import SqliteExtDatabase, SearchField, FTSModel +from peewee import SQL +from playhouse.sqlite_ext import SqliteExtDatabase, RowIDField, FTS5Model, SearchField # from playhouse.apsw_ext import APSWDatabase import inspect import threading @@ -23,18 +26,20 @@ inspect.getsourcefile(lambda: 0) ) +def get_db(): + return getattr(hou.session, "DATABASE", None) + scriptpath = os.path.dirname(current_file_path) dbpath = os.path.join(scriptpath, "python/searcher/db/searcher.db") -# db = SqliteExtDatabase(':memory:') -db = SqliteExtDatabase(dbpath) -dbc = None +hou.session.DATABASE = DatabaseProxy() +db = get_db() settingdata = {} isloading = True tempkey = "" -class settings(Model): +class Settings(Model): id = IntegerField(unique=True) indexvalue = IntegerField() defaulthotkey = TextField() @@ -47,10 +52,9 @@ class Meta: table_name = 'settings' database = db - -class hcontext(Model): +class HContext(Model): id = AutoField() - context = CharField(unique=True) + context = TextField(unique=True) title = TextField() description = TextField() @@ -58,10 +62,19 @@ class Meta: table_name = 'hcontext' database = db +class HContextIndex(FTS5Model): + # rowid = RowIDField() + context = SearchField() + title = SearchField() + description = SearchField() -class hotkeys(Model): + class Meta: + database = db + options = {'prefix': [2, 3], 'tokenize': 'porter'} + +class Hotkeys(Model): hotkey_symbol = CharField(unique=True) - label = TextField() + label = CharField() description = TextField() assignments = TextField() context = TextField() @@ -70,31 +83,31 @@ class Meta: table_name = 'hotkeys' database = db - -class hotkeyindex(FTSModel): - description = SearchField() +class HotkeysIndex(FTS5Model): + # rowid = RowIDField() + hotkey_symbol = SearchField(unindexed=True) label = SearchField() + description = SearchField() + assignments = SearchField(unindexed=True) + context = SearchField(unindexed=True) class Meta: - table_name = 'hotkeyindex' + # table_name = 'hotkeysindex' database = db - options = {'tokenize': 'porter', - 'description': hotkeys.description} - + options = {'prefix': [2, 3], 'tokenize': 'porter'} def create_tables(): with db: - db.create_tables([settings, hcontext, hotkeys]) - + db.create_tables([Settings, HContext, HContextIndex, Hotkeys, HotkeysIndex]) def worker(): hd.executeInMainThreadWithResult(updatecontext) - def py_unique(data): return list(set(data)) - +# ------------------------------------------------- getdata +# NOTE getdata -------------------------------------------- def getdata(): rval = [] contextdata = [] @@ -119,10 +132,8 @@ def getcontexts(r, context_symbol, root): getcontexts("", "", rval) return contextdata, hotkeydata - -# region --------------------------------------------------- Initial Setup - - +# -------------------------------------------- initialsetup +# NOTE initialsetup --------------------------------------- def initialsetup(cur): currentidx = hou.hotkeys.changeIndex() chindex = getchangeindex(cur) @@ -143,34 +154,15 @@ def initialsetup(cur): getlastusedhk(cur) updatedataasync() updatechangeindex(int(currentidx)) - if hou.isUIAvailable(): - hou.ui.setStatusMessage( - "Searcher database created and populated", severity=hou.severityType.Message) - - -def dbupdate(cur): - currentidx = hou.hotkeys.changeIndex() - chindex = getchangeindex(cur) - if int(currentidx) != chindex: - getlastusedhk(cur) - updatedataasync() - updatechangeindex(int(currentidx)) if hou.isUIAvailable(): hou.ui.setStatusMessage( - "Searcher database updated", severity=hou.severityType.Message) - - -def updatedataasync(): - thread = threading.Thread(target=worker) - thread.daemon = True - thread.start() - -# endregionc - -# region --------------------------------------------------- Retrieve - + "Searcher database created and populated", severity=hou.severityType.Message) +# --------------------------------------------------------------- Retrieve +# SECTION Retrieve ------------------------------------------------------- +# ------------------------------------------ getchangeindex +# NOTE getchangeindex ------------------------------------- def getchangeindex(cur): try: cur.execute("SELECT indexvalue FROM settings") @@ -183,7 +175,8 @@ def getchangeindex(cur): else: print("Could not get Searcher changeindex: " + str(e)) - +# ------------------------------------------- getlastusedhk +# NOTE getlastusedhk -------------------------------------- def getlastusedhk(cur): try: cur.execute("SELECT lastused FROM settings") @@ -196,8 +189,8 @@ def getlastusedhk(cur): hkcheck = hou.hotkeys.assignments(str(lasthk[0])) hou.hotkeys.saveOverrides() if len(hkcheck) is 0: - settings.update(lastused="").where( - settings.id == 1).execute() + Settings.update(lastused="").where( + Settings.id == 1).execute() currentidx = hou.hotkeys.changeIndex() updatechangeindex(int(currentidx)) else: @@ -205,8 +198,8 @@ def getlastusedhk(cur): hou.hotkeys.saveOverrides() hkcheck = hou.hotkeys.assignments(str(lasthk[0])) if len(hkcheck) is 0: - settings.update(lastused="").where( - settings.id == 1).execute() + Settings.update(lastused="").where( + Settings.id == 1).execute() currentidx = hou.hotkeys.changeIndex() updatechangeindex(int(currentidx)) else: @@ -229,10 +222,30 @@ def getlastusedhk(cur): ("Could not query last assigned temp hotkey:" + str(e)), severity=hou.severityType.Warning) else: print("Could not query last assigned temp hotkey: " + str(e)) +# !SECTION -# region --------------------------------------------------- Updates +# ----------------------------------------------------------------- Update +# SECTION Update --------------------------------------------------------- +# ------------------------------------------------ dbupdate +# NOTE dbupdate ------------------------------------------- +def dbupdate(cur): + currentidx = hou.hotkeys.changeIndex() + chindex = getchangeindex(cur) + if int(currentidx) != chindex: + getlastusedhk(cur) + updatedataasync() + updatechangeindex(int(currentidx)) +# ----------------------------------------- updatedataasync +# NOTE updatedataasync ------------------------------------ +def updatedataasync(): + thread = threading.Thread(target=worker) + thread.daemon = True + thread.start() + +# --------------------------------------- updatechangeindex +# NOTE updatechangeindex ---------------------------------- def updatechangeindex(indexval, new=False): try: if new is True: @@ -242,21 +255,25 @@ def updatechangeindex(indexval, new=False): if not result: defaultkey = util.HOTKEYLIST[i] - settings.insert(indexvalue=indexval, - defaulthotkey=defaultkey, searchdescription=0, searchprefix=0, searchcurrentcontext=0, lastused="", id=1).execute() + Settings.insert( + indexvalue=indexval, + defaulthotkey=defaultkey, + searchdescription=0, + searchprefix=0, + searchcurrentcontext=0, + lastused="", + id=1).execute() else: - settings.update(indexvalue=indexval).where( - settings.id == 1).execute() + Settings.update(indexvalue=indexval).where( + Settings.id == 1).execute() except(AttributeError, TypeError) as e: if hou.isUIAvailable(): - hou.ui.setStatusMessage( - ("Could not update Searcher context database: " + str(e)), - severity=hou.severityType.Warning - ) + hou.ui.setStatusMessage(("Could not update Searcher context database: " + str(e)), severity=hou.severityType.Warning) else: print("Could not update Searcher context database: " + str(e)) - +# ------------------------------------------- updatecontext +# NOTE updatecontext -------------------------------------- def updatecontext(debug=False): try: time1 = time.time() @@ -264,25 +281,28 @@ def updatecontext(debug=False): ctxdata, hkeydata = getdata() with db.atomic(): for data_dict in ctxdata: - hcontext.replace_many(data_dict).execute() + HContext.replace_many(data_dict).execute() + HContextIndex.replace_many(data_dict).execute() with db.atomic(): for idx in hkeydata: - hotkeys.replace_many(idx).execute() + Hotkeys.replace_many(idx).execute() + HotkeysIndex.replace_many(idx).execute() time2 = time.time() if debug: + res = ((time2 - time1) * 1000.0) if hou.isUIAvailable(): hou.ui.setStatusMessage( - ('DB update took %0.3f ms' % - ((time2 - time1) * 1000.0)), severity=hou.severityType.Message) + ('DB update took %0.4f ms' % res), severity=hou.severityType.Message) else: - print('DB update took %0.3f ms' % - ((time2 - time1) * 1000.0)) # TODO Remove this timer + print('DB update took %0.4f ms' % res) + return res except(AttributeError, TypeError) as e: hou.ui.setStatusMessage( ("Could not update Searcher context database: " + str(e)), severity=hou.severityType.Warning) # endregion - +# ------------------------------------------- cleardatabase +# NOTE cleardatabase -------------------------------------- def cleardatabase(): try: delhk = "DELETE FROM hotkeys" @@ -298,31 +318,63 @@ def cleardatabase(): ("Could not clear db for refresh: " + str(e)), severity=hou.severityType.Warning) else: print("Could not clear db for refresh: " + str(e)) - +# !SECTION def deferaction(action, val): hd.executeDeferred(action, val) - # hd.execute_deferred_after_waiting(action, 25) - def checklasthk(cur): getlastusedhk(cur) - def main(): if os.path.isfile(searcher_data.searcher_settings): settingdata = searcher_data.loadsettings() else: searcher_data.createdefaults() settingdata = searcher_data.loadsettings() - - if not os.path.isfile(dbpath): + isdebug = util.Dbug( + util.bc(settingdata[util.SETTINGS_KEYS[4]]), + str(settingdata[util.SETTINGS_KEYS[10]]) + ) + + inmemory = util.bc(settingdata[util.SETTINGS_KEYS[0]]) + if inmemory: + val = ':memory:' + else: + val = (dbpath) + + db.initialize( + SqliteExtDatabase( + val, + pragmas=( + ("cache_size", -1024 * 64), + ("journal_mode", "wal"), + ("synchronous", 0) + ))) + + time1 = ptime.time() + if inmemory: create_tables() cur = db.cursor() deferaction(initialsetup, cur) else: - cur = db.cursor() - deferaction(dbupdate, cur) + if not os.path.isfile(dbpath): + create_tables() + cur = db.cursor() + deferaction(initialsetup, cur) + else: + cur = db.cursor() + deferaction(dbupdate, cur) + + time2 = ptime.time() + + if isdebug and isdebug.level in {"TIMER", "ALL"}: + res = ((time2 - time1) * 1000.0) + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + ('Startup took %0.4f ms' % res), severity=hou.severityType.Message) + else: + print('Startup took %0.4f ms' % res) if __name__ == '__main__': diff --git a/scripts/456.py.bak b/scripts/456.py.bak new file mode 100644 index 0000000..d98e72a --- /dev/null +++ b/scripts/456.py.bak @@ -0,0 +1,329 @@ +from __future__ import print_function +from searcher import searcher_data +from searcher import util + +from peewee import * +from peewee import SQL +from playhouse.sqlite_ext import SqliteExtDatabase, FTSModel, SearchField +# from playhouse.apsw_ext import APSWDatabase +import inspect +import threading +import time +import hou +import hdefereval as hd +import os +import sys + +# info +__author__ = "instance.id" +__copyright__ = "2020 All rights reserved. See LICENSE for more details." +__status__ = "Prototype" + + +current_file_path = os.path.abspath( + inspect.getsourcefile(lambda: 0) +) + +scriptpath = os.path.dirname(current_file_path) +dbpath = os.path.join(scriptpath, "python/searcher/db/searcher.db") + +# db = SqliteExtDatabase(':memory:') +db = SqliteExtDatabase(dbpath) +dbc = None +settingdata = {} +isloading = True +tempkey = "" + +class settings(Model): + id = IntegerField(unique=True) + indexvalue = IntegerField() + defaulthotkey = TextField() + searchdescription = IntegerField() + searchprefix = IntegerField() + searchcurrentcontext = IntegerField() + lastused = TextField() + + class Meta: + table_name = 'settings' + database = db + + +class hcontext(Model): + id = AutoField() + context = CharField(unique=True) + title = TextField() + description = TextField() + + class Meta: + table_name = 'hcontext' + database = db + + +class hotkeys(Model): + hotkey_symbol = CharField(unique=True) + label = TextField() + description = TextField() + assignments = TextField() + context = TextField() + + class Meta: + table_name = 'hotkeys' + database = db + + +class hotkeyindex(FTSModel): + description = SearchField() + label = SearchField() + + class Meta: + table_name = 'hotkeyindex' + database = db + extension_options = {'tokenize': 'porter', + 'description': hotkeys.description} + + +def create_tables(): + with db: + db.create_tables([settings, hcontext, hotkeys]) + + +def worker(): + hd.executeInMainThreadWithResult(updatecontext) + + +def py_unique(data): + return list(set(data)) + + +def getdata(): + rval = [] + contextdata = [] + hotkeydata = [] + + def getcontexts(r, context_symbol, root): + keys = None + branches = hou.hotkeys.contextsInContext(context_symbol) + for branch in branches: + branch_path = "%s/%s" % (r, branch['label']) + contextdata.append( + {'context': branch['symbol'], 'title': branch['label'], 'description': branch['help']}) + commands = hou.hotkeys.commandsInContext(branch['symbol']) + for command in commands: + keys = hou.hotkeys.assignments(command['symbol']) + ctx = command['symbol'].rsplit('.', 1) + hotkeydata.append( + {'hotkey_symbol': command['symbol'], 'label': command['label'], 'description': command['help'], + 'assignments': " ".join(keys), 'context': ctx[0]}) + getcontexts(branch_path, branch['symbol'], root) + + getcontexts("", "", rval) + return contextdata, hotkeydata + + +# ---------------------------------------------------------- Initial Setup + + +def initialsetup(cur): + currentidx = hou.hotkeys.changeIndex() + chindex = getchangeindex(cur) + + if len(chindex) == 0: + chindex = int(currentidx) + updatechangeindex(chindex, True) + updatedataasync() + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + "Searcher database created", severity=hou.severityType.Message) + else: + print("Searcher database created") + else: + chindex = int(chindex[0][0]) + + if int(currentidx) != chindex: + getlastusedhk(cur) + updatedataasync() + updatechangeindex(int(currentidx)) + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + "Searcher database created and populated", severity=hou.severityType.Message) + + +def dbupdate(cur): + currentidx = hou.hotkeys.changeIndex() + chindex = getchangeindex(cur) + + if int(currentidx) != chindex: + getlastusedhk(cur) + updatedataasync() + updatechangeindex(int(currentidx)) + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + "Searcher database updated", severity=hou.severityType.Message) + + +def updatedataasync(): + thread = threading.Thread(target=worker) + thread.daemon = True + thread.start() + +# endregionc + +# ---------------------------------------------------------- Retrieve + + +def getchangeindex(cur): + try: + cur.execute("SELECT indexvalue FROM settings") + result = cur.fetchall() + return result + except(AttributeError, TypeError) as e: + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + ("Could not get Searcher changeindex: " + str(e)), severity=hou.severityType.Warning) + else: + print("Could not get Searcher changeindex: " + str(e)) + + +def getlastusedhk(cur): + try: + cur.execute("SELECT lastused FROM settings") + result = cur.fetchall() + if str(result[0][0]) != "": + lasthk = str(result[0][0]).split(' ') + rmresult = hou.hotkeys.removeAssignment( + str(lasthk[0]).strip(), str(lasthk[1]).strip()) + if rmresult: + hkcheck = hou.hotkeys.assignments(str(lasthk[0])) + hou.hotkeys.saveOverrides() + if len(hkcheck) is 0: + settings.update(lastused="").where( + settings.id == 1).execute() + currentidx = hou.hotkeys.changeIndex() + updatechangeindex(int(currentidx)) + else: + hou.hotkeys.clearAssignments(str(lasthk[0])) + hou.hotkeys.saveOverrides() + hkcheck = hou.hotkeys.assignments(str(lasthk[0])) + if len(hkcheck) is 0: + settings.update(lastused="").where( + settings.id == 1).execute() + currentidx = hou.hotkeys.changeIndex() + updatechangeindex(int(currentidx)) + else: + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + ("Could not clear last assigned temp hotkey on last attempt:"), severity=hou.severityType.Warning) + else: + print( + "Could not clear last assigned temp hotkey on last attempt:") + else: + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + ("Could not clear last assigned temp hotkey:"), severity=hou.severityType.Warning) + else: + print("Could not clear last assigned temp hotkey:") + + except(AttributeError, TypeError) as e: + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + ("Could not query last assigned temp hotkey:" + str(e)), severity=hou.severityType.Warning) + else: + print("Could not query last assigned temp hotkey: " + str(e)) + +# ---------------------------------------------------------- Updates + + +def updatechangeindex(indexval, new=False): + try: + if new is True: + defaultkey = "" + for i in range(len(util.HOTKEYLIST)): + result = hou.hotkeys.findConflicts("h", util.HOTKEYLIST[i]) + if not result: + defaultkey = util.HOTKEYLIST[i] + + settings.insert(indexvalue=indexval, + defaulthotkey=defaultkey, searchdescription=0, searchprefix=0, searchcurrentcontext=0, lastused="", id=1).execute() + else: + settings.update(indexvalue=indexval).where( + settings.id == 1).execute() + except(AttributeError, TypeError) as e: + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + ("Could not update Searcher context database: " + str(e)), + severity=hou.severityType.Warning + ) + else: + print("Could not update Searcher context database: " + str(e)) + + +def updatecontext(debug=False): + try: + time1 = time.time() + cleardatabase() + ctxdata, hkeydata = getdata() + with db.atomic(): + for data_dict in ctxdata: + hcontext.replace_many(data_dict).execute() + with db.atomic(): + for idx in hkeydata: + hotkeys.replace_many(idx).execute() + time2 = time.time() + if debug: + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + ('DB update took %0.3f ms' % + ((time2 - time1) * 1000.0)), severity=hou.severityType.Message) + else: + print('DB update took %0.3f ms' % + ((time2 - time1) * 1000.0)) # TODO Remove this timer + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not update Searcher context database: " + str(e)), severity=hou.severityType.Warning) +# endregion + + +def cleardatabase(): + try: + delhk = "DELETE FROM hotkeys" + delctx = "DELETE FROM hcontext" + db.cursor().execute(delhk) + db.cursor().execute(delctx) + result = db.cursor().fetchall() + + return result + except(AttributeError, TypeError) as e: + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + ("Could not clear db for refresh: " + str(e)), severity=hou.severityType.Warning) + else: + print("Could not clear db for refresh: " + str(e)) + + +def deferaction(action, val): + hd.executeDeferred(action, val) + # hd.execute_deferred_after_waiting(action, 25) + + +def checklasthk(cur): + getlastusedhk(cur) + + +def main(): + if os.path.isfile(searcher_data.searcher_settings): + settingdata = searcher_data.loadsettings() + else: + searcher_data.createdefaults() + settingdata = searcher_data.loadsettings() + + if not os.path.isfile(dbpath): + create_tables() + cur = db.cursor() + deferaction(initialsetup, cur) + else: + cur = db.cursor() + deferaction(dbupdate, cur) + + +if __name__ == '__main__': + main() diff --git a/scripts/python/searcher/HelpButton.py b/scripts/python/searcher/HelpButton.py index 9b20f35..e8b88df 100644 --- a/scripts/python/searcher/HelpButton.py +++ b/scripts/python/searcher/HelpButton.py @@ -21,22 +21,21 @@ class HelpButton(QtWidgets.QPushButton): """Generic Help button.""" - def __init__(self, name, parent=None): + def __init__(self, name, tooltip, size, parent=None): super(HelpButton, self).__init__( hou.qt.createIcon("BUTTONS_help"), "", parent=parent ) self._name = name - - self.setToolTip("Open Help Page.") - self.setIconSize(QtCore.QSize(16, 16)) - self.setMaximumSize(QtCore.QSize(16, 16)) + self.setToolTip(tooltip) + self.setIconSize(QtCore.QSize(size, size)) + self.setMaximumSize(QtCore.QSize(size, size)) self.setFlat(True) self.clicked.connect(self.display_help) def display_help(self): - """Display help page.""" + """Display help panel.""" # Look for an existing, float help browser. for pane_tab in hou.ui.paneTabs(): if isinstance(pane_tab, hou.HelpBrowser): @@ -51,4 +50,3 @@ def display_help(self): hou.paneTabType.HelpBrowser) browser.displayHelpPath("/searcher/{}".format(self._name)) -# endregion diff --git a/scripts/python/searcher/SearcherSettings.py b/scripts/python/searcher/SearcherSettings.py deleted file mode 100644 index faac6f3..0000000 --- a/scripts/python/searcher/SearcherSettings.py +++ /dev/null @@ -1,204 +0,0 @@ -# -*- coding: utf-8 -*- - -################################################################################ -## Form generated from reading UI file 'SearcherSettings.ui' -## -## Created by: Qt User Interface Compiler version 5.14.1 -## -## WARNING! All changes made in this file will be lost when recompiling UI file! -################################################################################ - -from PySide2.QtCore import (QCoreApplication, QMetaObject, QObject, QPoint, - QRect, QSize, QUrl, Qt) -from PySide2.QtGui import (QBrush, QColor, QConicalGradient, QCursor, QFont, - QFontDatabase, QIcon, QLinearGradient, QPalette, QPainter, QPixmap, - QRadialGradient) -from PySide2.QtWidgets import * - - -class Ui_SearcherSettings(object): - def setupUi(self, SearcherSettings): - if SearcherSettings.objectName(): - SearcherSettings.setObjectName(u"SearcherSettings") - SearcherSettings.setWindowModality(Qt.NonModal) - SearcherSettings.resize(600, 185) - sizePolicy = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred) - sizePolicy.setHorizontalStretch(0) - sizePolicy.setVerticalStretch(0) - sizePolicy.setHeightForWidth(SearcherSettings.sizePolicy().hasHeightForWidth()) - SearcherSettings.setSizePolicy(sizePolicy) - SearcherSettings.setMinimumSize(QSize(600, 0)) - SearcherSettings.setBaseSize(QSize(0, 0)) - self.gridLayout = QGridLayout(SearcherSettings) - self.gridLayout.setObjectName(u"gridLayout") - self.verticalLayout_4 = QVBoxLayout() - self.verticalLayout_4.setObjectName(u"verticalLayout_4") - self.horizontalLayout_4 = QHBoxLayout() - self.horizontalLayout_4.setObjectName(u"horizontalLayout_4") - self.projectTitle = QLabel(SearcherSettings) - self.projectTitle.setObjectName(u"projectTitle") - font = QFont() - font.setPointSize(15) - self.projectTitle.setFont(font) - self.projectTitle.setAlignment(Qt.AlignCenter) - - self.horizontalLayout_4.addWidget(self.projectTitle) - - self.horizontalSpacer_3 = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum) - - self.horizontalLayout_4.addItem(self.horizontalSpacer_3) - - self.inmemory_chk = QCheckBox(SearcherSettings) - self.inmemory_chk.setObjectName(u"inmemory_chk") - self.inmemory_chk.setLayoutDirection(Qt.RightToLeft) - self.inmemory_chk.setTristate(False) - - self.horizontalLayout_4.addWidget(self.inmemory_chk) - - self.windowsize_chk = QCheckBox(SearcherSettings) - self.windowsize_chk.setObjectName(u"windowsize_chk") - self.windowsize_chk.setLayoutDirection(Qt.RightToLeft) - - self.horizontalLayout_4.addWidget(self.windowsize_chk) - - - self.verticalLayout_4.addLayout(self.horizontalLayout_4) - - self.horizontalLayout_5 = QHBoxLayout() - self.horizontalLayout_5.setObjectName(u"horizontalLayout_5") - self.horizontalSpacer_2 = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum) - - self.horizontalLayout_5.addItem(self.horizontalSpacer_2) - - self.label_2 = QLabel(SearcherSettings) - self.label_2.setObjectName(u"label_2") - - self.horizontalLayout_5.addWidget(self.label_2) - - self.hkinput_txt = QLineEdit(SearcherSettings) - self.hkinput_txt.setObjectName(u"hkinput_txt") - self.hkinput_txt.setReadOnly(True) - - self.horizontalLayout_5.addWidget(self.hkinput_txt) - - self.hotkey_icon = QToolButton(SearcherSettings) - self.hotkey_icon.setObjectName(u"hotkey_icon") - self.hotkey_icon.setPopupMode(QToolButton.InstantPopup) - - self.horizontalLayout_5.addWidget(self.hotkey_icon) - - - self.verticalLayout_4.addLayout(self.horizontalLayout_5) - - self.horizontalLayout_7 = QHBoxLayout() - self.horizontalLayout_7.setObjectName(u"horizontalLayout_7") - self.label = QLabel(SearcherSettings) - self.label.setObjectName(u"label") - - self.horizontalLayout_7.addWidget(self.label) - - self.databasepath_txt = QLineEdit(SearcherSettings) - self.databasepath_txt.setObjectName(u"databasepath_txt") - - self.horizontalLayout_7.addWidget(self.databasepath_txt) - - self.dbpath_icon = QToolButton(SearcherSettings) - self.dbpath_icon.setObjectName(u"dbpath_icon") - - self.horizontalLayout_7.addWidget(self.dbpath_icon) - - - self.verticalLayout_4.addLayout(self.horizontalLayout_7) - - self.horizontalLayout = QHBoxLayout() - self.horizontalLayout.setObjectName(u"horizontalLayout") - self.label_4 = QLabel(SearcherSettings) - self.label_4.setObjectName(u"label_4") - - self.horizontalLayout.addWidget(self.label_4) - - self.test1_btn = QPushButton(SearcherSettings) - self.test1_btn.setObjectName(u"test1_btn") - - self.horizontalLayout.addWidget(self.test1_btn) - - self.test_context_btn = QPushButton(SearcherSettings) - self.test_context_btn.setObjectName(u"test_context_btn") - - self.horizontalLayout.addWidget(self.test_context_btn) - - self.cleardata_btn = QPushButton(SearcherSettings) - self.cleardata_btn.setObjectName(u"cleardata_btn") - - self.horizontalLayout.addWidget(self.cleardata_btn) - - - self.verticalLayout_4.addLayout(self.horizontalLayout) - - self.horizontalLayout_2 = QHBoxLayout() - self.horizontalLayout_2.setObjectName(u"horizontalLayout_2") - self.label_3 = QLabel(SearcherSettings) - self.label_3.setObjectName(u"label_3") - - self.horizontalLayout_2.addWidget(self.label_3) - - - self.verticalLayout_4.addLayout(self.horizontalLayout_2) - - self.horizontalLayout_3 = QHBoxLayout() - self.horizontalLayout_3.setObjectName(u"horizontalLayout_3") - self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum) - - self.horizontalLayout_3.addItem(self.horizontalSpacer) - - self.debugflag_chk = QCheckBox(SearcherSettings) - self.debugflag_chk.setObjectName(u"debugflag_chk") - self.debugflag_chk.setLayoutDirection(Qt.RightToLeft) - - self.horizontalLayout_3.addWidget(self.debugflag_chk) - - self.discard_btn = QPushButton(SearcherSettings) - self.discard_btn.setObjectName(u"discard_btn") - - self.horizontalLayout_3.addWidget(self.discard_btn) - - self.save_btn = QPushButton(SearcherSettings) - self.save_btn.setObjectName(u"save_btn") - - self.horizontalLayout_3.addWidget(self.save_btn) - - - self.verticalLayout_4.addLayout(self.horizontalLayout_3) - - - self.gridLayout.addLayout(self.verticalLayout_4, 1, 0, 1, 1) - - - self.retranslateUi(SearcherSettings) - - QMetaObject.connectSlotsByName(SearcherSettings) - # setupUi - - def retranslateUi(self, SearcherSettings): - SearcherSettings.setWindowTitle(QCoreApplication.translate("SearcherSettings", u"Form", None)) - self.projectTitle.setText(QCoreApplication.translate("SearcherSettings", u"Searcher Settings", None)) - self.inmemory_chk.setText(QCoreApplication.translate("SearcherSettings", u"Use In-Memory Database", None)) - self.windowsize_chk.setText(QCoreApplication.translate("SearcherSettings", u"Remember Search Window Size", None)) - self.label_2.setText(QCoreApplication.translate("SearcherSettings", u"Hotkey to use for opening unassigned items: ", None)) -#if QT_CONFIG(tooltip) - self.hkinput_txt.setToolTip("") -#endif // QT_CONFIG(tooltip) - self.hkinput_txt.setPlaceholderText(QCoreApplication.translate("SearcherSettings", u"Double Click", None)) - self.hotkey_icon.setText(QCoreApplication.translate("SearcherSettings", u"...", None)) - self.label.setText(QCoreApplication.translate("SearcherSettings", u"Database location: ", None)) - self.dbpath_icon.setText(QCoreApplication.translate("SearcherSettings", u"...", None)) - self.label_4.setText(QCoreApplication.translate("SearcherSettings", u"Maintenance utilities:", None)) - self.test1_btn.setText(QCoreApplication.translate("SearcherSettings", u"Test Button 1", None)) - self.test_context_btn.setText(QCoreApplication.translate("SearcherSettings", u"Test HContext", None)) - self.cleardata_btn.setText(QCoreApplication.translate("SearcherSettings", u"Clear Data", None)) - self.label_3.setText("") - self.debugflag_chk.setText(QCoreApplication.translate("SearcherSettings", u"Debug Mode", None)) - self.discard_btn.setText(QCoreApplication.translate("SearcherSettings", u"Discard", None)) - self.save_btn.setText(QCoreApplication.translate("SearcherSettings", u"Save", None)) - # retranslateUi - diff --git a/scripts/python/searcher/_conversion/searcher.py b/scripts/python/searcher/_conversion/searcher.py deleted file mode 100644 index 9896c2a..0000000 --- a/scripts/python/searcher/_conversion/searcher.py +++ /dev/null @@ -1,831 +0,0 @@ -# region Imports -from __future__ import print_function -from __future__ import absolute_import -import weakref - -from searcher import util -from searcher import database -from searcher import datahandler -from searcher import searcher_ui -from searcher import searcher_data -from searcher import searcher_settings - -from pyautogui import press, typewrite, hotkey -import hou -from husdui.common import error_print, debug_print -import toolutils -import drivertoolutils -import platform -import objecttoolutils -import os -import sys -import hdefereval as hd -import stateutils -hver = 0 -if os.environ["HFS"] != "": - ver = os.environ["HFS"] - hver = int(ver[ver.rindex('.')+1:]) - from hutil.Qt import QtGui - from hutil.Qt import QtCore - from hutil.Qt import QtWidgets - if int(hver) >= 391: - from hutil.Qt import _QtUiTools - elif int(hver) < 391: - from hutil.Qt import QtUiTools - -reload(searcher_settings) -reload(searcher_data) -reload(searcher_ui) -reload(datahandler) -reload(database) -reload(util) -# endregion - -# region ------------------------------------------------------------- App Info -__package__ = "Searcher" -__version__ = "0.1b" -__author__ = "instance.id" -__copyright__ = "2020 All rights reserved. See LICENSE for more details." -__status__ = "Prototype" -# endregion - -# region ------------------------------------------------------------- Variables / Constants -kwargs = {} -settings = {} -hasran = False -isdebug = False -mousePos = None -cur_screen = QtWidgets.QDesktopWidget().screenNumber( - QtWidgets.QDesktopWidget().cursor().pos() -) -screensize = QtWidgets.QDesktopWidget().screenGeometry(cur_screen) -centerPoint = QtWidgets.QDesktopWidget().availableGeometry(cur_screen).center() - -sys.path.append(os.path.join(os.path.dirname(__file__))) -script_path = os.path.dirname(os.path.realpath(__file__)) -name = "Searcher" - -parent_widget = hou.qt.mainWindow() -searcher_window = QtWidgets.QMainWindow() -# endregion - -# region ------------------------------------------------------------- Class Functions - - -def keyconversion(key): - for i in range(len(key)): - if key[i] in util.KEYCONVERSIONS: - key[i] = util.KEYCONVERSIONS[key[i]] - return key -# endregion - -# region ------------------------------------------------------------- Searcher Class - - -class Searcher(QtWidgets.QWidget): - """instance.id Searcher for Houdini""" - - def __init__(self, kwargs, settings, windowsettings): - super(Searcher, self).__init__(hou.qt.mainWindow()) - mainui = searcher_ui.Ui_Searcher() - mainui.setupUi(self) - mainui.retranslateUi(self) - - self._drag_active = False - - # Setting vars - kwargs = kwargs - self.settingdata = settings - self.windowsettings = windowsettings - self.isdebug = util.bc(self.settingdata[util.SETTINGS_KEYS[4]]) - self.menuopened = False - self.windowispin = util.bc(self.settingdata[util.SETTINGS_KEYS[5]]) - self.originalsize = self.settingdata[util.SETTINGS_KEYS[3]] - - # if hver >= 391: - self.app = QtWidgets.QApplication.instance() - - # UI Vars - self.handler, self.tmpkey = self.initialsetup() - self.ui = searcher_settings.SearcherSettings(self.handler, self.tmpkey) - - # Functional Vars - self.lastused = {} - self.tmpsymbol = None - self.searching = False - self.ctxsearch = False - self.showglobal = True - self.previous_pos = None - self.searchprefix = False - self.keys_changed = False - self.searchdescription = False - self.searchcurrentcontext = False - - # Functionals - hou.hotkeys._createBackupTables() - self.uisetup(mainui) - - # Event System Initialization - self.installEventFilter(self) - self.searchbox.installEventFilter(self) - self.pinwindow.installEventFilter(self) - self.searchfilter.installEventFilter(self) - self.opensettingstool.installEventFilter(self) - self.searchresultstree.installEventFilter(self) - - # region ------------------------------------------------------------- Settings - def open_settings(self): - self.ui.setWindowTitle('Searcher - Settings') - self.ui.show() - self.ui.setFocus() - # endregion - - # region ------------------------------------------------------------- UI - def setupContext(self): - cols = 4 - self.searchresultstree.setColumnCount(cols) - self.searchresultstree.setColumnWidth(0, 250) - if self.isdebug: - self.searchresultstree.setColumnWidth(1, 350) - else: - self.searchresultstree.setColumnWidth(1, 450) - self.searchresultstree.setColumnWidth(2, 100) - self.searchresultstree.setColumnWidth(3, 150) - if self.isdebug: - self.searchresultstree.setColumnWidth(4, 150) - self.searchresultstree.setHeaderLabels([ - "Label", - "Description", - "Assignments", - "Symbol", - "Context" - ]) - else: - self.searchresultstree.setHeaderLabels([ - "Label", - "Description", - "Assignments", - "Symbol" - ]) - - def uisetup(self, mainui): - self.main_widget = QtWidgets.QWidget(self) - - # Load UI File - loader = None - if int(hver) >= 391: - loader = _QtUiTools.QUiLoader() - else: - loader = QtUiTools.QUiLoader() - - # mainui = loader.load(script_path + "/searcher_ui.py") - # mainui = loader.load(script_path + "/searcher_ui.py") - - names = ["open", "save", "hotkey", "perference"] - self.completer = QtWidgets.QCompleter(names) - - # Get UI Elements - # self.searchfilter = mainui.findChild( - # QtWidgets.QToolButton, - # "searchfilter_btn" - # ) - - # self.searchfilter = mainui( - # QtWidgets.QToolButton, - # "searchfilter_btn" - # ) - - # self.pinwindow = mainui.findChild( - # QtWidgets.QToolButton, - # "pinwindow_btn" - # ) - # self.opensettingstool = mainui.findChild( - # QtWidgets.QToolButton, - # "opensettings_btn" - # ) - # self.searchresultstree = mainui.findChild( - # QtWidgets.QTreeWidget, - # "searchresults_tree" - # ) - # self.searchbox = mainui.findChild( - # QtWidgets.QLineEdit, - # "searchbox_txt" - # ) - # self.infolbl = mainui.findChild( - # QtWidgets.QLabel, - # "info_lbl" - # ) - # - - self.searchfilter = mainui.searchfilter_btn - self.pinwindow = mainui.pinwindow_btn - self.opensettingstool = mainui.opensettings_btn - self.searchresultstree = mainui.searchresults_tree - self.searchbox = mainui.searchbox_txt - self.infolbl = mainui.info_lbl - - self.searchbox.setPlaceholderText(" Search..") - self.searchbox.setFocusPolicy(QtCore.Qt.StrongFocus) - self.searchbox.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) - self.searchbox.setClearButtonEnabled(True) - - self.searchfilter.clicked.connect(self.searchfilter_cb) - searchfilter_button_size = hou.ui.scaledSize(12) - self.searchfilter.setProperty("flat", True) - self.searchfilter.setIcon(util.SEARCH_ICON) - self.searchfilter.setIconSize(QtCore.QSize( - searchfilter_button_size, - searchfilter_button_size - )) - - self.pinwindow.clicked.connect(self.pinwindow_cb) - self.setpinicon() - pinwindow_button_size = hou.ui.scaledSize(16) - self.pinwindow.setProperty("flat", True) - self.pinwindow.setIconSize(QtCore.QSize( - pinwindow_button_size, - pinwindow_button_size - )) - - self.opensettingstool.clicked.connect(self.opensettings_cb) - opensettingstool_button_size = hou.ui.scaledSize(16) - self.opensettingstool.setProperty("flat", True) - self.opensettingstool.setIcon(util.SETTINGS_ICON) - self.opensettingstool.setIconSize(QtCore.QSize( - opensettingstool_button_size, - opensettingstool_button_size - )) - - self.searchbox.textChanged.connect(self.textchange_cb) - self.searchbox.customContextMenuRequested.connect(self.openmenu) - self.searchresultstree.itemActivated.connect(self.searchclick_cb) - - # Layout - mainlayout = QtWidgets.QVBoxLayout() - mainlayout.setAlignment(QtCore.Qt.AlignBottom) - mainlayout.setContentsMargins(0, 0, 0, 0) - mainlayout.setGeometry(QtCore.QRect(0, 0, 1400, 1200)) - - mainlayout.addWidget(HelpButton("main")) - mainlayout.addWidget(self.searchfilter) - mainlayout.addWidget(self.pinwindow) - mainlayout.addWidget(self.opensettingstool) - mainlayout.addWidget(self.searchresultstree) - mainlayout.addWidget(self.searchbox) - mainlayout.addWidget(self.infolbl) - self.setLayout(mainlayout) - - self.searchbox.setToolTip( - 'Begin typing to search or click magnifying glass icon to display options') - self.pinwindow.setToolTip( - 'Pin the search window to keep it from closing automatically when losing focus') - self.searchfilter.setToolTip( - 'Select a predefined filter') - self.opensettingstool.setToolTip( - 'General application settings') - self.searchresultstree.setToolTip( - 'Double click an action to attempt to perform it. Some actions only work in specific contexts') - - self.setupContext() - self.searchbox.setFocus() - self.searchbox.grabKeyboard() - - # endregion - - # region ------------------------------------------------------------- Initial Setup - - def initialsetup(self): - self.handler = datahandler.DataHandler(self.isdebug) - currentidx = hou.hotkeys.changeIndex() - chindex = self.handler.getchangeindex() - - if len(chindex) == 0: - chindex = int(currentidx) - self.handler.updatechangeindex(chindex, True) - self.handler.updatedataasync() - hou.ui.setStatusMessage( - "Searcher database created", - severity=hou.severityType.Message - ) - else: - chindex = int(chindex[0][0]) - - if int(currentidx) != chindex: - self.handler.updatedataasync() - self.handler.updatechangeindex(int(currentidx)) - - tmpkey = self.handler.getdefaulthotkey() - self.tmpkey = tmpkey[0][0] - return self.handler, self.tmpkey - - def getnode(self): - nodeSelect = hou.selectedNodes() - for node in nodeSelect: - getName = node.name() - print(getName) - - def getpane(self): - try: - return hou.ui.paneTabUnderCursor().type() - except (AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("No context options to display" + str(e)), - severity=hou.severityType.Message - ) - - # endregion - # region ------------------------------------------------------------- Callbacks - def searchfilter_cb(self): - self.openmenu() - - def pinwindow_cb(self): - self.windowispin = not self.windowispin - self.settingdata[util.SETTINGS_KEYS[5]] = self.windowispin - searcher_data.savesettings(self.settingdata) - self.setpinicon() - - def setpinicon(self): - if self.windowispin: - self.pinwindow.setIcon(util.PIN_IN_ICON) - else: - self.pinwindow.setIcon(util.PIN_OUT_ICON) - - def opensettings_cb(self): - self.open_settings() - - def globalkeysearch(self): - self.ctxsearch = True - ctx = [] - ctx.append("h") - results = self.handler.searchctx(ctx) - self.searchtablepopulate(results) - self.ctxsearch = False - - def ctxsearcher(self, ctx=None): - results = None - ctxresult = [] - - if ctx is None: - self.ctxsearch = True - if self.isdebug: - print(self.getpane()) - ctxresult = util.PANETYPES[self.getpane()] - results = self.handler.searchctx(ctxresult) - - elif ctx == ":v": - self.ctxsearch = True - ctxresult.append("h.pane") - results = self.handler.searchctx(ctxresult) - - elif ctx == ":c": - self.ctxsearch = True - ctxresult = util.PANETYPES[self.getpane()] - if self.isdebug: - print(self.getpane()) - results = self.handler.searchctx(ctxresult) - - elif ctx == ":g": - self.ctxsearch = True - ctxresult.append("h") - results = self.handler.searchctx(ctxresult) - - self.searchtablepopulate(results) - self.ctxsearch = False - self.searchbox.clearFocus() - self.searchresultstree.setFocus() - self.searchresultstree.setCurrentItem( - self.searchresultstree.topLevelItem(0).child(0) - ) - - def textchange_cb(self, text): - if len(text) > 0: - self.infolbl.setText(self.searchresultstree.toolTip()) - if text in util.CTXSHOTCUTS: - self.ctxsearcher(text) - elif len(text) > 1 and text not in util.CTXSHOTCUTS: - self.searching = True - txt = self.handler.searchtext(text) - self.searchtablepopulate(txt) - else: - self.searching = False - self.searchresultstree.clear() - self.infolbl.setText( - "Begin typing to search or click magnifying glass icon to display options") - - def searchclick_cb(self, item, column): - hk = item.text(2) - self.tmpsymbol = item.text(3) - - if hk == "": - self.chindex = hou.hotkeys.changeIndex() - result = self.createtemphotkey(self.tmpsymbol) - if result is True: - self.chindex = hou.hotkeys.changeIndex() - hk = hou.hotkeys.assignments(self.tmpsymbol) - self.processkey(hk, True) - else: - hk = hou.hotkeys.assignments(self.tmpsymbol) - self.processkey(hk) - self.tmpsymbol = None - return - # endregion - - # region ------------------------------------------------------------- Searchbar Menu - def openmenu(self): - self.menuopened = True - self.searchmenu = QtWidgets.QMenu() - self.searchmenu.setProperty('flat', True) - self.searchmenu.setStyleSheet(util.MENUSTYLE) - self.searchmenu.setWindowFlags( - self.searchmenu.windowFlags() | - QtCore.Qt.NoDropShadowWindowHint - ) - self.globalprefix = self.searchmenu.addAction("Global items") - self.contextprefix = self.searchmenu.addAction("Context items") - self.viewprefix = self.searchmenu.addAction("View items") - - self.globalprefix.setToolTip( - "View application-wide actions") - - self.contextprefix.setToolTip( - "Shows possible actions for the view in which the mouse was in when the window was opened") - - self.viewprefix.setToolTip( - "Shows the available view panes (ex. Scene View, Render View, Composit View, etc") - - self.searchmenu.hovered.connect(self.handlemenuhovered) - - self.action = self.searchmenu.exec_( - self.searchbox.mapToGlobal(QtCore.QPoint(0, 20))) - if self.action == self.globalprefix: - self.searchbox.setText(":g") - if self.action == self.contextprefix: - self.searchbox.setText(":c") - if self.action == self.viewprefix: - self.searchbox.setText(":v") - - self.searchmenu.installEventFilter(self) - - def handlemenuhovered(self, action): - self.infolbl.setText(action.toolTip()) - - def getContext(self, ctx): - """Return Houdini context string.""" - try: - hou_context = ctx.pwd().childTypeCategory().name() - except: - return None - - print ("Hou Context: ", hou_context) - return util.CONTEXTTYPE[hou_context] - - # endregion - - # region ------------------------------------------------------------- Search Functionality - def searchtablepopulate(self, data): - if len(data) > 0: - self.searchresultstree.clear() - hotkeys = [] - context_list = [] - hcontext_tli = {} - - for i in range(len(data)): - if data[i][4] not in context_list: - if self.ctxsearch: - context_list.append(data[i][4]) - else: - context_list.append(data[i][4]) - - result = self.handler.gethcontextod(context_list) - - for hc in range(len(result)): - hcontext_tli[result[hc][2]] = (QtWidgets.QTreeWidgetItem( - self.searchresultstree, [ - result[hc][0], - result[hc][1] - ] - )) - self.searchresultstree.expandItem(hcontext_tli[result[hc][2]]) - - base_keys = hcontext_tli.keys() - for i in range(len(data)): - for j in range(len(base_keys)): - if base_keys[j] in data[i][4]: - if self.isdebug: - hotkeys.append(QtWidgets.QTreeWidgetItem( - hcontext_tli[base_keys[j]], [ - data[i][0], - data[i][1], - data[i][2], - data[i][3], - data[i][4] - ] - )) - else: - hotkeys.append(QtWidgets.QTreeWidgetItem( - hcontext_tli[base_keys[j]], [ - data[i][0], - data[i][1], - data[i][2], - data[i][3] - ] - )) - # endregion - - # region ------------------------------------------------------------- Hotkey Processing - - def processkey(self, key, tmphk=False): - hk = key - if tmphk: - lastkey = (str(self.tmpsymbol) + " " + str(hk[0])) - self.handler.updatelasthk(lastkey) - - key = key[0].split('+') - - mods = [] - skey = None - ikey = None - key = keyconversion(key) - modifiers = util.MODIFIERS - mod_flag = QtCore.Qt.KeyboardModifiers() - for i in range(len(key)): - if str(key[i]) in modifiers: - mod_flag = mod_flag | util.MODIFIERS[str(key[i])] - else: - skey = key[i] - ikey = util.KEY_DICT[str(key[i])] - - keypress = QtGui.QKeyEvent( - QtGui.QKeyEvent.KeyPress, # Keypress event identifier - ikey, # Qt key identifier - mod_flag, # Qt key modifier - skey # String of Qt key identifier - ) - - hou.ui.mainQtWindow().setFocus() - try: - hd.executeDeferred(self.app.sendEvent, - hou.ui.mainQtWindow(), keypress) - self.close() - - except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not trigger hotkey event: " + str(e)), - severity=hou.severityType.Warning - ) - - def setKeysChanged(self, changed): - if self.keys_changed and not changed: - if not hou.hotkeys.saveOverrides(): - print("ERROR: Couldn't save hotkey override file.") - self.keys_changed = changed - self.chindex = hou.hotkeys.changeIndex() - self.handler.updatechangeindex(self.chindex) - - def createtemphotkey(self, symbol): - hou.hotkeys._createBackupTables() - result = hou.hotkeys.addAssignment(symbol, self.tmpkey) - self.keys_changed = True - self.setKeysChanged(False) - return result - - def removetemphotkey(self, symbol, tmpkey): - hou.hotkeys._restoreBackupTables() - hou.hotkeys.revertToDefaults(symbol, True) - self.keys_changed = True - self.setKeysChanged(False) - # endregion - - # region ------------------------------------------------------------- Events - def checktooltip(self, obj): - if obj == self.searchresultstree: - if self.searching: - self.infolbl.setText(obj.toolTip()) - else: - self.infolbl.setText(self.searchbox.toolTip()) - else: - self.infolbl.setText(obj.toolTip()) - - def eventFilter(self, obj, event): - # ---------------------------------------------------- Mouse - if event.type() == QtCore.QEvent.Enter: - self.checktooltip(obj) - elif event.type() == QtCore.QEvent.Leave: - self.infolbl.setText("") - if event.type() == QtCore.QEvent.ToolTip: - return True - - if event.type() == QtCore.QEvent.MouseButtonPress: - if obj == self.searchbox: - return QtCore.QObject.eventFilter(self, obj, event) - else: - self.previous_pos = event.globalPos() - return QtCore.QObject.eventFilter(self, obj, event) - - if event.type() == QtCore.QEvent.MouseMove: - if obj == self: - delta = event.globalPos() - self.previous_pos - self.move(self.x() + delta.x(), self.y()+delta.y()) - self.previous_pos = event.globalPos() - self._drag_active = True - else: - return QtCore.QObject.eventFilter(self, obj, event) - - if event.type() == QtCore.QEvent.MouseButtonRelease: - if self._drag_active: - self._drag_active = False - - # ------------------------------------------------- Keypress - if event.type() == QtCore.QEvent.KeyPress: - if event.key() == QtCore.Qt.Key_Tab: - if self.searching: - self.searchbox.releaseKeyboard() - self.searchbox.clearFocus() - self.searchresultstree.setFocus() - self.searchresultstree.setCurrentItem( - self.searchresultstree.topLevelItem(0).child(0)) - return True - else: - # self.searchbox.releaseKeyboard() - # self.searchbox.clearFocus() - if self.menuopened: - self.searchmenu.setFocus() - else: - self.searchbox.setText(":c") - self.ctxsearcher() - self.searchresultstree.setFocus() - self.searchresultstree.setCurrentItem( - self.searchresultstree.topLevelItem(0).child(0)) - return True - if event.key() == QtCore.Qt.Key_Escape: - if self.ui.isVisible(): - pass - else: - if self.menuopened: - if self.searchmenu.isVisible(): - self.searchmenu.setVisible(False) - return QtCore.QObject.eventFilter(self, obj, event) - else: - self.menuopened = False - else: - self.close() - if event.key() == QtCore.Qt.Key_Colon: - if self.searchbox.text() == "": - self.searchbox.releaseKeyboard() - self.searchbox.clearFocus() - self.openmenu() - return True - - # ------------------------------------------------- Window - if event.type() == QtCore.QEvent.WindowActivate: - self.searchbox.grabKeyboard() - elif event.type() == QtCore.QEvent.WindowDeactivate: - if self.ui.isVisible(): - self.searchbox.releaseKeyboard() - return QtCore.QObject.eventFilter(self, obj, event) - if self.windowispin: - return QtCore.QObject.eventFilter(self, obj, event) - else: - self.close() - elif event.type() == QtCore.QEvent.FocusIn: - if obj == self.window: - self.searchbox.grabKeyboard() - elif event.type() == QtCore.QEvent.FocusOut: - pass - - # ------------------------------------------------- Close - if event.type() == QtCore.QEvent.Close: - try: - if util.bc(self.settingdata[util.SETTINGS_KEYS[2]]): - self.windowsettings.setValue( - "geometry", - self.saveGeometry() - ) - except (AttributeError, TypeError) as e: - if hou.isUIAvailable(): - hou.ui.setStatusMessage( - ("Could not save window dimensions: " + str(e)), - severity=hou.severityType.Warning - ) - else: - print("Could not save window dimensions: " + str(e)) - - if self.menuopened: - self.searchmenu.setVisible(False) - - if self.tmpsymbol is not None: - hd.executeDeferred( - self.removetemphotkey, - self.tmpsymbol, - self.tmpkey - ) - self.searchbox.releaseKeyboard() - try: - self.parent().setFocus() - self.setParent(None) - self.deleteLater() - except: - self.setParent(None) - self.deleteLater() - return QtCore.QObject.eventFilter(self, obj, event) - - # endregion -# endregion - - -# region ----------------------------------------------------------------- Help -class HelpButton(QtWidgets.QPushButton): - """Generic Help button.""" - - def __init__(self, name, parent=None): - super(HelpButton, self).__init__( - hou.qt.createIcon("BUTTONS_help"), "", parent=parent - ) - - self._name = name - - self.setToolTip("Show Help.") - self.setIconSize(QtCore.QSize(14, 14)) - self.setMaximumSize(QtCore.QSize(14, 14)) - self.setFlat(True) - - self.clicked.connect(self.display_help) - - # ------------------------------------------------------------------------- - # METHODS - # ------------------------------------------------------------------------- - - def display_help(self): - """Display help page.""" - # Look for an existing, float help browser. - for pane_tab in hou.ui.paneTabs(): - if isinstance(pane_tab, hou.HelpBrowser): - if pane_tab.isFloating(): - browser = pane_tab - break - - # Didn't find one, so create a new floating browser. - else: - desktop = hou.ui.curDesktop() - browser = desktop.createFloatingPaneTab( - hou.paneTabType.HelpBrowser) - - browser.displayHelpPath("/searcher/{}".format(self._name)) -# endregion - - -# region ----------------------------------------------------------------- Setup Functions - - -def center(): - return parent_widget.mapToGlobal( - QtCore.QPoint( - parent_widget.rect().center().x(), - parent_widget.rect().center().y() - ) - ) - - -def CreateSearcherPanel(kwargs, searcher_window=None): - kwargs = kwargs - try: - searcher_window.close() - searcher_window.deleteLater() - except: - pass - - settings = searcher_data.loadsettings() - windowsettings = QtCore.QSettings("instance.id", "Searcher") - - searcher_window = Searcher(kwargs, settings, windowsettings) - searcher_window.setWindowFlags( - # searcher_window.windowFlags() | - QtCore.Qt.Tool | - # QtCore.Qt.WindowSystemMenuHint | - # QtCore.Qt.WindowTitleHint | - QtCore.Qt.CustomizeWindowHint | - # QtCore.Qt.FramelessWindowHint - QtCore.Qt.WindowStaysOnTopHint - ) - - # util.SETTINGS_KEYS[2] = savewindowsize - # util.SETTINGS_KEYS[3] = windowsize - if util.bc(settings[util.SETTINGS_KEYS[2]]) and windowsettings.value("geometry") is not None: - searcher_window.restoreGeometry(windowsettings.value("geometry")) - else: - searcher_window.resize( - int(settings[util.SETTINGS_KEYS[3]][0]), - int(settings[util.SETTINGS_KEYS[3]][1]) - ) - pos = center() - searcher_window.setGeometry( - pos.x() - (searcher_window.width() / 2), - pos.y() - (searcher_window.height() / 2), - searcher_window.width(), - searcher_window.height() - ) - searcher_window.searchbox.setFocus() - searcher_window.setWindowTitle('Searcher') - searcher_window.show() - searcher_window.activateWindow() -# endregion diff --git a/scripts/python/searcher/_conversion/searcher_ui.py b/scripts/python/searcher/_conversion/searcher_ui.py deleted file mode 100644 index 358cb3f..0000000 --- a/scripts/python/searcher/_conversion/searcher_ui.py +++ /dev/null @@ -1,196 +0,0 @@ -# -*- coding: utf-8 -*- - -################################################################################ -# Form generated from reading UI file 'searcher_ui.ui' -## -# Created by: Qt User Interface Compiler version 5.14.1 -## -# WARNING! All changes made in this file will be lost when recompiling UI file! -################################################################################ - -# from PySide2.QtCore import (QCoreApplication, QMetaObject, QObject, QPoint, -# QRect, QSize, QUrl, Qt) -# from PySide2.QtGui import (QBrush, QColor, QConicalGradient, QCursor, QFont, -# QFontDatabase, QIcon, QLinearGradient, QPalette, QPainter, QPixmap, -# QRadialGradient) -# from PySide2.QtWidgets import * - -from hutil.Qt.QtCore import (QCoreApplication, QMetaObject, QObject, QPoint, - QRect, QSize, QUrl, Qt) -from hutil.Qt.QtGui import (QBrush, QColor, QConicalGradient, QCursor, QFont, - QFontDatabase, QIcon, QLinearGradient, QPalette, QPainter, QPixmap, - QRadialGradient) -from hutil.Qt.QtWidgets import * - - -class Ui_Searcher(object): - def setupUi(self, Searcher): - if Searcher.objectName(): - Searcher.setObjectName(u"Searcher") - Searcher.setWindowModality(Qt.WindowModal) - Searcher.resize(1000, 329) - sizePolicy = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - sizePolicy.setHorizontalStretch(0) - sizePolicy.setVerticalStretch(0) - sizePolicy.setHeightForWidth(Searcher.sizePolicy().hasHeightForWidth()) - Searcher.setSizePolicy(sizePolicy) - Searcher.setMinimumSize(QSize(0, 0)) - Searcher.setBaseSize(QSize(1000, 350)) - Searcher.setStyleSheet(u"QTreeWidget QHeaderView::section {\n" - " font-size: 9pt;\n" - "}") - self.gridLayout = QGridLayout(Searcher) - self.gridLayout.setSpacing(0) - self.gridLayout.setObjectName(u"gridLayout") - self.gridLayout.setContentsMargins(0, 0, 0, 0) - self.verticalLayout = QVBoxLayout() - self.verticalLayout.setSpacing(0) - self.verticalLayout.setObjectName(u"verticalLayout") - self.horizontalLayout = QHBoxLayout() - self.horizontalLayout.setSpacing(0) - self.horizontalLayout.setObjectName(u"horizontalLayout") - self.horizontalSpacer_2 = QSpacerItem( - 8, 2, QSizePolicy.Fixed, QSizePolicy.Minimum) - - self.horizontalLayout.addItem(self.horizontalSpacer_2) - - self.projectTitle = QLabel(Searcher) - self.projectTitle.setObjectName(u"projectTitle") - sizePolicy1 = QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Preferred) - sizePolicy1.setHorizontalStretch(0) - sizePolicy1.setVerticalStretch(0) - sizePolicy1.setHeightForWidth( - self.projectTitle.sizePolicy().hasHeightForWidth()) - self.projectTitle.setSizePolicy(sizePolicy1) - font = QFont() - font.setPointSize(15) - self.projectTitle.setFont(font) - self.projectTitle.setAlignment(Qt.AlignCenter) - - self.horizontalLayout.addWidget(self.projectTitle) - - self.horizontalSpacer = QSpacerItem( - 40, 5, QSizePolicy.Expanding, QSizePolicy.Minimum) - - self.horizontalLayout.addItem(self.horizontalSpacer) - - self.pinwindow_btn = QToolButton(Searcher) - self.pinwindow_btn.setObjectName(u"pinwindow_btn") - - self.horizontalLayout.addWidget(self.pinwindow_btn) - - self.opensettings_btn = QToolButton(Searcher) - self.opensettings_btn.setObjectName(u"opensettings_btn") - - self.horizontalLayout.addWidget(self.opensettings_btn) - - self.horizontalSpacer_3 = QSpacerItem( - 8, 2, QSizePolicy.Fixed, QSizePolicy.Minimum) - - self.horizontalLayout.addItem(self.horizontalSpacer_3) - - self.verticalLayout.addLayout(self.horizontalLayout) - - self.horizontalLayout_3 = QHBoxLayout() - self.horizontalLayout_3.setSpacing(0) - self.horizontalLayout_3.setObjectName(u"horizontalLayout_3") - self.frame = QFrame(Searcher) - self.frame.setObjectName(u"frame") - sizePolicy2 = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred) - sizePolicy2.setHorizontalStretch(2) - sizePolicy2.setVerticalStretch(0) - sizePolicy2.setHeightForWidth( - self.frame.sizePolicy().hasHeightForWidth()) - self.frame.setSizePolicy(sizePolicy2) - self.frame.setMinimumSize(QSize(0, 20)) - self.frame.setFrameShape(QFrame.StyledPanel) - self.frame.setFrameShadow(QFrame.Raised) - self.searchfilter_btn = QToolButton(self.frame) - self.searchfilter_btn.setObjectName(u"searchfilter_btn") - self.searchfilter_btn.setGeometry(QRect(0, 0, 26, 20)) - self.searchfilter_btn.setBaseSize(QSize(16, 16)) - self.searchfilter_btn.setStyleSheet( - u"background-color: rgb(19, 19, 19);") - self.searchfilter_btn.setArrowType(Qt.NoArrow) - - self.horizontalLayout_3.addWidget(self.frame) - - self.searchbox_txt = QLineEdit(Searcher) - self.searchbox_txt.setObjectName(u"searchbox_txt") - sizePolicy3 = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Minimum) - sizePolicy3.setHorizontalStretch(99) - sizePolicy3.setVerticalStretch(0) - sizePolicy3.setHeightForWidth( - self.searchbox_txt.sizePolicy().hasHeightForWidth()) - self.searchbox_txt.setSizePolicy(sizePolicy3) - self.searchbox_txt.setMinimumSize(QSize(50, 0)) - self.searchbox_txt.setMouseTracking(False) - self.searchbox_txt.setStyleSheet(u"background-color: rgb(19, 19, 19);") - self.searchbox_txt.setFrame(False) - - self.horizontalLayout_3.addWidget(self.searchbox_txt) - - self.verticalLayout.addLayout(self.horizontalLayout_3) - - self.searchresults_tree = QTreeWidget(Searcher) - __qtreewidgetitem = QTreeWidgetItem() - __qtreewidgetitem.setText(0, u"1") - self.searchresults_tree.setHeaderItem(__qtreewidgetitem) - self.searchresults_tree.setObjectName(u"searchresults_tree") - sizePolicy4 = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) - sizePolicy4.setHorizontalStretch(0) - sizePolicy4.setVerticalStretch(0) - sizePolicy4.setHeightForWidth( - self.searchresults_tree.sizePolicy().hasHeightForWidth()) - self.searchresults_tree.setSizePolicy(sizePolicy4) - font1 = QFont() - font1.setPointSize(9) - self.searchresults_tree.setFont(font1) - self.searchresults_tree.setMouseTracking(False) - self.searchresults_tree.setFocusPolicy(Qt.NoFocus) - self.searchresults_tree.setFrameShadow(QFrame.Sunken) - self.searchresults_tree.setLineWidth(0) - self.searchresults_tree.setSizeAdjustPolicy( - QAbstractScrollArea.AdjustToContents) - self.searchresults_tree.setAlternatingRowColors(True) - self.searchresults_tree.setSelectionMode( - QAbstractItemView.SingleSelection) - self.searchresults_tree.setSelectionBehavior( - QAbstractItemView.SelectRows) - - self.verticalLayout.addWidget(self.searchresults_tree) - - self.gridLayout.addLayout(self.verticalLayout, 1, 0, 1, 1) - - self.info_lbl = QLabel(Searcher) - self.info_lbl.setObjectName(u"info_lbl") - font2 = QFont() - font2.setPointSize(8) - font2.setBold(False) - font2.setWeight(50) - self.info_lbl.setFont(font2) - self.info_lbl.setStyleSheet(u"background-color: rgb(26, 26, 26);") - self.info_lbl.setMargin(2) - self.info_lbl.setIndent(5) - - self.gridLayout.addWidget(self.info_lbl, 2, 0, 1, 1) - - self.retranslateUi(Searcher) - - QMetaObject.connectSlotsByName(Searcher) - # setupUi - - def retranslateUi(self, Searcher): - Searcher.setWindowTitle(QCoreApplication.translate( - "Searcher", u"Searcher", None)) - self.projectTitle.setText( - QCoreApplication.translate("Searcher", u"Searcher", None)) - self.pinwindow_btn.setText( - QCoreApplication.translate("Searcher", u"...", None)) - self.opensettings_btn.setText( - QCoreApplication.translate("Searcher", u"...", None)) - self.searchfilter_btn.setText( - QCoreApplication.translate("Searcher", u"...", None)) - self.info_lbl.setText(QCoreApplication.translate( - "Searcher", u"Begin typing to search or click magnifying glass icon to display options", None)) - # retranslateUi diff --git a/scripts/python/searcher/about.py b/scripts/python/searcher/about.py new file mode 100644 index 0000000..784a620 --- /dev/null +++ b/scripts/python/searcher/about.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import +from searcher import about_ui +from searcher import util +import os +import sys + +import hou +hver = 0 +if os.environ["HFS"] != "": + ver = os.environ["HFS"] + hver = int(ver[ver.rindex('.')+1:]) + from hutil.Qt import QtGui + from hutil.Qt import QtCore + from hutil.Qt import QtWidgets + if hver >= 395: + from hutil.Qt import QtUiTools + elif hver <= 394 and hver >= 391: + from hutil.Qt import _QtUiTools + elif hver < 391 and hver >= 348: + from hutil.Qt import QtUiTools + +scriptpath = os.path.dirname(os.path.realpath(__file__)) + + +class About(QtWidgets.QWidget): + """ Searcher Settings and Debug Menu""" + + def __init__(self, parent=None): + super(About, self).__init__(parent=parent) + self.setParent(parent) + self.ui = about_ui.Ui_About() + self.ui.setupUi(self) + self.ui.retranslateUi(self) + diff --git a/scripts/python/searcher/about_ui.py b/scripts/python/searcher/about_ui.py new file mode 100644 index 0000000..fd4d583 --- /dev/null +++ b/scripts/python/searcher/about_ui.py @@ -0,0 +1,89 @@ +from hutil.Qt import QtCore, QtGui, QtWidgets +import os + +scriptpath = os.path.dirname(os.path.realpath(__file__)) + + +class Ui_About(object): + def setupUi(self, About): + About.setObjectName("About") + About.setWindowModality(QtCore.Qt.NonModal) + About.resize(185, 251) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(About.sizePolicy().hasHeightForWidth()) + About.setSizePolicy(sizePolicy) + About.setMinimumSize(QtCore.QSize(100, 0)) + About.setBaseSize(QtCore.QSize(0, 0)) + About.setStyleSheet("") + self.gridLayout = QtWidgets.QGridLayout(About) + self.gridLayout.setContentsMargins(-1, -1, -1, 6) + self.gridLayout.setSpacing(0) + self.gridLayout.setObjectName("gridLayout") + self.verticalLayout_4 = QtWidgets.QVBoxLayout() + self.verticalLayout_4.setObjectName("verticalLayout_4") + + # ------------------------------------------------------ logo + # NOTE logo ------------------------------------------------- + self.logo = QtWidgets.QLabel(About) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.logo.sizePolicy().hasHeightForWidth()) + self.logo.setSizePolicy(sizePolicy) + self.logo.setMaximumSize(QtCore.QSize(170, 170)) + self.logo.setText("") + self.logo.setPixmap(QtGui.QPixmap(scriptpath + "/images/logo.png")) + self.logo.setScaledContents(True) + self.logo.setObjectName("logo") + self.verticalLayout_4.addWidget(self.logo) + + # ------------------------------------------------- headerrow + # NOTE headerrow -------------------------------------------- + self.headerrow = QtWidgets.QHBoxLayout() + self.headerrow.setObjectName("headerrow") + + self.github = LinkLabel(About, 'github/instance-id') + self.github.setObjectName("github") + self.headerrow.addWidget(self.github) + self.verticalLayout_4.addLayout(self.headerrow) + + # ------------------------------------------------- secondrow + # NOTE Second Row ------------------------------------------- + self.secondrow = QtWidgets.QHBoxLayout() + self.secondrow.setObjectName("secondrow") + + self.web = LinkLabel(About, "instance.id") + self.web.setObjectName("web") + self.secondrow.addWidget(self.web) + self.verticalLayout_4.addLayout(self.secondrow) + + # -------------------------------------------------- thirdrow + # NOTE Third Row -------------------------------------------- + self.thirdrow = QtWidgets.QHBoxLayout() + self.thirdrow.setObjectName("fifthrow") + spacerItem2 = QtWidgets.QSpacerItem(40, 2, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.thirdrow.addItem(spacerItem2) + spacerItem3 = QtWidgets.QSpacerItem(40, 2, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.thirdrow.addItem(spacerItem3) + self.verticalLayout_4.addLayout(self.thirdrow) + self.gridLayout.addLayout(self.verticalLayout_4, 0, 0, 1, 1) + + self.retranslateUi(About) + QtCore.QMetaObject.connectSlotsByName(About) + + def retranslateUi(self, About): + _translate = QtCore.QCoreApplication.translate + About.setWindowTitle(_translate("About", "Form")) + self.github.setText(_translate("About", 'github/instance-id')) + self.web.setText(_translate("About", 'instance.id')) + +class LinkLabel(QtWidgets.QLabel): + def __init__(self, parent, text): + super(LinkLabel, self).__init__(parent) + + self.setText(text) + self.setTextFormat(QtCore.Qt.RichText) + self.setTextInteractionFlags(QtCore.Qt.TextBrowserInteraction) + self.setOpenExternalLinks(True) \ No newline at end of file diff --git a/scripts/python/searcher/animator.py b/scripts/python/searcher/animator.py new file mode 100644 index 0000000..79f1b09 --- /dev/null +++ b/scripts/python/searcher/animator.py @@ -0,0 +1,58 @@ +from hutil.Qt import QtCore, QtGui, QtWidgets + + +class Animator(QtWidgets.QWidget): + def __init__(self, parent=None, close_cb=None, animationDuration=200): + super(Animator, self).__init__(parent) + + self.animationDuration = animationDuration + + self.toggleAnimation = QtCore.QParallelAnimationGroup() + if close_cb is not None: + self.toggleAnimation.finished.connect(close_cb) + + self.contentArea = QtWidgets.QScrollArea( + maximumHeight=0, minimumHeight=0, minimumWidth=500) + self.contentArea.setStyleSheet( + "QScrollArea { background-color: rgba(58 58, 58, 1); border: none;}") + self.contentArea.setSizePolicy( + QtWidgets.QSizePolicy.Expanding, + QtWidgets.QSizePolicy.Fixed) + + toggleAnimation = self.toggleAnimation + toggleAnimation.addAnimation( + QtCore.QPropertyAnimation(self, b"minimumHeight")) + toggleAnimation.addAnimation( + QtCore.QPropertyAnimation(self, b"maximumHeight")) + toggleAnimation.addAnimation(QtCore.QPropertyAnimation( + self.contentArea, b"maximumHeight")) + + mainLayout = QtWidgets.QVBoxLayout(self) + mainLayout.setSpacing(0) + mainLayout.setContentsMargins(0, 0, 0, 0) + mainLayout.addWidget(self.contentArea) + + def start_animation(self, checked): + direction = QtCore.QAbstractAnimation.Forward if checked else QtCore.QAbstractAnimation.Backward + self.toggleAnimation.setDirection(direction) + self.toggleAnimation.start() + + def setContentLayout(self, contentLayout): + # Not sure if this is equivalent to self.contentArea.destroy() + lay = self.contentArea.layout() + del lay + self.contentArea.setLayout(contentLayout) + collapsedHeight = self.sizeHint().height() - self.contentArea.maximumHeight() + + contentHeight = contentLayout.sizeHint().height() + for i in range(self.toggleAnimation.animationCount()-1): + expandAnimation = self.toggleAnimation.animationAt(i) + expandAnimation.setDuration(self.animationDuration) + expandAnimation.setStartValue(collapsedHeight) + expandAnimation.setEndValue(collapsedHeight + contentHeight) + + contentAnimation = self.toggleAnimation.animationAt( + self.toggleAnimation.animationCount() - 1) + contentAnimation.setDuration(self.animationDuration) + contentAnimation.setStartValue(0) + contentAnimation.setEndValue(contentHeight) diff --git a/scripts/python/searcher/bugreport.py b/scripts/python/searcher/bugreport.py new file mode 100644 index 0000000..2b1aea9 --- /dev/null +++ b/scripts/python/searcher/bugreport.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import +from searcher import bugreport_ui +from searcher import util +import os +import sys + +import hou +hver = 0 +if os.environ["HFS"] != "": + ver = os.environ["HFS"] + hver = int(ver[ver.rindex('.')+1:]) + from hutil.Qt import QtGui + from hutil.Qt import QtCore + from hutil.Qt import QtWidgets + if hver >= 395: + from hutil.Qt import QtUiTools + elif hver <= 394 and hver >= 391: + from hutil.Qt import _QtUiTools + elif hver < 391 and hver >= 348: + from hutil.Qt import QtUiTools + +scriptpath = os.path.dirname(os.path.realpath(__file__)) + + +class BugReport(QtWidgets.QWidget): + """ Searcher Settings and Debug Menu""" + + def __init__(self, parent=None): + super(BugReport, self).__init__(parent=parent) + self.setParent(parent) + self.ui = bugreport_ui.Ui_BugReport() + self.ui.setupUi(self) + self.ui.retranslateUi(self) + diff --git a/scripts/python/searcher/bugreport_ui.py b/scripts/python/searcher/bugreport_ui.py new file mode 100644 index 0000000..0d98314 --- /dev/null +++ b/scripts/python/searcher/bugreport_ui.py @@ -0,0 +1,89 @@ +from hutil.Qt import QtCore, QtGui, QtWidgets +import os + +scriptpath = os.path.dirname(os.path.realpath(__file__)) + + +class Ui_BugReport(object): + def setupUi(self, BugReport): + BugReport.setObjectName("BugReport") + BugReport.setWindowModality(QtCore.Qt.NonModal) + BugReport.resize(450, 300) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(BugReport.sizePolicy().hasHeightForWidth()) + BugReport.setSizePolicy(sizePolicy) + BugReport.setMinimumSize(QtCore.QSize(100, 0)) + BugReport.setBaseSize(QtCore.QSize(0, 0)) + BugReport.setStyleSheet("") + self.gridLayout = QtWidgets.QGridLayout(BugReport) + self.gridLayout.setContentsMargins(-1, -1, -1, 6) + self.gridLayout.setSpacing(0) + self.gridLayout.setObjectName("gridLayout") + self.verticalLayout_4 = QtWidgets.QVBoxLayout() + self.verticalLayout_4.setObjectName("verticalLayout_4") + + # ------------------------------------------------------ logo + # NOTE logo ------------------------------------------------- + self.logo = QtWidgets.QLabel(BugReport) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.logo.sizePolicy().hasHeightForWidth()) + self.logo.setSizePolicy(sizePolicy) + self.logo.setMaximumSize(QtCore.QSize(170, 170)) + self.logo.setText("") + self.logo.setPixmap(QtGui.QPixmap(scriptpath + "/images/logo.png")) + self.logo.setScaledContents(True) + self.logo.setObjectName("logo") + self.verticalLayout_4.addWidget(self.logo) + + # ------------------------------------------------- headerrow + # NOTE headerrow -------------------------------------------- + self.headerrow = QtWidgets.QHBoxLayout() + self.headerrow.setObjectName("headerrow") + + self.github = LinkLabel(BugReport, 'github/instance-id') + self.github.setObjectName("github") + self.headerrow.addWidget(self.github) + self.verticalLayout_4.addLayout(self.headerrow) + + # ------------------------------------------------- secondrow + # NOTE Second Row ------------------------------------------- + self.secondrow = QtWidgets.QHBoxLayout() + self.secondrow.setObjectName("secondrow") + + self.web = LinkLabel(BugReport, "instance.id") + self.web.setObjectName("web") + self.secondrow.addWidget(self.web) + self.verticalLayout_4.addLayout(self.secondrow) + + # -------------------------------------------------- thirdrow + # NOTE Third Row -------------------------------------------- + self.thirdrow = QtWidgets.QHBoxLayout() + self.thirdrow.setObjectName("fifthrow") + spacerItem2 = QtWidgets.QSpacerItem(40, 2, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.thirdrow.addItem(spacerItem2) + spacerItem3 = QtWidgets.QSpacerItem(40, 2, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.thirdrow.addItem(spacerItem3) + self.verticalLayout_4.addLayout(self.thirdrow) + self.gridLayout.addLayout(self.verticalLayout_4, 0, 0, 1, 1) + + self.retranslateUi(BugReport) + QtCore.QMetaObject.connectSlotsByName(BugReport) + + def retranslateUi(self, BugReport): + _translate = QtCore.QCoreApplication.translate + BugReport.setWindowTitle(_translate("BugReport", "Form")) + self.github.setText(_translate("BugReport", 'github/instance-id')) + self.web.setText(_translate("BugReport", 'instance.id')) + +class LinkLabel(QtWidgets.QLabel): + def __init__(self, parent, text): + super(LinkLabel, self).__init__(parent) + + self.setText(text) + self.setTextFormat(QtCore.Qt.RichText) + self.setTextInteractionFlags(QtCore.Qt.TextBrowserInteraction) + self.setOpenExternalLinks(True) \ No newline at end of file diff --git a/scripts/python/searcher/database-bak.py b/scripts/python/searcher/database-bak.py new file mode 100644 index 0000000..f4befc6 --- /dev/null +++ b/scripts/python/searcher/database-bak.py @@ -0,0 +1,261 @@ +# region Imports +from __future__ import print_function +from __future__ import absolute_import +import weakref + +import hou +import os + +from searcher import util + + +from peewee import * +from playhouse.sqlite_ext import SqliteExtDatabase, SearchField, FTSModel +import time + +scriptpath = os.path.dirname(os.path.realpath(__file__)) +db = SqliteExtDatabase(scriptpath + "/db/searcher.db") +cur = db.cursor() + + +class settings(Model): + id = IntegerField(unique=True) + indexvalue = IntegerField() + defaulthotkey = TextField() + searchdescription = IntegerField() + searchprefix = IntegerField() + searchcurrentcontext = IntegerField() + lastused = TextField() + + class Meta: + table_name = 'settings' + database = db + + +class hcontext(Model): + id = AutoField() + context = CharField(unique=True) + title = TextField() + description = TextField() + + class Meta: + table_name = 'hcontext' + database = db + + +class hotkeys(Model): + hotkey_symbol = CharField(unique=True) + label = TextField() + description = TextField() + assignments = TextField() + context = TextField() + + class Meta: + table_name = 'hotkeys' + database = db + + +class hotkeyindex(FTSModel): + description = SearchField() + label = SearchField() + + class Meta: + table_name = 'hotkeyindex' + database = db + options = {'tokenize': 'porter', + 'description': hotkeys.description} + + +db.create_tables([settings, hcontext, hotkeys]) + + +def py_unique(data): + return list(set(data)) + + +def getdata(): + rval = [] + contextdata = [] + hotkeydata = [] + + def getcontexts(r, context_symbol, root): + keys = None + branches = hou.hotkeys.contextsInContext(context_symbol) + for branch in branches: + branch_path = "%s/%s" % (r, branch['label']) + contextdata.append( + {'context': branch['symbol'], 'title': branch['label'], 'description': branch['help']}) + commands = hou.hotkeys.commandsInContext(branch['symbol']) + for command in commands: + keys = hou.hotkeys.assignments(command['symbol']) + ctx = command['symbol'].rsplit('.', 1) + hotkeydata.append( + {'hotkey_symbol': command['symbol'], 'label': command['label'], 'description': command['help'], + 'assignments': " ".join(keys), 'context': ctx[0]}) + getcontexts(branch_path, branch['symbol'], root) + + getcontexts("", "", rval) + return contextdata, hotkeydata + + +class Databases(object): + def __init__(self): + self.a = 1 + # self.settingdata = settings + # if self.settingdata[0]: + # db = SqliteExtDatabase(':memory:') + + # ---------------------------------------------------------- Retrieve + def getchangeindex(self): + try: + cur.execute("SELECT indexvalue FROM settings") + result = cur.fetchall() + return result + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not get Searcher changeindex: " + str(e)), severity=hou.severityType.Error) + + def getdefhotkey(self): + try: + cur.execute("SELECT defaulthotkey FROM settings") + result = cur.fetchall() + return result + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not get Searcher default hotkey: " + str(e)), severity=hou.severityType.Error) + + def gethcontexts(self): + try: + cur.execute("SELECT * FROM hcontext") + result = cur.fetchall() + return result + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not get Searcher hcontext: " + str(e)), severity=hou.severityType.Error) + + def gethcontextod(self, inputlist): + try: + result = [] + query = (hcontext + .select() + .where(hcontext.context.in_(inputlist))).execute() + for hctx in query: + result.append((hctx.title, hctx.description, hctx.context)) + uniqueresult = py_unique(result) + return uniqueresult + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not update Searcher context database: " + str(e)), severity=hou.severityType.Error) + + def ctxfilterresults(self, inputTerm): + try: + result = [] + query = (hotkeys + .select() + .where(hotkeys.context.in_(inputTerm))).execute() + for hctx in query: + result.append((hctx.label, hctx.description, + hctx.assignments, hctx.hotkey_symbol, hctx.context)) + uniqueresult = py_unique(result) + return uniqueresult + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not get Searcher context results: " + str(e)), severity=hou.severityType.Error) + + def searchresults(self, inputTerm): + try: + cur.execute( + "SELECT label, description, assignments, hotkey_symbol, context FROM hotkeys WHERE label LIKE '%" + + str(inputTerm) + + "%' OR description LIKE '%" + + str(inputTerm) + + "%' LIMIT 25" + ) + result = cur.fetchall() + uniqueresult = py_unique(result) + return uniqueresult + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not get Searcher results: " + str(e)), severity=hou.severityType.Error) + # endregion + + # ---------------------------------------------------------- Updates + + def updatechangeindex(self, indexval, new=False): + try: + if new is True: + defaultkey = (u"Ctrl+Alt+Shift+F7") + settings.insert(indexvalue=indexval, + defaulthotkey=defaultkey, searchdescription=0, searchprefix=0, searchcurrentcontext=0, id=1).execute() + else: + settings.update(indexvalue=indexval).where( + settings.id == 1).execute() + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not update Searcher context database: " + str(e)), + severity=hou.severityType.Error + ) + + def updatetmphk(self, tmpkey): + try: + _ = settings.update( + defaulthotkey=tmpkey).where(id == 1).execute() + return + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not update Searcher temp hotkey: " + str(e)), severity=hou.severityType.Error) + + def updatelastkey(self, lastkey): + try: + _ = settings.update( + lastused=lastkey).where(id == 1).execute() + return + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not update Searcher temp hotkey: " + str(e)), severity=hou.severityType.Error) + + def updatecontext(self, debug=None): + try: + time1 = time.time() + self.cleardatabase() + ctxdata, hkeydata = getdata() + with db.atomic(): + for data_dict in ctxdata: + hcontext.replace_many(data_dict).execute() + with db.atomic(): + for idx in hkeydata: + hotkeys.replace_many(idx).execute() + time2 = time.time() + if debug: + hou.ui.setStatusMessage( + ('DB update took %0.3f ms' % + ((time2 - time1) * 1000.0)), severity=hou.severityType.Message) + print('DB update took %0.3f ms' % + ((time2 - time1) * 1000.0)) # TODO Remove this timer + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not update Searcher context database: " + str(e)), severity=hou.severityType.Error) + + # with db.atomic(): + # for idx in range(0, len(ctxdata), 100): + # hcontext.replace_many(ctxdata[idx:idx+100]).execute() + # with db.atomic(): + # for idx in range(0, len(hkeydata), 100): + # hotkeys.replace_many(hkeydata[idx:idx+100]).execute() + + # endregion + + def cleardatabase(self): + try: + delhk = "DELETE FROM hotkeys" + delctx = "DELETE FROM hcontext" + cur.execute(delhk) + cur.execute(delctx) + result = cur.fetchall() + + return result + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage( + ("Could not update Searcher temp hotkey: " + str(e)), + severity=hou.severityType.Error + ) diff --git a/scripts/python/searcher/database.py b/scripts/python/searcher/database.py index 48de46a..ccf8b88 100644 --- a/scripts/python/searcher/database.py +++ b/scripts/python/searcher/database.py @@ -1,4 +1,3 @@ -# region Imports from __future__ import print_function from __future__ import absolute_import import weakref @@ -7,18 +6,25 @@ import os from searcher import util +from searcher import searcher_data +from searcher import ptime as ptime + from peewee import * -from playhouse.sqlite_ext import SqliteExtDatabase, SearchField, FTSModel +from playhouse.sqlite_ext import SqliteExtDatabase, RowIDField, FTS5Model, SearchField import time -scriptpath = os.path.dirname(os.path.realpath(__file__)) -db = SqliteExtDatabase(scriptpath + "/db/searcher.db") -cur = db.cursor() +def get_db(): + return getattr(hou.session, "DATABASE", None) +scriptpath = os.path.dirname(os.path.realpath(__file__)) +hou.session.DATABASE = DatabaseProxy() +db = get_db() -class settings(Model): +# -------------------------------------- DatabaseModels +# SECTION DatabaseModels ------------------------------ +class Settings(Model): id = IntegerField(unique=True) indexvalue = IntegerField() defaulthotkey = TextField() @@ -32,9 +38,9 @@ class Meta: database = db -class hcontext(Model): +class HContext(Model): id = AutoField() - context = CharField(unique=True) + context = TextField(unique=True) title = TextField() description = TextField() @@ -42,10 +48,19 @@ class Meta: table_name = 'hcontext' database = db +class HContextIndex(FTS5Model): + # rowid = RowIDField() + context = SearchField() + title = SearchField() + description = SearchField() + + class Meta: + database = db + options = {'prefix': [2, 3], 'tokenize': 'porter'} -class hotkeys(Model): +class Hotkeys(Model): hotkey_symbol = CharField(unique=True) - label = TextField() + label = CharField() description = TextField() assignments = TextField() context = TextField() @@ -55,24 +70,32 @@ class Meta: database = db -class hotkeyindex(FTSModel): - description = SearchField() +class HotkeysIndex(FTS5Model): + # rowid = RowIDField() + hotkey_symbol = SearchField(unindexed=True) label = SearchField() + description = SearchField() + assignments = SearchField(unindexed=True) + context = SearchField(unindexed=True) + + def clear_index(self): + HotkeysIndex.delete().where(HotkeysIndex.rowid == self.id).execute() class Meta: - table_name = 'hotkeyindex' + # table_name = 'hotkeysindex' database = db - options = {'tokenize': 'porter', - 'description': hotkeys.description} - - -db.create_tables([settings, hcontext, hotkeys]) - + options = {'prefix': [2, 3], 'tokenize': 'porter'} +#!SECTION +# --------------------------------------------------------- DatabaseModels +# SECTION DatabaseModels ------------------------------------------------- +# ----------------------------------------------- py_unique +# NOTE py_unique ------------------------------------------ def py_unique(data): return list(set(data)) - +# ------------------------------------------------- getdata +# NOTE getdata -------------------------------------------- def getdata(): rval = [] contextdata = [] @@ -84,178 +107,301 @@ def getcontexts(r, context_symbol, root): for branch in branches: branch_path = "%s/%s" % (r, branch['label']) contextdata.append( - {'context': branch['symbol'], 'title': branch['label'], 'description': branch['help']}) + {'context': branch['symbol'], + 'title': branch['label'], + 'description': branch['help']} + ) commands = hou.hotkeys.commandsInContext(branch['symbol']) for command in commands: keys = hou.hotkeys.assignments(command['symbol']) ctx = command['symbol'].rsplit('.', 1) hotkeydata.append( - {'hotkey_symbol': command['symbol'], 'label': command['label'], 'description': command['help'], - 'assignments': " ".join(keys), 'context': ctx[0]}) + {'hotkey_symbol': command['symbol'], + 'label': command['label'], + 'description': command['help'], + 'assignments': " ".join(keys), + 'context': ctx[0]} + ) getcontexts(branch_path, branch['symbol'], root) getcontexts("", "", rval) return contextdata, hotkeydata - +# !SECTION class Databases(object): def __init__(self): - self.a = 1 - # self.settingdata = settings - # if self.settingdata[0]: - # db = SqliteExtDatabase(':memory:') + self.settings = searcher_data.loadsettings() + self.isdebug = util.bc(self.settings[util.SETTINGS_KEYS[4]]) + inmemory = util.bc(self.settings[util.SETTINGS_KEYS[0]]) + if inmemory: + val = ':memory:' + else: + val = (scriptpath + "/db/searcher.db") + + db.initialize( + SqliteExtDatabase( + val, + pragmas=( + ("cache_size", -1024 * 64), + ("journal_mode", "wal"), + ("synchronous", 0) + ))) + + self.cur = db.cursor() + if inmemory: + db.create_tables([ + Settings, + HContext, + HContextIndex, + Hotkeys, + HotkeysIndex] + ) + self.initialsetup(self.cur) - # region --------------------------------------------------- Retrieve + self.a = 1 + self.isdebug = None + self.contexttime = 0 + self.hotkeystime = 0 + + # ----------------------------------------------------------- Retrieve + # SECTION Retrieve --------------------------------------------------- + # -------------------------------------- getchangeindex + # NOTE getchangeindex --------------------------------- def getchangeindex(self): try: - cur.execute("SELECT indexvalue FROM settings") - result = cur.fetchall() + self.cur.execute("SELECT indexvalue FROM settings") + result = self.cur.fetchall() return result except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not get Searcher changeindex: " + str(e)), severity=hou.severityType.Error) + hou.ui.setStatusMessage(("Could not get Searcher changeindex: " + str(e)), severity=hou.severityType.Error) - def getdefhotkey(self): + # ------------------------------------------- getlastusedhk + # NOTE getlastusedhk -------------------------------------- + def getlastusedhk(self, cur): try: - cur.execute("SELECT defaulthotkey FROM settings") + cur.execute("SELECT lastused FROM settings") result = cur.fetchall() + if str(result[0][0]) != "": + lasthk = str(result[0][0]).split(' ') + rmresult = hou.hotkeys.removeAssignment( + str(lasthk[0]).strip(), str(lasthk[1]).strip()) + if rmresult: + hkcheck = hou.hotkeys.assignments(str(lasthk[0])) + hou.hotkeys.saveOverrides() + if len(hkcheck) is 0: + Settings.update(lastused="").where(Settings.id == 1).execute() + currentidx = hou.hotkeys.changeIndex() + updatechangeindex(int(currentidx)) + else: + hou.hotkeys.clearAssignments(str(lasthk[0])) + hou.hotkeys.saveOverrides() + hkcheck = hou.hotkeys.assignments(str(lasthk[0])) + if len(hkcheck) is 0: + Settings.update(lastused="").where(Settings.id == 1).execute() + currentidx = hou.hotkeys.changeIndex() + updatechangeindex(int(currentidx)) + else: + if hou.isUIAvailable(): + hou.ui.setStatusMessage(("Could not clear last assigned temp hotkey on last attempt:"), severity=hou.severityType.Warning) + else: + print("Could not clear last assigned temp hotkey on last attempt:") + else: + if hou.isUIAvailable(): + hou.ui.setStatusMessage(("Could not clear last assigned temp hotkey:"), severity=hou.severityType.Warning) + else: + print("Could not clear last assigned temp hotkey:") + + except(AttributeError, TypeError) as e: + if hou.isUIAvailable(): + hou.ui.setStatusMessage(("Could not query last assigned temp hotkey:" + str(e)), severity=hou.severityType.Warning) + else: + print("Could not query last assigned temp hotkey: " + str(e)) + + def getdefhotkey(self): + try: + self.cur.execute("SELECT defaulthotkey FROM settings") + result = self.cur.fetchall() return result except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not get Searcher default hotkey: " + str(e)), severity=hou.severityType.Error) + hou.ui.setStatusMessage(("Could not get Searcher default hotkey: " + str(e)), severity=hou.severityType.Error) def gethcontexts(self): try: - cur.execute("SELECT * FROM hcontext") - result = cur.fetchall() + self.cur.execute("SELECT * FROM hcontext") + result = self.cur.fetchall() return result except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not get Searcher hcontext: " + str(e)), severity=hou.severityType.Error) + hou.ui.setStatusMessage(("Could not get Searcher hcontext: " + str(e)), severity=hou.severityType.Error) def gethcontextod(self, inputlist): try: + time1 = ptime.time() result = [] - query = (hcontext + # query = (HContextIndex + # .select(HContextIndex) + # .where(HContextIndex.match(inputlist))) + query = (HContext .select() - .where(hcontext.context.in_(inputlist))).execute() + .where(HContext.context.in_(inputlist))).execute() for hctx in query: result.append((hctx.title, hctx.description, hctx.context)) uniqueresult = py_unique(result) - return uniqueresult + time2 = ptime.time() + self.contexttime = ((time2 - time1) * 1000.0) + return uniqueresult, self.contexttime except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not update Searcher context database: " + str(e)), severity=hou.severityType.Error) + hou.ui.setStatusMessage(("Could not update Searcher context database: " + str(e)), severity=hou.severityType.Error) + + # def gethcontextod(self, inputlist): + # try: + # result = [] + # query = (HContext + # .select() + # .where(HContext.context.in_(inputlist))).execute() + # for hctx in query: + # result.append((hctx.title, hctx.description, hctx.context)) + # uniqueresult = py_unique(result) + # return uniqueresult + # except(AttributeError, TypeError) as e: + # hou.ui.setStatusMessage(("Could not update Searcher context database: " + str(e)), severity=hou.severityType.Error) def ctxfilterresults(self, inputTerm): try: result = [] - query = (hotkeys + query = (Hotkeys .select() - .where(hotkeys.context.in_(inputTerm))).execute() + .where(Hotkeys.context.in_(inputTerm))).execute() for hctx in query: - result.append((hctx.label, hctx.description, - hctx.assignments, hctx.hotkey_symbol, hctx.context)) + result.append((hctx.label, hctx.description, hctx.assignments, hctx.hotkey_symbol, hctx.context)) uniqueresult = py_unique(result) return uniqueresult except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not get Searcher context results: " + str(e)), severity=hou.severityType.Error) - - def searchresults(self, inputTerm): + hou.ui.setStatusMessage(("Could not get Searcher context results: " + str(e)), severity=hou.severityType.Error) + + + def searchresults(self, inputTerm, debug, limit=0): + self.isdebug = debug try: - cur.execute( - "SELECT label, description, assignments, hotkey_symbol, context FROM hotkeys WHERE label LIKE '%" + time1 = ptime.time() + self.cur.execute( + "SELECT label, description, assignments, hotkey_symbol, context FROM hotkeysindex WHERE hotkeysindex MATCH '" + str(inputTerm) - + "%' OR description LIKE '%" - + str(inputTerm) - + "%'" + + "' ORDER BY rank" + + " LIMIT " + + str(limit) ) - result = cur.fetchall() + result = self.cur.fetchall() uniqueresult = py_unique(result) - return uniqueresult - except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not get Searcher results: " + str(e)), severity=hou.severityType.Error) - # endregion + + time2 = ptime.time() + self.hotkeystime = ((time2 - time1) * 1000.0) - # region --------------------------------------------------- Updates + return uniqueresult, self.hotkeystime + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage(("Could not get Searcher results: " + str(e)), severity=hou.severityType.Error) + # ---------------------------------------------------------- Updates def updatechangeindex(self, indexval, new=False): try: if new is True: - defaultkey = (u"Ctrl+Alt+Shift+F7") - settings.insert(indexvalue=indexval, - defaulthotkey=defaultkey, searchdescription=0, searchprefix=0, searchcurrentcontext=0, id=1).execute() + defaultkey = "" + for i in range(len(util.HOTKEYLIST)): + result = hou.hotkeys.findConflicts("h", util.HOTKEYLIST[i]) + if not result: + defaultkey = util.HOTKEYLIST[i] + + Settings.insert(indexvalue=indexval, + defaulthotkey=defaultkey, searchdescription=0, searchprefix=0, searchcurrentcontext=0, lastused="", id=1).execute() else: - settings.update(indexvalue=indexval).where( - settings.id == 1).execute() + Settings.update(indexvalue=indexval).where( + Settings.id == 1).execute() except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not update Searcher context database: " + str(e)), - severity=hou.severityType.Error - ) + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + ("Could not update Searcher context database: " + str(e)), + severity=hou.severityType.Warning + ) + else: + print("Could not update Searcher context database: " + str(e)) def updatetmphk(self, tmpkey): try: - _ = settings.update( + _ = Settings.update( defaulthotkey=tmpkey).where(id == 1).execute() return except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not update Searcher temp hotkey: " + str(e)), severity=hou.severityType.Error) + hou.ui.setStatusMessage(("Could not update Searcher temp hotkey: " + str(e)), severity=hou.severityType.Error) def updatelastkey(self, lastkey): try: - _ = settings.update( - lastused=lastkey).where(id == 1).execute() + _ = Settings.update(lastused=lastkey).where(id == 1).execute() return except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not update Searcher temp hotkey: " + str(e)), severity=hou.severityType.Error) + hou.ui.setStatusMessage(("Could not update Searcher temp hotkey: " + str(e)), severity=hou.severityType.Error) - def updatecontext(self, debug=None): + def updatecontext(self, debug): + self.isdebug = debug try: - time1 = time.time() + time1 = ptime.time() self.cleardatabase() ctxdata, hkeydata = getdata() with db.atomic(): for data_dict in ctxdata: - hcontext.replace_many(data_dict).execute() + HContext.replace_many(data_dict).execute() with db.atomic(): for idx in hkeydata: - hotkeys.replace_many(idx).execute() - time2 = time.time() - if debug: - hou.ui.setStatusMessage( - ('DB update took %0.3f ms' % - ((time2 - time1) * 1000.0)), severity=hou.severityType.Message) - print('DB update took %0.3f ms' % - ((time2 - time1) * 1000.0)) # TODO Remove this timer - except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not update Searcher context database: " + str(e)), severity=hou.severityType.Error) - - # with db.atomic(): - # for idx in range(0, len(ctxdata), 100): - # hcontext.replace_many(ctxdata[idx:idx+100]).execute() - # with db.atomic(): - # for idx in range(0, len(hkeydata), 100): - # hotkeys.replace_many(hkeydata[idx:idx+100]).execute() + Hotkeys.replace_many(idx).execute() + HotkeysIndex.replace_many(idx).execute() + time2 = ptime.time() + if self.isdebug and self.isdebug.level in {"TIMER", "ALL"}: + res = ((time2 - time1) * 1000.0) + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + ('DB update took %0.4f ms' % res), severity=hou.severityType.Message) + else: + print('DB update took %0.4f ms' % res) # TODO Remove this timer + return res - # endregion + except(AttributeError, TypeError) as e: + hou.ui.setStatusMessage(("Could not update Searcher context database: " + str(e)), severity=hou.severityType.Error) def cleardatabase(self): try: delhk = "DELETE FROM hotkeys" delctx = "DELETE FROM hcontext" - cur.execute(delhk) - cur.execute(delctx) - result = cur.fetchall() + delhkindex = "DELETE FROM hotkeysindex" + delhcindex = "DELETE FROM hcontextindex" + self.cur.execute(delhk) + self.cur.execute(delctx) + result = self.cur.fetchall() return result except(AttributeError, TypeError) as e: - hou.ui.setStatusMessage( - ("Could not update Searcher temp hotkey: " + str(e)), - severity=hou.severityType.Error - ) + hou.ui.setStatusMessage(("Could not update Searcher temp hotkey: " + str(e)),severity=hou.severityType.Error) + + def initialsetup(self, cur): + currentidx = hou.hotkeys.changeIndex() + chindex = self.getchangeindex() + + if len(chindex) == 0: + chindex = int(currentidx) + self.updatechangeindex(chindex, True) + self.updatecontext(self.isdebug) + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + "Searcher database created", severity=hou.severityType.Message) + else: + print("Searcher database created") + else: + chindex = int(chindex[0][0]) + + if int(currentidx) != chindex: + self.getlastusedhk(cur) + self.updatecontext() + self.updatechangeindex(int(currentidx)) + + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + "Searcher database created and populated", severity=hou.severityType.Message) + + \ No newline at end of file diff --git a/scripts/python/searcher/datahandler.py b/scripts/python/searcher/datahandler.py index 892f9cf..d7d7938 100644 --- a/scripts/python/searcher/datahandler.py +++ b/scripts/python/searcher/datahandler.py @@ -1,10 +1,19 @@ +from __future__ import print_function +from __future__ import absolute_import + +import hou +import os + +from searcher import util +from searcher import searcher_data +from searcher import ptime as ptime +from searcher import database + import os import threading import hdefereval as hd -from . import database - -db = database.Databases() +reload(database) def worker(): @@ -15,58 +24,64 @@ class DataHandler(object): """Searcher data and communication handler""" def __init__(self, debug=None): + self.db = database.Databases() self.isdebug = debug self.scriptpath = os.path.dirname(os.path.realpath(__file__)) - # ----------------------------------------------------------------------------------- Function calls - # ----------------------------------------------------- Retrieve + # SECTION Function calls ------------------------------ Function calls + # -------------------------------------------- Retrieve + # NOTE Retrieve --------------------------------------- def getchangeindex(self): - index = db.getchangeindex() + index = self.db.getchangeindex() return index def getdefaulthotkey(self): - index = db.getdefhotkey() + index = self.db.getdefhotkey() return index - # ----------------------------------------------------- Updates + # --------------------------------------------- Updates + # NOTE Updates ---------------------------------------- def updatechangeindex(self, indexval, new=False): - db.updatechangeindex(indexval, new) + self.db.updatechangeindex(indexval, new) return - def updatedataasync(self): + def updatedataasync(self, debug): + self.isdebug = debug thread = threading.Thread(target=worker) thread.daemon = True thread.start() def updatedata(self): - db.updatecontext(self.isdebug) + self.db.updatecontext(self.isdebug) return def updatetmphotkey(self, tmpkey): - db.updatetmphk(tmpkey) + self.db.updatetmphk(tmpkey) return def updatelasthk(self, lastkey): - db.updatelastkey(lastkey) + self.db.updatelastkey(lastkey) return @staticmethod def gethcontext(): - results = db.gethcontexts() + results = self.db.gethcontexts() return results def gethcontextod(self, inputtext): - results = db.gethcontextod(inputtext) - return results + results, timer = self.db.gethcontextod(inputtext) + return results, timer def searchctx(self, txt): - results = db.ctxfilterresults(txt) + results = self.db.ctxfilterresults(txt) return results - def searchtext(self, txt): - results = db.searchresults(txt) - return results + def searchtext(self, txt, debug, limit=0): + self.isdebug = debug + results, timer = self.db.searchresults(txt, self.isdebug, limit) + return results, timer def cleardb(self): - results = db.cleardatabase() + results = self.db.cleardatabase() return results + # !SECTION \ No newline at end of file diff --git a/scripts/python/searcher/enum.py b/scripts/python/searcher/enum.py new file mode 100644 index 0000000..83d8868 --- /dev/null +++ b/scripts/python/searcher/enum.py @@ -0,0 +1,38 @@ +# -------------------------------------------- +# http://code.activestate.com/recipes/413486/ +# -------------------------------------------- +def Enum(*names): + ##assert names, "Empty enums are not supported" # <- Don't like empty enums? Uncomment! + + class EnumClass(object): + __slots__ = names + def __iter__(self): return iter(constants) + def __len__(self): return len(constants) + def __getitem__(self, i): return constants[i] + def __repr__(self): return 'Enum' + str(names) + def __str__(self): return 'enum ' + str(constants) + + class EnumValue(object): + __slots__ = ('__value') + def __init__(self, value): self.__value = value + Value = property(lambda self: self.__value) + EnumType = property(lambda self: EnumType) + def __hash__(self): return hash(self.__value) + def __cmp__(self, other): + # C fans might want to remove the following assertion + # to make all enums comparable by ordinal value {;)) + assert self.EnumType is other.EnumType, "Only values from the same enum are comparable" + return cmp(self.__value, other.__value) + def __invert__(self): return constants[maximum - self.__value] + def __nonzero__(self): return bool(self.__value) + def __repr__(self): return str(names[self.__value]) + + maximum = len(names) - 1 + constants = [None] * len(names) + for i, each in enumerate(names): + val = EnumValue(i) + setattr(EnumClass, each, val) + constants[i] = val + constants = tuple(constants) + EnumType = EnumClass() + return EnumType \ No newline at end of file diff --git a/scripts/python/searcher/images/logo.png b/scripts/python/searcher/images/logo.png new file mode 100644 index 0000000..4292cea Binary files /dev/null and b/scripts/python/searcher/images/logo.png differ diff --git a/scripts/python/searcher/language_en.py b/scripts/python/searcher/language_en.py new file mode 100644 index 0000000..5cff590 --- /dev/null +++ b/scripts/python/searcher/language_en.py @@ -0,0 +1,38 @@ +# SECTION Language US +language = "en" + +# NOTE Tooltips +ln_searchbox = "Begin typing to search or click magnifying glass icon to display options" +ln_contexttoggle = "Toggle to enable or disable the 'context' column in the search results" +ln_pinwindow = "Pin the search window to keep it from closing automatically when losing focus" +ln_searchfilter = "Select a predefined filter" +ln_opensettingstool = "General application settings" +ln_searchresultstree = "Press tab to highlight or double click an action to attempt to perform it. Some actions only work in specific contexts" +ln_helppanel = "Open help panel" + +# NOTE Tooltops Settings +TT_SETTINGS = { + "about_btn": "Thanks for using Searcher!", + "projectTitle" : "Thanks for using Searcher!", + "lang_cbox" : "When translations become available they can be selected here", + "inmemory_chk" : "Enable to use an im-memory database instead of SQLite file", + "windowsize_chk" : "Enable to save the size and location of the main window upon closing. Defaults to center (1000px, 600px)", + "maxresults_lbl" : "Maximum results to load per query as you type your search term", + "maxresults_txt" : "Maximum results to load per query as you type your search term", + "animatedsettings_chk" : "Enables animated menus", + "dbpath_icon": "", + "dbpath_lbl" : "The location in which Searcher stores it's database file", + "databasepath_txt" : "The location in which Searcher stores it's database file", + "defaulthotkey_lbl" : ("If left to the default value of (Ctrl+Alt+Shift+F7), " + "in the event that Searcher detects a conflict it will " + "automatically attempt to try different key combinations."), + "defaulthotkey_txt" : ("If left to the default value of (Ctrl+Alt+Shift+F7), " + "in the event that Searcher detects a conflict it will " + "automatically attempt to try different key combinations."), + "cleardata_btn" : "If, for some reason, Searcher is having issues this function will clear out the database and start fresh", + "save_btn" : "Save your settings", + "discard_btn" : "Disgard settings changes", + "debugflag_chk": "Toggle debug messages", + "debuglevel_cbx": "Select level of debugging", +} +# !SECTION diff --git a/scripts/python/searcher/linklabel b/scripts/python/searcher/linklabel new file mode 100644 index 0000000..e69de29 diff --git a/scripts/python/searcher/platformselect.py b/scripts/python/searcher/platformselect.py new file mode 100644 index 0000000..84e5e1f --- /dev/null +++ b/scripts/python/searcher/platformselect.py @@ -0,0 +1,31 @@ +import os +import sys +import platform +import ctypes + +PLATFORM = None +SEARCHER_PATH = os.environ["SEARCHER"] + +def get_platform(): + """Returns a string for the current platform""" + global PLATFORM + + if PLATFORM is None: + p = sys.platform + if p == 'darwin': + PLATFORM = 'darwin' + elif p.startswith('win'): + PLATFORM = 'windows' + else: + PLATFORM = 'unix' + + return PLATFORM + +def get_sqlite(): + if get_platform() == "Windows": + path_sqlite_dll = os.path.join(SEARCHER_PATH, 'python27/dlls/sqlite3.dll') + ctypes.cdll.LoadLibrary(path_sqlite_dll) + elif get_platform() == "Darwin": + path_sqlite_dll = os.path.join(SEARCHER_PATH, 'python27/dlls/sqlite3.dll') + else: + path_sqlite_dll = os.path.join(SEARCHER_PATH, 'python27/dlls/sqlite3.dll') \ No newline at end of file diff --git a/scripts/python/searcher/ptime.py b/scripts/python/searcher/ptime.py new file mode 100644 index 0000000..450e6fd --- /dev/null +++ b/scripts/python/searcher/ptime.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +""" +ptime.py - Precision time function made os-independent (should have been taken care of by python) +Copyright 2010 Luke Campagnola +Distributed under MIT/X11 license. See license.txt for more infomation. +""" + + +import sys + +if sys.version_info[0] < 3: + from time import clock + from time import time as system_time +else: + from time import perf_counter as clock + from time import time as system_time + +START_TIME = None +time = None + +def winTime(): + """Return the current time in seconds with high precision (windows version, use Manager.time() to stay platform independent).""" + return clock() - START_TIME + #return systime.time() + +def unixTime(): + """Return the current time in seconds with high precision (unix version, use Manager.time() to stay platform independent).""" + return system_time() + +if sys.platform.startswith('win'): + cstart = clock() ### Required to start the clock in windows + START_TIME = system_time() - cstart + + time = winTime +else: + time = unixTime \ No newline at end of file diff --git a/scripts/python/searcher/scratch b/scripts/python/searcher/scratch index 31f6181..285b0dd 100644 --- a/scripts/python/searcher/scratch +++ b/scripts/python/searcher/scratch @@ -1,3 +1,5 @@ +# C:\Users\mosthated\AppData\Roaming\Python\Python37\Scripts\pyuic5.exe .\SearcherSettings.ui -o .\SearcherSettings.py + # panetab = None # for pane in hou.ui.floatingPaneTabs(): # if pane.type() == hou.paneTabType.PythonPanel: @@ -20,4 +22,47 @@ # from qtpy import QtGui # from qtpy import QtCore # from qtpy import QtWidgets -# endregion \ No newline at end of file +# endregion + + + + # ------------------------------------- checkforchanges + def checkforchanges(self): + for i in range(len(util.SETTINGS_KEYS)): + if util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "bool": + if self.isdebug and self.isdebug.level in {"ALL"}: + print("Get attribute: ", getattr(self, util.SETTINGS_KEYS[i])) + print("Get settings: ", bc(self.currentsettings[util.SETTINGS_KEYS[i]])) + if getattr(self, util.SETTINGS_KEYS[i]).isChecked() != bc(self.currentsettings[util.SETTINGS_KEYS[i]]): + if self.isdebug and self.isdebug.level in {"ALL"}: + print("{} value {}".format(util.SETTINGS_KEYS[i], getattr(self, util.SETTINGS_KEYS[i]).isChecked())) + print("{} value {}".format(util.SETTINGS_KEYS[i], bc(self.currentsettings[util.SETTINGS_KEYS[i]]))) + return True + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "text": + if self.isdebug and self.isdebug.level in {"ALL"}: + print("Get attribute: ", getattr(self, util.SETTINGS_KEYS[i])) + print("Get settings: ", self.currentsettings[util.SETTINGS_KEYS[i]]) + if getattr(self, util.SETTINGS_KEYS[i]).text() != self.currentsettings[util.SETTINGS_KEYS[i]]: + if self.isdebug and self.isdebug.level in {"ALL"}: + print("{} value {}".format(util.SETTINGS_KEYS[i], getattr(self, util.SETTINGS_KEYS[i]).text())) + print("{} value {}".format(util.SETTINGS_KEYS[i], self.currentsettings[util.SETTINGS_KEYS[i]])) + return True + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "intval": + if self.isdebug and self.isdebug.level in {"ALL"}: + print("Get attribute: ", getattr(self, util.SETTINGS_KEYS[i])) + print("Get settings: ", self.currentsettings[util.SETTINGS_KEYS[i]]) + if getattr(self, util.SETTINGS_KEYS[i]).value() != self.currentsettings[util.SETTINGS_KEYS[i]]: + if self.isdebug and self.isdebug.level in {"ALL"}: + print("{} value {}".format(util.SETTINGS_KEYS[i], getattr(self, util.SETTINGS_KEYS[i]).value())) + print("{} value {}".format(util.SETTINGS_KEYS[i], int(self.currentsettings[util.SETTINGS_KEYS[i]]))) + return True + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "cbx": + if self.isdebug and self.isdebug.level in {"ALL"}: + print("Get attribute: ", getattr(self, util.SETTINGS_KEYS[i])) + print("Get settings: ", self.currentsettings[util.SETTINGS_KEYS[i]]) + if getattr(self, util.SETTINGS_KEYS[i]).currentText() != self.currentsettings[util.SETTINGS_KEYS[i]]: + if self.isdebug and self.isdebug.level in {"ALL"}: + print("{} value {}".format(util.SETTINGS_KEYS[i], getattr(self, util.SETTINGS_KEYS[i]).currentText())) + print("{} value {}".format(util.SETTINGS_KEYS[i], str(self.currentsettings[util.SETTINGS_KEYS[i]]))) + return True + return False \ No newline at end of file diff --git a/scripts/python/searcher/searcher.py b/scripts/python/searcher/searcher.py index 05c51c7..6f93100 100644 --- a/scripts/python/searcher/searcher.py +++ b/scripts/python/searcher/searcher.py @@ -4,16 +4,21 @@ import weakref from searcher import util +from searcher import ptime +from searcher import animator from searcher import database +from searcher import HelpButton from searcher import datahandler from searcher import searcher_data from searcher import searcher_settings -from searcher import HelpButton +from searcher import language_en as la import hou import platform import os import sys +import re +from string import ascii_letters import hdefereval as hd hver = 0 if os.environ["HFS"] != "": @@ -27,11 +32,14 @@ reload(searcher_data) reload(datahandler) reload(HelpButton) +reload(animator) reload(database) +reload(ptime) reload(util) +reload(la) # endregion -# region ------------------------------------------------------------- App Info +# -------------------------------------------------------------------- App Info __package__ = "Searcher" __version__ = "0.1b" __author__ = "instance.id" @@ -39,11 +47,10 @@ __status__ = "Prototype" # endregion -# region ------------------------------------------------------------- Variables / Constants +# -------------------------------------------------------------------- Variables / Constants kwargs = {} settings = {} hasran = False -isdebug = False mousePos = None cur_screen = QtWidgets.QDesktopWidget().screenNumber( QtWidgets.QDesktopWidget().cursor().pos() @@ -59,7 +66,7 @@ searcher_window = QtWidgets.QMainWindow() # endregion -# region ------------------------------------------------------------- Class Functions +# -------------------------------------------------------------------- Class Functions def keyconversion(key): @@ -69,7 +76,7 @@ def keyconversion(key): return key # endregion -# region ------------------------------------------------------------- Searcher Class +# -------------------------------------------------------------------- Searcher Class class Searcher(QtWidgets.QWidget): @@ -79,25 +86,46 @@ class Searcher(QtWidgets.QWidget): def __init__(self, kwargs, settings, windowsettings): super(Searcher, self).__init__(hou.qt.mainWindow()) self._drag_active = False + self.animationDuration = 200 + self.uiwidth = int(520) + self.uiheight = int(300) # Setting vars kwargs = kwargs self.settingdata = settings self.windowsettings = windowsettings - self.isdebug = util.bc(self.settingdata[util.SETTINGS_KEYS[4]]) + self.isdebug = util.Dbug( + util.bc(self.settingdata[util.SETTINGS_KEYS[4]]), + str(self.settingdata[util.SETTINGS_KEYS[10]]) + ) self.menuopened = False self.windowispin = util.bc(self.settingdata[util.SETTINGS_KEYS[5]]) + self.showctx = util.bc(self.settingdata[util.SETTINGS_KEYS[7]]) self.originalsize = self.settingdata[util.SETTINGS_KEYS[3]] - - # if hver >= 391: + self.animatedsettings = util.bc( + self.settingdata[util.SETTINGS_KEYS[8]]) + self.mainlayout = QtWidgets.QVBoxLayout() + self.settingslayout = QtWidgets.QVBoxLayout() self.app = QtWidgets.QApplication.instance() # UI Vars self.handler, self.tmpkey = self.initialsetup() - self.ui = searcher_settings.SearcherSettings(self.handler, self.tmpkey) + self.ui = searcher_settings.SearcherSettings( + self.handler, + self.tmpkey, + self + ) + if self.animatedsettings: + self.anim = animator.Animator(self.ui, self.anim_complete) # Functional Vars + self.endtime = 0 self.lastused = {} + self.starttime = 0 + self.treecatnum = 0 + self.hotkeystime = 0 + self.hcontexttime = 0 + self.treeitemsnum = 0 self.tmpsymbol = None self.searching = False self.ctxsearch = False @@ -118,49 +146,100 @@ def __init__(self, kwargs, settings, windowsettings): self.pinwindow.installEventFilter(self) self.helpButton.installEventFilter(self) self.searchfilter.installEventFilter(self) + self.contexttoggle.installEventFilter(self) self.opensettingstool.installEventFilter(self) self.searchresultstree.installEventFilter(self) + + # ---------------------------------- Build Settings + # NOTE Build Settings ----------------------------- + self.buildsettingsmenu() + # !SECTION - # region ------------------------------------------------------------- Settings - def open_settings(self): - self.ui.setWindowTitle('Searcher - Settings') - self.ui.show() - self.ui.setFocus() - # endregion + # ------------------------------------------------------ Settings Menu + # SECTION Settings Menu ---------------------------------------------- + # ----------------------------------- buildsettingsmenu + # NOTE buildsettingsmenu ------------------------------ + def buildsettingsmenu(self): + self.ui.setWindowFlags( + QtCore.Qt.Tool | + QtCore.Qt.CustomizeWindowHint | + QtCore.Qt.FramelessWindowHint | + QtCore.Qt.WindowStaysOnTopHint + ) + self.ui.setAttribute(QtCore.Qt.WA_StyledBackground, True) + self.ui.setAttribute(QtCore.Qt.WA_TranslucentBackground, True) + self.ui.setStyleSheet("QWidget { background: rgb(58, 58, 58); }" + "QWidget#SearcherSettings { border: 2px solid rgb(35, 35, 35); } ") + + self.settingslayout = self.ui.settingslayout + if self.animatedsettings: + self.anim.setContentLayout(self.settingslayout) + self.anim.resize( + self.uiwidth, + self.uiheight + ) + self.ui.resize( + self.uiwidth, + self.uiheight + ) + + + # !SECTION - # SECTION uisetup - # region ------------------------------------------------------------- UI - def setupContext(self): + # ----------------------------------------------------------------- UI + # SECTION UI --------------------------------------------------------- + # ------------------------------------- setupresulttree + # NOTE setupresulttree -------------------------------- + def setupresulttree(self): cols = 4 self.searchresultstree.setColumnCount(cols) self.searchresultstree.setColumnWidth(0, 250) - if self.isdebug: + if self.isdebug and self.isdebug.level in {"ALL"}: self.searchresultstree.setColumnWidth(1, 350) else: self.searchresultstree.setColumnWidth(1, 450) self.searchresultstree.setColumnWidth(2, 100) self.searchresultstree.setColumnWidth(3, 150) - if self.isdebug: - self.searchresultstree.setColumnWidth(4, 150) - self.searchresultstree.setHeaderLabels([ - "Label", - "Description", - "Assignments", - "Symbol", - "Context" - ]) - else: - self.searchresultstree.setHeaderLabels([ - "Label", - "Description", - "Assignments", - "Symbol" - ]) - + self.searchresultstree.setColumnWidth(4, 150) + self.searchresultstree.setHeaderLabels([ + "Label", + "Description", + "Assignments", + "Symbol", + "Context" + ]) + self.searchresultstree.setColumnHidden(3, self.showctx) + if not self.isdebug.level in {"ALL"}: + self.searchresultstree.hideColumn(4) + + self.searchresultstree.header().setMinimumSectionSize(85) + self.searchresultstree.header().setSectionResizeMode( + 0, QtWidgets.QHeaderView.ResizeToContents) + self.searchresultstree.header().setSectionResizeMode( + 1, QtWidgets.QHeaderView.ResizeToContents) + self.searchresultstree.header().setSectionResizeMode( + 2, QtWidgets.QHeaderView.ResizeToContents) + self.searchresultstree.header().setSectionResizeMode( + 3, QtWidgets.QHeaderView.Stretch) + self.searchresultstree.header().setSectionResizeMode( + 4, QtWidgets.QHeaderView.ResizeToContents) + self.searchresultstree.setStyleSheet("""QHeaderView::section{ + color: rgb(200, 200, 200); + resize:both; + overflow:auto; + padding: 4px; + height:20px; + border: 0px solid rgb(150, 150, 150); + border-bottom: 1px solid rgb(150, 150, 150); + border-left:0px solid rgb(50, 50, 50); + border-right:1px solid rgb(60, 60, 60); + background: rgb(36, 36, 36); + }""") + + # -------------------------------------- Build Settings + # NOTE Build Settings --------------------------------- def uisetup(self): - self.main_widget = QtWidgets.QWidget(self) - names = ["open", "save", "hotkey", "perference"] self.completer = QtWidgets.QCompleter(names) @@ -171,8 +250,6 @@ def uisetup(self): self.verticalLayout = QtWidgets.QVBoxLayout() self.verticalLayout.setSpacing(0) - mainlayout = QtWidgets.QVBoxLayout() - self.titlerow = QtWidgets.QHBoxLayout() self.titlerow.setSpacing(5) @@ -194,7 +271,8 @@ def uisetup(self): QtWidgets.QSizePolicy.Minimum ) - self.helpButton = HelpButton.HelpButton("main") + self.contexttoggle = QtWidgets.QPushButton() + self.helpButton = HelpButton.HelpButton("main", la.ln_helppanel, 16) self.pinwindow_btn = QtWidgets.QToolButton() self.opensettings_btn = QtWidgets.QToolButton() @@ -204,8 +282,8 @@ def uisetup(self): QtWidgets.QSizePolicy.Minimum ) - # ------------------------------------------------------- Search Filter - # NOTE Search Filter -------------------------------------------------- + # ----------------------------------- Search Filter + # NOTE Search Filter ------------------------------ self.searchrow = QtWidgets.QHBoxLayout() self.searchrow.setSpacing(0) self.frame = QtWidgets.QFrame() @@ -228,8 +306,8 @@ def uisetup(self): u"background-color: rgb(19, 19, 19);") self.searchfilter_btn.setArrowType(QtCore.Qt.NoArrow) - # ---------------------------------------------------------- Search Box - # NOTE Search Box ----------------------------------------------------- + # -------------------------------------- Search Box + # NOTE Search Box --------------------------------- self.searchbox_txt = QtWidgets.QLineEdit() searchbox_details = QtWidgets.QSizePolicy( QtWidgets.QSizePolicy.Expanding, @@ -245,11 +323,14 @@ def uisetup(self): self.searchbox_txt.setStyleSheet(u"background-color: rgb(19, 19, 19);") self.searchbox_txt.setFrame(False) - # -------------------------------------------------------- Results Tree - # NOTE Results Tree --------------------------------------------------- + # ------------------------------------ Results Tree + # NOTE Results Tree ------------------------------- self.searchresults_tree = QtWidgets.QTreeWidget() __qtreewidgetitem = QtWidgets.QTreeWidgetItem() __qtreewidgetitem.setText(0, u"1") + resultstree_header = QtGui.QFont() + resultstree_header.setPointSize(9) + __qtreewidgetitem.setFont(0, resultstree_header) self.searchresults_tree.setHeaderItem(__qtreewidgetitem) resultstree_details = QtWidgets.QSizePolicy( QtWidgets.QSizePolicy.Preferred, @@ -257,8 +338,8 @@ def uisetup(self): ) resultstree_details.setHorizontalStretch(0) resultstree_details.setVerticalStretch(0) - resultstree_details.setHeightForWidth( - self.searchresults_tree.sizePolicy().hasHeightForWidth()) + # resultstree_details.setHeightForWidth( + # self.searchresults_tree.sizePolicy().hasHeightForWidth()) self.searchresults_tree.setSizePolicy(resultstree_details) resultstree_font = QtGui.QFont() resultstree_font.setPointSize(9) @@ -275,7 +356,12 @@ def uisetup(self): self.searchresults_tree.setSelectionBehavior( QtWidgets.QAbstractItemView.SelectRows) - # NOTE Info Panel --------------------------------------------------- Info Panel + # -------------------------------------- Info Panel + # NOTE Info Panel --------------------------------- + self.infobar = QtWidgets.QHBoxLayout() + self.infobar.setObjectName("infobar") + self.infobargrid = QtWidgets.QGridLayout() + self.infobargrid.setObjectName("infobargrid") self.info_lbl = QtWidgets.QLabel() self.infolbl_font = QtGui.QFont() self.infolbl_font.setPointSize(8) @@ -290,11 +376,24 @@ def uisetup(self): self.overlay.setStyleSheet(u"background-color: rgb(26, 26, 26);") self.overlay.setMargin(2) self.overlay.setIndent(5) - - # NOTE Layout Implementation ------------------------------------------ Layout Implementation + self.treetotal_lbl = QtWidgets.QLabel() + treetotal_size = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Preferred) + treetotal_size.setHorizontalStretch(0) + treetotal_size.setVerticalStretch(0) + treetotal_size.setHeightForWidth(self.treetotal_lbl.sizePolicy().hasHeightForWidth()) + self.treetotal_lbl.setSizePolicy(treetotal_size) + self.treetotal_lbl.setMinimumSize(QtCore.QSize(150, 0)) + self.treetotal_lbl.setMaximumSize(QtCore.QSize(150, 16777215)) + self.treetotal_lbl.setObjectName("treetotal_lbl") + self.treetotal_lbl.setStyleSheet(u"background-color: rgb(26, 26, 26);") + self.treetotal_lbl.setAlignment(QtCore.Qt.AlignRight|QtCore.Qt.AlignTrailing|QtCore.Qt.AlignVCenter) + + # --------------------------- Layout Implementation + # NOTE Layout Implementation ---------------------- self.titlerow.addItem(self.titlespacer1) self.titlerow.addWidget(self.searcherlbl) self.titlerow.addItem(self.titlespacer2) + self.titlerow.addWidget(self.contexttoggle) self.titlerow.addWidget(self.helpButton) self.titlerow.addWidget(self.pinwindow_btn) self.titlerow.addWidget(self.opensettings_btn) @@ -305,10 +404,15 @@ def uisetup(self): self.verticalLayout.addLayout(self.searchrow) self.verticalLayout.addWidget(self.searchresults_tree) self.gridLayout.addLayout(self.verticalLayout, 1, 0, 1, 1) - self.gridLayout.addWidget(self.overlay, 2, 0, 1, 1) - self.gridLayout.addWidget(self.info_lbl, 2, 0, 1, 1) + + self.infobargrid.addWidget(self.overlay, 1, 0, 1, 1) + self.infobargrid.addWidget(self.info_lbl, 1, 0, 1, 1) + self.infobargrid.addWidget(self.treetotal_lbl, 1, 1, 1, 1) + + self.infobar.addLayout(self.infobargrid) + self.gridLayout.addLayout(self.infobar, 3, 0, 1, 1) - # NOTE Layout to functionality connection ----------------------------- + # NOTE Layout to functionality connection --------- self.searchfilter = self.searchfilter_btn self.pinwindow = self.pinwindow_btn self.opensettingstool = self.opensettings_btn @@ -316,7 +420,8 @@ def uisetup(self): self.searchbox = self.searchbox_txt self.infolbl = self.info_lbl - # NOTE Settings and details ------------------------------------------- + # ---------------------------- Settings and details + # NOTE Settings and details ----------------------- self.searchbox.setPlaceholderText(" Begin typing to search..") self.searchbox.setFocusPolicy(QtCore.Qt.StrongFocus) self.searchbox.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) @@ -331,6 +436,21 @@ def uisetup(self): searchfilter_button_size )) + self.setctxicon() + self.contexttoggle.setCheckable(True) + self.contexttoggle.setChecked(self.showctx) + self.contexttoggle.setFixedWidth(20) + self.contexttoggle.setFixedHeight(20) + contexttoggle_button_size = hou.ui.scaledSize(16) + self.contexttoggle.setProperty("flat", True) + self.contexttoggle.setIconSize(QtCore.QSize( + contexttoggle_button_size, + contexttoggle_button_size + )) + self.contexttoggle.clicked[bool].connect(self.showctx_cb) + self.contexttoggle.setStyleSheet("QPushButton { width: 8px; border: none; }" + "QPushButton:checked { width: 8px; border: none;}") + self.pinwindow.clicked.connect(self.pinwindow_cb) self.setpinicon() pinwindow_button_size = hou.ui.scaledSize(16) @@ -340,6 +460,8 @@ def uisetup(self): pinwindow_button_size )) + self.opensettingstool.setCheckable(True) + self.opensettingstool.setChecked(False) self.opensettingstool.clicked.connect(self.opensettings_cb) opensettingstool_button_size = hou.ui.scaledSize(16) self.opensettingstool.setProperty("flat", True) @@ -353,31 +475,38 @@ def uisetup(self): self.searchbox.customContextMenuRequested.connect(self.openmenu) self.searchresultstree.itemActivated.connect(self.searchclick_cb) - mainlayout.setAlignment(QtCore.Qt.AlignBottom) - mainlayout.setContentsMargins(0, 0, 0, 0) - mainlayout.setGeometry(QtCore.QRect(0, 0, 1400, 1200)) - - mainlayout.addLayout(self.gridLayout) - self.setLayout(mainlayout) - - self.searchbox.setToolTip( - 'Begin typing to search or click magnifying glass icon to display options') - self.pinwindow.setToolTip( - 'Pin the search window to keep it from closing automatically when losing focus') - self.searchfilter.setToolTip( - 'Select a predefined filter') - self.opensettingstool.setToolTip( - 'General application settings') - self.searchresultstree.setToolTip( - 'Double click an action to attempt to perform it. Some actions only work in specific contexts') - - self.setupContext() + self.mainlayout.setAlignment(QtCore.Qt.AlignBottom) + self.mainlayout.setContentsMargins(0, 0, 0, 0) + self.mainlayout.setGeometry(QtCore.QRect(0, 0, 1400, 1200)) + + self.mainlayout.addLayout(self.gridLayout) + self.setLayout(self.mainlayout) + + self.searchbox.setToolTip(la.ln_searchbox) + self.contexttoggle.setToolTip(la.ln_contexttoggle) + self.pinwindow.setToolTip(la.ln_pinwindow) + self.searchfilter.setToolTip(la.ln_searchfilter) + self.opensettingstool.setToolTip(la.ln_opensettingstool) + self.searchresultstree.setToolTip(la.ln_searchresultstree) + + self.setupresulttree() self.searchbox.setFocus() self.searchbox.grabKeyboard() + # !SECTION - # region ------------------------------------------------------------- Initial Setup + # ---------------------------------------------------------- Functions + # SECTION Functions -------------------------------------------------- + # ----------------------------------- count_chars Setup + # NOTE count_chars ------------------------------------ + def count_chars(self, txt): + result = 0 + for char in txt: + result += 1 # same as result = result + 1 + return result + # --------------------------------------- Initial Setup + # NOTE Initial Setup ---------------------------------- def initialsetup(self): self.handler = datahandler.DataHandler(self.isdebug) currentidx = hou.hotkeys.changeIndex() @@ -386,7 +515,7 @@ def initialsetup(self): if len(chindex) == 0: chindex = int(currentidx) self.handler.updatechangeindex(chindex, True) - self.handler.updatedataasync() + self.handler.updatedataasync(self.isdebug) hou.ui.setStatusMessage( "Searcher database created", severity=hou.severityType.Message @@ -395,13 +524,15 @@ def initialsetup(self): chindex = int(chindex[0][0]) if int(currentidx) != chindex: - self.handler.updatedataasync() + self.handler.updatedataasync(self.isdebug) self.handler.updatechangeindex(int(currentidx)) tmpkey = self.handler.getdefaulthotkey() self.tmpkey = tmpkey[0][0] return self.handler, self.tmpkey + # ------------------------------------------- Panel/Node + # NOTE Panel/Node ------------------------------------- def getnode(self): nodeSelect = hou.selectedNodes() for node in nodeSelect: @@ -416,27 +547,93 @@ def getpane(self): ("No context options to display" + str(e)), severity=hou.severityType.Message ) + # !SECTION - # endregion - # region ------------------------------------------------------------- Callbacks + # ---------------------------------------------------------- Callbacks + # SECTION Callbacks -------------------------------------------------- + # ------------------------------------- searchfilter_cb + # NOTE searchfilter_cb -------------------------------- def searchfilter_cb(self): self.openmenu() + # ------------------------------------------ setctxicon + # NOTE setctxicon ------------------------------------- + def setctxicon(self): + if self.showctx: + self.contexttoggle.setIcon(util.COLLAPSE_ICON) + else: + self.contexttoggle.setIcon(util.EXPAND_ICON) + self.searchresultstree.setColumnHidden(3, self.showctx) + + # ------------------------------------------ showctx_cb + # NOTE showctx_cb ------------------------------------- + def showctx_cb(self, pressed): + self.showctx = True if pressed else False + self.settingdata[util.SETTINGS_KEYS[7]] = self.showctx + searcher_data.savesettings(self.settingdata) + self.setctxicon() + + # ---------------------------------------- pinwindow_cb + # NOTE pinwindow_cb ----------------------------------- def pinwindow_cb(self): self.windowispin = not self.windowispin self.settingdata[util.SETTINGS_KEYS[5]] = self.windowispin searcher_data.savesettings(self.settingdata) self.setpinicon() + # ------------------------------------------ setpinicon + # NOTE setpinicon ------------------------------------- def setpinicon(self): if self.windowispin: self.pinwindow.setIcon(util.PIN_IN_ICON) else: self.pinwindow.setIcon(util.PIN_OUT_ICON) - def opensettings_cb(self): - self.open_settings() + # ------------------------------------- opensettings_cb + # NOTE opensettings_cb -------------------------------- + def opensettings_cb(self, doopen): + self.ui.isopened = self.ui.isVisible() + + if self.animatedsettings: + self.open_settings(doopen) + elif self.ui.isopened: + self.open_settings(False) + else: + self.open_settings(True) + + # --------------------------------------- open_settings + # NOTE open_settings ---------------------------------- + def open_settings(self, doopen): + if doopen: + pos = self.opensettingstool.mapToGlobal( + QtCore.QPoint(-self.ui.width() + 31, 26)) + self.ui.setGeometry( + pos.x(), + pos.y(), + self.ui.width(), + self.ui.height() + ) + self.ui.show() + self.ui.activateWindow() + self.ui.setFocus() + if self.animatedsettings: + self.anim.start_animation(True) + else: + if self.animatedsettings: + _ = self.anim.start_animation(False) + else: + self.ui.isopened = True + self.ui.close() + + def anim_complete(self): + if self.ui.isopened: + self.ui.close() + self.ui.isopened = False + self.opensettingstool.setChecked(False) + + # ------------------------------------- globalkeysearch + # NOTE globalkeysearch -------------------------------- def globalkeysearch(self): self.ctxsearch = True ctx = [] @@ -445,6 +642,8 @@ def globalkeysearch(self): self.searchtablepopulate(results) self.ctxsearch = False + # ----------------------------------------- ctxsearcher + # NOTE ctxsearcher ------------------------------------ def ctxsearcher(self, ctx=None): results = None ctxresult = [] @@ -481,21 +680,36 @@ def ctxsearcher(self, ctx=None): self.searchresultstree.topLevelItem(0).child(0) ) + # --------------------------------------- textchange_cb + # NOTE textchange_cb ---------------------------------- def textchange_cb(self, text): + self.starttime = ptime.time() if len(text) > 0: self.infolbl.setText(self.searchresultstree.toolTip()) if text in util.CTXSHOTCUTS: self.ctxsearcher(text) elif len(text) > 1 and text not in util.CTXSHOTCUTS: self.searching = True - txt = self.handler.searchtext(text) + allowed = re.compile(r'[^a-zA-Z ]+') + text = re.sub(allowed, '', text) + str = text.split() + searchstring = ['%s*' % (x,) for x in str] + txt, timer = self.handler.searchtext( + ' '.join(searchstring), + self.isdebug, + self.settingdata[util.SETTINGS_KEYS[9]] + ) + self.hotkeystime = timer self.searchtablepopulate(txt) else: self.searching = False + self.treetotal_lbl.setText("") self.searchresultstree.clear() self.infolbl.setText( "Begin typing to search or click magnifying glass icon to display options") - + + # -------------------------------------- searchclick_cb + # NOTE searchclick_cb --------------------------------- def searchclick_cb(self, item, column): hk = item.text(2) self.tmpsymbol = item.text(3) @@ -512,9 +726,12 @@ def searchclick_cb(self, item, column): self.processkey(hk) self.tmpsymbol = None return - # endregion + # !SECTION - # region ------------------------------------------------------------- Searchbar Menu + # ------------------------------------------------------------- Search + # SECTION Search ----------------------------------------------------- + # -------------------------------------------- openmenu + # NOTE openmenu --------------------------------------- def openmenu(self): self.menuopened = True self.searchmenu = QtWidgets.QMenu() @@ -565,9 +782,13 @@ def getContext(self, ctx): # endregion - # region ------------------------------------------------------------- Search Functionality + # --------------------------------- searchtablepopulate + # NOTE searchtablepopulate ---------------------------- def searchtablepopulate(self, data): if len(data) > 0: + goalnum = 15 + self.treecatnum = 0 + self.treeitemsnum = 0 self.searchresultstree.clear() hotkeys = [] context_list = [] @@ -580,8 +801,10 @@ def searchtablepopulate(self, data): else: context_list.append(data[i][4]) - result = self.handler.gethcontextod(context_list) + result, hctimer = self.handler.gethcontextod(context_list) + self.hcontexttime = hctimer + treebuildtimer = ptime.time() for hc in range(len(result)): hcontext_tli[result[hc][2]] = (QtWidgets.QTreeWidgetItem( self.searchresultstree, [ @@ -590,12 +813,13 @@ def searchtablepopulate(self, data): ] )) self.searchresultstree.expandItem(hcontext_tli[result[hc][2]]) + self.treecatnum += 1 base_keys = hcontext_tli.keys() for i in range(len(data)): for j in range(len(base_keys)): if base_keys[j] in data[i][4]: - if self.isdebug: + if self.isdebug and self.isdebug.level in {"ALL"}: hotkeys.append(QtWidgets.QTreeWidgetItem( hcontext_tli[base_keys[j]], [ data[i][0], @@ -605,6 +829,7 @@ def searchtablepopulate(self, data): data[i][4] ] )) + self.treeitemsnum += 1 else: hotkeys.append(QtWidgets.QTreeWidgetItem( hcontext_tli[base_keys[j]], [ @@ -614,10 +839,36 @@ def searchtablepopulate(self, data): data[i][3] ] )) - # endregion + self.treeitemsnum += 1 + treebuildtimerend = ptime.time() + treebuildtotal = ((treebuildtimerend - treebuildtimer) * 1000.0) + + # Display the number of added results by iteration + catval = ("%d : Contexts | " % self.treecatnum) + itmval = ("%d : Results " % self.treeitemsnum) + catval = catval.rjust(goalnum - self.count_chars(str(self.treecatnum)), " ") + itmval = itmval.rjust((goalnum + 2) - self.count_chars(str(self.treeitemsnum)), " ") + self.treetotal_lbl.setText((catval + itmval)) + + # Performance monitors to check how long different aspects take to run ---------- + self.endtime = ptime.time() + totaltime = ((self.endtime - self.starttime) * 1000.0) + if self.isdebug and self.isdebug.level in {"TIMER", "ALL"}: + if hou.isUIAvailable(): + hou.ui.setStatusMessage( + (('Context Search %0.4f ms | ' % self.hcontexttime) + + ('Hotkey Search %0.4f ms | ' % self.hotkeystime) + + ('Tree build %0.4f ms | ' % treebuildtotal) + + ('Total : %0.4f ms' % (totaltime))) , severity=hou.severityType.Message) + else: + print('Search took %0.4f ms' % self.hotkeystime) - # region ------------------------------------------------------------- Hotkey Processing + # !SECTION + # -------------------------------------------------- Hotkey Processing + # SECTION Hotkey Processing ------------------------------------------ + # -------------------------------------- processkey + # NOTE processkey --------------------------------- def processkey(self, key, tmphk=False): hk = key if tmphk: @@ -647,8 +898,7 @@ def processkey(self, key, tmphk=False): hou.ui.mainQtWindow().setFocus() try: - hd.executeDeferred(self.app.sendEvent, - hou.ui.mainQtWindow(), keypress) + hd.executeDeferred(self.app.sendEvent, hou.ui.mainQtWindow(), keypress) self.close() except(AttributeError, TypeError) as e: @@ -657,6 +907,8 @@ def processkey(self, key, tmphk=False): severity=hou.severityType.Warning ) + # ---------------------------------- setKeysChanged + # NOTE setKeysChanged ----------------------------- def setKeysChanged(self, changed): if self.keys_changed and not changed: if not hou.hotkeys.saveOverrides(): @@ -665,6 +917,8 @@ def setKeysChanged(self, changed): self.chindex = hou.hotkeys.changeIndex() self.handler.updatechangeindex(self.chindex) + # -------------------------------- createtemphotkey + # NOTE createtemphotkey --------------------------- def createtemphotkey(self, symbol): hou.hotkeys._createBackupTables() result = hou.hotkeys.addAssignment(symbol, self.tmpkey) @@ -672,15 +926,20 @@ def createtemphotkey(self, symbol): self.setKeysChanged(False) return result + # -------------------------------- removetemphotkey + # NOTE removetemphotkey --------------------------- def removetemphotkey(self, symbol, tmpkey): hou.hotkeys._restoreBackupTables() hou.hotkeys.revertToDefaults(symbol, True) self.keys_changed = True self.setKeysChanged(False) - # endregion - # region ------------------------------------------------------------- Events - # SECTION Events + # !SECTION + + # --------------------------------------------------------- Animations + # SECTION Animations ------------------------------------------------- + # ----------------------------------------- fade_in + # NOTE fade_in ------------------------------------ def fade_in(self, target, duration): self.effect = QtWidgets.QGraphicsOpacityEffect() self.tar = target @@ -691,6 +950,8 @@ def fade_in(self, target, duration): self.an.setEndValue(1) self.an.start() + # ---------------------------------------- fade_out + # NOTE fade_out ----------------------------------- def fade_out(self, target, duration): self.effect = QtWidgets.QGraphicsOpacityEffect() self.tar = target @@ -700,37 +961,38 @@ def fade_out(self, target, duration): self.an.setStartValue(1) self.an.setEndValue(0) self.an.start() + # !SECTION - def info_fade(self, fadein): - self.animation = QtCore.QPropertyAnimation( - self.infolbl_font, b'opacity') - self.animation.setDuration(200) - if fadein: - self.animation.setStartValue(0.0) - self.animation.setEndValue(1.0) - if not fadein: - self.animation.setStartValue(1.0) - self.animation.setEndValue(0.0) - self.animation.setEasingCurve(QtGui.QEasingCurve.OutQuad) - self.animation.start() - - def checktooltip(self, obj): - if obj == self.searchresultstree: - if self.searching: - self.infolbl.setText(obj.toolTip()) - else: + # ------------------------------------------------------------- Events + # SECTION Events ----------------------------------------------------- + def checktooltip(self, obj, hasleft=False): + if hasleft: + # self.fade_out(self.infolbl, 200) + if self.searching and self.infolbl.text() != self.searchresultstree.toolTip(): + self.infolbl.setText(self.searchresultstree.toolTip()) + self.fade_in(self.infolbl, 200) + elif not self.searching and self.infolbl.text() != self.searchbox.toolTip(): self.infolbl.setText(self.searchbox.toolTip()) + self.fade_in(self.infolbl, 200) else: - self.fade_in(self.infolbl, 200) - self.infolbl.setText(obj.toolTip()) + if obj == self.searchresultstree or obj == self.searchbox: + if self.searching and self.infolbl.text() != self.searchresultstree.toolTip(): + self.infolbl.setText(self.searchresultstree.toolTip()) + self.fade_in(self.infolbl, 200) + elif not self.searching and self.infolbl.text() != self.searchbox.toolTip(): + self.infolbl.setText(self.searchbox.toolTip()) + self.fade_in(self.infolbl, 200) + elif self.infolbl.text() != obj.toolTip(): + self.infolbl.setText(obj.toolTip()) + self.fade_in(self.infolbl, 200) def eventFilter(self, obj, event): - # ---------------------------------------------------- Mouse + # ------------------------------------------- Mouse + # NOTE Mouse -------------------------------------- if event.type() == QtCore.QEvent.Enter: self.checktooltip(obj) if event.type() == QtCore.QEvent.Leave: - self.fade_out(self.infolbl, 200) - self.infolbl.setText("") + self.checktooltip(obj, True) if event.type() == QtCore.QEvent.ToolTip: return True @@ -744,7 +1006,10 @@ def eventFilter(self, obj, event): if event.type() == QtCore.QEvent.MouseMove: if obj == self: delta = event.globalPos() - self.previous_pos - self.move(self.x() + delta.x(), self.y()+delta.y()) + self.move(self.x() + delta.x(), self.y() + delta.y()) + if self.ui.isVisible(): + self.ui.move(self.ui.x() + delta.x(), + self.ui.y() + delta.y()) self.previous_pos = event.globalPos() self._drag_active = True else: @@ -754,7 +1019,8 @@ def eventFilter(self, obj, event): if self._drag_active: self._drag_active = False - # ------------------------------------------------- Keypress + # ---------------------------------------- Keypress + # NOTE Keypress ----------------------------------- if event.type() == QtCore.QEvent.KeyPress: if event.key() == QtCore.Qt.Key_Tab: if self.searching: @@ -795,7 +1061,8 @@ def eventFilter(self, obj, event): self.openmenu() return True - # ------------------------------------------------- Window + # ------------------------------------------ Window + # NOTE Window ------------------------------------- if event.type() == QtCore.QEvent.WindowActivate: self.searchbox.grabKeyboard() elif event.type() == QtCore.QEvent.WindowDeactivate: @@ -812,20 +1079,16 @@ def eventFilter(self, obj, event): elif event.type() == QtCore.QEvent.FocusOut: pass - # ------------------------------------------------- Close + # ------------------------------------------- Close + # NOTE Close -------------------------------------- if event.type() == QtCore.QEvent.Close: try: if util.bc(self.settingdata[util.SETTINGS_KEYS[2]]): - self.windowsettings.setValue( - "geometry", - self.saveGeometry() - ) + self.windowsettings.setValue("geometry", self.saveGeometry()) except (AttributeError, TypeError) as e: if hou.isUIAvailable(): hou.ui.setStatusMessage( - ("Could not save window dimensions: " + str(e)), - severity=hou.severityType.Warning - ) + ("Could not save window dimensions: " + str(e)), severity=hou.severityType.Warning) else: print("Could not save window dimensions: " + str(e)) @@ -857,7 +1120,7 @@ def __init__(self, parent=None): super(overlayLabel, self).__init__(parent) self.setAlignment(QtCore.Qt.AlignHCenter | QtCore.Qt.AlignVCenter) -# region ----------------------------------------------------------------- Setup Functions +# ------------------------------------------------------------------------ Setup Functions def center(): @@ -882,12 +1145,8 @@ def CreateSearcherPanel(kwargs, searcher_window=None): searcher_window = Searcher(kwargs, settings, windowsettings) searcher_window.setWindowFlags( - # searcher_window.windowFlags() | QtCore.Qt.Tool | - # QtCore.Qt.WindowSystemMenuHint | - # QtCore.Qt.WindowTitleHint | QtCore.Qt.CustomizeWindowHint | - # QtCore.Qt.FramelessWindowHint QtCore.Qt.WindowStaysOnTopHint ) diff --git a/scripts/python/searcher/searcher_data.py b/scripts/python/searcher/searcher_data.py index 750b817..80062c7 100644 --- a/scripts/python/searcher/searcher_data.py +++ b/scripts/python/searcher/searcher_data.py @@ -24,21 +24,14 @@ settingsdata = QtCore.QSettings(searcher_settings, QtCore.QSettings.IniFormat) -DEFAULT_SETTINGS = { - util.SETTINGS_KEYS[0]: "False", # in_memory_db - util.SETTINGS_KEYS[1]: defaultdbpath, # database_path - util.SETTINGS_KEYS[2]: "False", # savewindowsize - util.SETTINGS_KEYS[3]: [1000, 600], # windowsize - util.SETTINGS_KEYS[4]: "False", # debugflag - util.SETTINGS_KEYS[5]: "False", # pinwindow -} - def createdefaults(): + def_set = util.DEFAULT_SETTINGS + def_set[util.SETTINGS_KEYS[1]] = str(defaultdbpath) settingsdata.beginGroup('Searcher') for i in range(len(util.SETTINGS_KEYS)): settingsdata.setValue( - util.SETTINGS_KEYS[i], DEFAULT_SETTINGS[util.SETTINGS_KEYS[i]]) + util.SETTINGS_KEYS[i], def_set[util.SETTINGS_KEYS[i]]) settingsdata.endGroup() diff --git a/scripts/python/searcher/searcher_settings.py b/scripts/python/searcher/searcher_settings.py index 93fd8a0..597feab 100644 --- a/scripts/python/searcher/searcher_settings.py +++ b/scripts/python/searcher/searcher_settings.py @@ -1,11 +1,15 @@ from __future__ import division from __future__ import print_function from __future__ import absolute_import -import weakref +from searcher import about +from searcher import bugreport +from searcher import bugreport_ui +from searcher import about_ui from searcher import searcher_data from searcher import util -from inspect import currentframe +from searcher import language_en as la +from searcher import searchersettings_ui from builtins import range from past.utils import old_div @@ -29,40 +33,41 @@ from hutil.Qt import _QtUiTools elif hver < 391 and hver >= 348: from hutil.Qt import QtUiTools -# else: -# os.environ['QT_API'] = 'pyside2' -# from PySide import QtUiTools -# from qtpy import QtGui -# from qtpy import QtCore -# from qtpy import QtWidgets +reload(about) +reload(about_ui) +reload(bugreport) +reload(bugreport_ui) +reload(searchersettings_ui) -the_scaled_icon_size = hou.ui.scaledSize(16) -the_icon_size = 16 - -num = 0 -# info +# -------------------------------------------------------------------- App Info +__package__ = "Searcher" +__version__ = "0.1b" __author__ = "instance.id" -__copyright__ = "2020 All rights reserved." +__copyright__ = "2020 All rights reserved. See LICENSE for more details." __status__ = "Prototype" -scriptpath = os.path.dirname(os.path.realpath(__file__)) +the_scaled_icon_size = hou.ui.scaledSize(16) +the_icon_size = 16 +scriptpath = os.path.dirname(os.path.realpath(__file__)) def bc(v): return str(v).lower() in ("yes", "true", "t", "1") - class SearcherSettings(QtWidgets.QWidget): """ Searcher Settings and Debug Menu""" def __init__(self, handler, tmphotkey, parent=None): super(SearcherSettings, self).__init__(parent=parent) - - # ------------------------------------------------- Component variables + # -------------------------------------------- settings + # NOTE settings --------------------------------------- + self.parentwindow = parent self.settings = {} self.context_dict = {} self.command_dict = {} + self.currentsettings = {} + self.performcheck = True self.contexts = None self.commands = None self.addKeyWidget = None @@ -74,79 +79,104 @@ def __init__(self, handler, tmphotkey, parent=None): self.canedit = False self.KeySequence = None self.hkholder = "" - self.hkinput = tmphotkey + self.defaulthotkey = tmphotkey self.datahandler = handler self.tmphotkey = tmphotkey + self.isopened = False self.setObjectName('searcher-settings') - # ------------------------------------------------- Build UI + # --------------------------------------------- beginui + # NOTE beginui ---------------------------------------- self.setAutoFillBackground(True) self.setBackgroundRole(QtGui.QPalette.Window) - self.setWindowFlags(QtCore.Qt.WindowStaysOnTopHint) self.settings = searcher_data.loadsettings() - self.isdebug = bc(self.settings[util.SETTINGS_KEYS[4]]) + self.isdebug = util.Dbug(util.bc(self.settings[util.SETTINGS_KEYS[4]]), str(self.settings[util.SETTINGS_KEYS[10]])) + self.la = la.TT_SETTINGS # Load UI File - loader = None - if int(hver) >= 391 and int(hver) <= 394: - loader = _QtUiTools.QUiLoader() - else: - loader = QtUiTools.QUiLoader() - self.ui = loader.load(scriptpath + '/searchersettings.ui') + self.ui = searchersettings_ui.Ui_SearcherSettings() + self.ui.setupUi(self, self.width, self.height, bc(self.settings[util.SETTINGS_KEYS[8]])) + self.ui.retranslateUi(self) + + self.bugreport = bugreport.BugReport(self.parentwindow) + self.bugreport.setAttribute(QtCore.Qt.WA_StyledBackground, True) + self.bugreport.setWindowFlags( + QtCore.Qt.Tool | + QtCore.Qt.WindowStaysOnTopHint | + QtCore.Qt.FramelessWindowHint | + QtCore.Qt.NoDropShadowWindowHint + ) + self.bugreport.setParent(self.parentwindow) + self.bugreport.resize(520, 250) + + self.settingslayout = QtWidgets.QVBoxLayout() # Get UI Elements - self.hotkey_icon = self.ui.findChild( - QtWidgets.QToolButton, - "hotkey_icon" - ) - self.debugflag = self.ui.findChild( - QtWidgets.QCheckBox, - "debugflag_chk" - ) - self.in_memory_db = self.ui.findChild( - QtWidgets.QCheckBox, - "inmemory_chk" - ) - self.savewindowsize = self.ui.findChild( - QtWidgets.QCheckBox, - "windowsize_chk" - ) - self.hkinput = self.ui.findChild( - QtWidgets.QLineEdit, - "hkinput_txt" - ) - self.database_path = self.ui.findChild( - QtWidgets.QLineEdit, - "databasepath_txt" - ) - self.test1 = self.ui.findChild( - QtWidgets.QPushButton, - "test1_btn" - ) - self.testcontext = self.ui.findChild( - QtWidgets.QPushButton, - "test_context_btn" - ) - self.cleardata = self.ui.findChild( - QtWidgets.QPushButton, - "cleardata_btn" - ) - self.savedata = self.ui.findChild( - QtWidgets.QPushButton, - "save_btn" - ) - self.discarddata = self.ui.findChild( - QtWidgets.QPushButton, - "discard_btn" - ) + self.hotkey_icon = self.ui.hotkey_icon + + # headerrow + self.in_memory_db = self.ui.inmemory_chk + self.in_memory_db.setToolTip(la.TT_SETTINGS[self.in_memory_db.objectName()]) + self.savewindowsize = self.ui.windowsize_chk + self.savewindowsize.setToolTip(la.TT_SETTINGS[self.savewindowsize.objectName()]) + + # secondrow + self.maxresults = self.ui.maxresults_txt + self.maxresults.setToolTip(la.TT_SETTINGS[self.maxresults.objectName()]) + self.animatedsettings = self.ui.animatedsettings_chk + self.animatedsettings.setToolTip(la.TT_SETTINGS[self.animatedsettings.objectName()]) + # thirdrow + self.defaulthotkey = self.ui.defaulthotkey_txt + self.defaulthotkey.setToolTip(la.TT_SETTINGS[self.defaulthotkey.objectName()]) + self.database_path = self.ui.databasepath_txt + self.database_path.setToolTip(la.TT_SETTINGS[self.database_path.objectName()]) + + # fourthrow + self.test1 = self.ui.test1_btn + self.cleardata = self.ui.cleardata_btn + self.cleardata.setToolTip(la.TT_SETTINGS[self.cleardata.objectName()]) + + # fifthrow + self.about = self.ui.about_btn + self.about.setToolTip(la.TT_SETTINGS[self.about.objectName()]) + about_button_size = hou.ui.scaledSize(32) + self.about.setProperty("flat", True) + self.about.setIcon(util.ABOUT_ICON1) + self.about.setIconSize(QtCore.QSize( + about_button_size, + about_button_size + )) - mainlayout = QtWidgets.QVBoxLayout() - mainlayout.addWidget(self.ui) + self.bugreportbtn = self.ui.bug_btn + self.bugreportbtn.setCheckable(True) + self.bugreportbtn.setChecked(False) + bugreport_button_size = hou.ui.scaledSize(21) + self.bugreportbtn.setProperty("flat", True) + self.bugreportbtn.setIcon(util.BUG_ICON) + self.bugreportbtn.setIconSize(QtCore.QSize( + bugreport_button_size, + bugreport_button_size + )) + + self.debuglevel = self.ui.debuglevel_cbx + for lvl in util.DEBUG_LEVEL: + self.debuglevel.addItem(str(lvl)) + self.debuglevel.setToolTip(la.TT_SETTINGS[self.debuglevel.objectName()]) + self.debugflag = self.ui.debugflag_chk + self.debugflag.setToolTip(la.TT_SETTINGS[self.debugflag.objectName()]) + self.debuglevel.setVisible(bc(self.settings[util.SETTINGS_KEYS[4]])) + self.debugflag.setVisible(bc(self.settings[util.SETTINGS_KEYS[4]])) - # ------------------------------------------------- Create Connections + self.savedata = self.ui.save_btn + self.savedata.setToolTip(la.TT_SETTINGS[self.savedata.objectName()]) + + self.discarddata = self.ui.discard_btn + self.discarddata.setToolTip(la.TT_SETTINGS[self.discarddata.objectName()]) + + # -------------------------------------------- sixthrow + # NOTE sixthrow --------------------------------------- # self.in_memory_db.stateChanged.connect(self.toggledebug) self.hotkey_icon.clicked.connect(self.hotkeyicon_cb) - self.hotkey_icon.setIcon(util.INFO_ICON) info_button_size = hou.ui.scaledSize(16) self.hotkey_icon.setProperty("flat", True) self.hotkey_icon.setIcon(util.INFO_ICON) @@ -155,45 +185,85 @@ def __init__(self, handler, tmphotkey, parent=None): info_button_size )) - self.hkinput.setText(self.tmphotkey) - self.hkinput.setStatusTip("Status Tip?") - self.hkinput.setWhatsThis("Whats this?") - # self.hkinput.setToolTip( - # "If left to the default value of (Ctrl+Alt+Shift+F7), in the event that Searcher detects a conflict it will automatically attempt to try different key combinations.") - self.hkinput.setStyleSheet(util.TOOLTIP) - self.database_path.setText(str(self.settings['database_path'])) + self.defaulthotkey.setToolTip(la.TT_SETTINGS[self.discarddata.objectName()]) + self.defaulthotkey.setStyleSheet(util.TOOLTIP) + + # --------------------------------------------- connect + # NOTE connect ---------------------------------------- self.test1.clicked.connect(self.test1_cb) - self.testcontext.clicked.connect(self.testcontext_cb) self.cleardata.clicked.connect(self.cleardata_cb) + self.about.clicked.connect(self.about_cb) + self.bugreportbtn.clicked.connect(self.bug_cb) self.savedata.clicked.connect(self.save_cb) self.discarddata.clicked.connect(self.discard_cb) - # ------------------------------------------------- Apply Layout - self.setLayout(mainlayout) + # -------------------------------------------- about_cb + # NOTE about_cb --------------------------------------- + self.settingslayout = self.ui.verticallayout + self.setLayout(self.ui.gridLayout) + + # ---------------------------------------- eventfilters + # NOTE eventfilters ----------------------------------- self.installEventFilter(self) + self.about.installEventFilter(self) + self.cleardata.installEventFilter(self) + self.savedata.installEventFilter(self) + self.discarddata.installEventFilter(self) + self.updatecurrentvalues() + self.fieldsetup() + + # --------------------------------------------------------------- Callbacks + # SECTION Callbacks ------------------------------------------------------- + + def bug_cb(self, toggled): + pos = self.bugreportbtn.mapToGlobal( + QtCore.QPoint( -43, 35)) + self.bugreport.setGeometry( + pos.x(), + pos.y(), + self.bugreport.width(), + self.bugreport.height() + ) + + if toggled == True: + self.bugreport.show() + else: + self.bugreport.close() + self.bugreport.setParent(None) + # -------------------------------------------- about_cb + # NOTE about_cb --------------------------------------- + def about_cb(self): + self.about = about.About(self.parentwindow) + self.about.setAttribute(QtCore.Qt.WA_StyledBackground, True) + self.about.setWindowFlags( + QtCore.Qt.Popup | + QtCore.Qt.WindowStaysOnTopHint | + QtCore.Qt.NoDropShadowWindowHint | + QtCore.Qt.WindowStaysOnTopHint - self.debugflag.setChecked(bc(self.settings[util.SETTINGS_KEYS[4]])) - self.debugflag.setVisible(bc(self.settings[util.SETTINGS_KEYS[4]])) - self.in_memory_db.setChecked(bc(self.settings[util.SETTINGS_KEYS[0]])) - self.savewindowsize.setChecked( - bc(self.settings[util.SETTINGS_KEYS[2]])) - - # ------------------------------------------------- Add EventFilters - self.hkinput.installEventFilter(self) - self.debugflag.installEventFilter(self) - # ----------------------------------------------------------------------------------- Callbacks - + ) + self.about.setParent(self.parentwindow) + self.about.move(self.pos().x() - 175, self.pos().y()) + self.about.show() + + # NOTE hotkeyicon_cb ---------------------------------- def hotkeyicon_cb(self): self.settings['in_memory_db'] = self.in_memory_db.isChecked() print(self.settings['in_memory_db']) + # ----------------------------------------- toggledebug + # NOTE toggledebug ------------------------------------ def toggledebug(self): self.settings['in_memory_db'] = self.in_memory_db.isChecked() print(self.settings['in_memory_db']) + # ---------------------------------------- defaulthk_cb + # NOTE defaulthk_cb ----------------------------------- def defaulthk_cb(self): return + # -------------------------------------------- test1_cb + # NOTE test1_cb --------------------------------------- def test1_cb(self): hkeys = [] for i in range(len(util.HOTKEYLIST)): @@ -205,96 +275,213 @@ def test1_cb(self): hkeys.append(result) print (hkeys) + # --------------------------------------- cleardata_cb + # NOTE cleardata_cb ---------------------------------- def cleardata_cb(self): self.datahandler.cleardb() + # --------------------------------------------- save_cb + # NOTE save_cb ---------------------------------------- def save_cb(self): - if self.hkinput.text() == "": - buttonindex = hou.ui.displayMessage("Please enter a hotkey") + if self.defaulthotkey.text() == "": + _ = hou.ui.displayMessage("Please enter a hotkey") self.activateWindow() - self.hkinput.setFocus() + self.defaulthotkey.setFocus() self.canedit = True else: - if self.hkinput.text() != self.tmphotkey: - self.tmphotkey = self.hkinput.text() + if self.defaulthotkey.text() != self.tmphotkey: + self.tmphotkey = self.defaulthotkey.text() self.datahandler.updatetmphotkey(self.tmphotkey) for i in range(len(util.SETTINGS_KEYS)): if util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "bool": - self.settings[util.SETTINGS_KEYS[i]] = getattr( - self, util.SETTINGS_KEYS[i]).isChecked() + self.settings[util.SETTINGS_KEYS[i]] = getattr(self, util.SETTINGS_KEYS[i]).isChecked() elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "text": - self.settings[util.SETTINGS_KEYS[i]] = getattr( - self, util.SETTINGS_KEYS[i]).text() + self.settings[util.SETTINGS_KEYS[i]] = getattr(self, util.SETTINGS_KEYS[i]).text() + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "intval": + self.settings[util.SETTINGS_KEYS[i]] = getattr(self, util.SETTINGS_KEYS[i]).value() + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "cbx": + self.settings[util.SETTINGS_KEYS[i]] = getattr(self, util.SETTINGS_KEYS[i]).currentText() - if self.isdebug: + if self.isdebug and self.isdebug.level in {"ALL"}: print(self.settings) searcher_data.savesettings(self.settings) - self.close() - + self.performcheck = False + if self.animatedsettings: + self.parentwindow.anim.start_animation(False) + self.isopened = True + else: + self.close() + # ------------------------------------------ discard_cb + # NOTE discard_cb ------------------------------------- def discard_cb(self): - self.hkinput.setText(self.tmphotkey) - self.hkholder = "" - self.close() - - def testcontext_cb(self): - return + if self.animatedsettings: + self.parentwindow.anim.start_animation(False) + self.isopened = True + self.performcheck=True + else: + self.close() - # ----------------------------------------------------------------------------------- Actions + # !SECTION + + # ----------------------------------------------------------------- Actions + # SECTION Actions --------------------------------------------------------- + # --------------------------------- updatecurrentvalues + # NOTE updatecurrentvalues ---------------------------- + def updatecurrentvalues(self): + for i in range(len(util.SETTINGS_KEYS)): + self.currentsettings[util.SETTINGS_KEYS[i]] = self.settings[util.SETTINGS_KEYS[i]] + if util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "bool": + getattr(self, util.SETTINGS_KEYS[i]).setChecked(bc(self.currentsettings[util.SETTINGS_KEYS[i]])) + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "text": + getattr(self, util.SETTINGS_KEYS[i]).setText(self.currentsettings[util.SETTINGS_KEYS[i]]) + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "intval": + getattr(self, util.SETTINGS_KEYS[i]).setValue(int(self.currentsettings[util.SETTINGS_KEYS[i]])) + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "cbx": + getattr(self, util.SETTINGS_KEYS[i]).setCurrentText(str(self.currentsettings[util.SETTINGS_KEYS[i]])) + + # ------------------------------------------ fieldsetup + # NOTE fieldsetup ------------------------------------- + def fieldsetup(self): + for i in range(len(util.SETTINGS_KEYS)): + if util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "bool": + getattr(self, util.SETTINGS_KEYS[i]).setChecked(bc(self.currentsettings[util.SETTINGS_KEYS[i]])) + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "text": + getattr(self, util.SETTINGS_KEYS[i]).setText(self.currentsettings[util.SETTINGS_KEYS[i]]) + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "intval": + getattr(self, util.SETTINGS_KEYS[i]).setValue(int(self.currentsettings[util.SETTINGS_KEYS[i]])) + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "cbx": + getattr(self, util.SETTINGS_KEYS[i]).setCurrentText(str(self.currentsettings[util.SETTINGS_KEYS[i]])) + try: + getattr(self, util.SETTINGS_KEYS[i]).installEventFilter(self) + except (AttributeError, TypeError): + pass + + if self.isdebug and self.isdebug.level in {"ALL"}: + print(self.currentsettings) + + # ------------------------------------- checkforchanges + # NOTE checkforchanges -------------------------------- + def checkforchanges(self): + for i in range(len(util.SETTINGS_KEYS)): + if util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "bool": + if getattr(self, util.SETTINGS_KEYS[i]).isChecked() != bc(self.currentsettings[util.SETTINGS_KEYS[i]]): + return True + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "text": + if getattr(self, util.SETTINGS_KEYS[i]).text() != self.currentsettings[util.SETTINGS_KEYS[i]]: + return True + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "intval": + if getattr(self, util.SETTINGS_KEYS[i]).value() != self.currentsettings[util.SETTINGS_KEYS[i]]: + return True + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "cbx": + if getattr(self, util.SETTINGS_KEYS[i]).currentText() != self.currentsettings[util.SETTINGS_KEYS[i]]: + return True + return False + # ------------------------------------------- savecheck + # NOTE savecheck -------------------------------------- def savecheck(self): buttonindex = hou.ui.displayMessage( - "Save changes?", buttons=('Save', 'Discard'), default_choice=0, title="Unsaved Changes:",) + "Save changes?", + buttons=('Save', 'Discard'), + default_choice=0, + title="Unsaved Changes:" + ) if buttonindex == 0: - self.tmphotkey = self.hkinput.text() - self.datahandler.updatetmphotkey(self.tmphotkey) + self.save_cb() self.hkholder = "" elif buttonindex == 1: - self.hkinput.setText(self.hkholder) + self.defaulthotkey.setText(self.hkholder) self.hkholder = "" + # !SECTION - # ----------------------------------------------------------------------------------- Events + # ------------------------------------------------------------------ Events + # SECTION Events ---------------------------------------------------------- def eventFilter(self, obj, event): + # ------------------------------------------ Window + # NOTE Window ------------------------------------- + if event.type() == QtCore.QEvent.WindowActivate: + self.ui.isopened = True + self.performcheck = True + self.updatecurrentvalues() + return True + + # ------------------------------------------- Mouse + # NOTE Mouse -------------------------------------- if event.type() == QtCore.QEvent.MouseButtonDblClick: - self.hkholder = self.hkinput.text() - self.hkinput.setText("") - self.hkinput.setPlaceholderText("Input key sequence") - self.canedit = True - if event.type() == QtCore.QEvent.Close: - if self.canedit is False and self.hkholder is not "": - self.savecheck() + if obj == self.defaulthotkey: + self.hkholder = self.defaulthotkey.text() + self.defaulthotkey.setText("") + self.defaulthotkey.setPlaceholderText("Input key sequence") + self.canedit = True + if event.type() == QtCore.QEvent.Enter: + self.parentwindow.checktooltip(obj) + if event.type() == QtCore.QEvent.Leave: + self.parentwindow.checktooltip(obj, True) + if event.type() == QtCore.QEvent.ToolTip: + return True + + # ---------------------------------------- Keypress + # NOTE Keypress ----------------------------------- if event.type() == QtCore.QEvent.KeyPress: if event.key() == QtCore.Qt.Key_D: - if not self.debugflag.isVisible(): - self.debugflag.setVisible(True) + if obj != self.defaulthotkey: + if not self.debugflag.isVisible(): + self.debugflag.setVisible(True) if event.key() == QtCore.Qt.Key_Escape: - if self.canedit is False: + if self.performcheck: + if self.checkforchanges(): + self.savecheck() + if self.animatedsettings: + self.parentwindow.anim.start_animation(False) + self.isopened = True + self.performcheck=True + else: self.close() else: self.keyindex += 1 self.keystring = hou.qt.qtKeyToString( - event.key(), - int(event.modifiers()), + event.key(), + int(event.modifiers()), event.text() ) if self.canedit: if self.keystring not in ["Esc", "Backspace"]: - if self.hkinput.hasFocus(): - self.KeySequence = QtGui.QKeySequence( - self.keystring).toString() - self.hkinput.setText(self.KeySequence) + if self.defaulthotkey.hasFocus(): + self.KeySequence = QtGui.QKeySequence(self.keystring).toString() + self.defaulthotkey.setText(self.KeySequence) if self.keystring in ["Esc", "Backspace"]: - self.hkinput.setText(self.hkholder) + self.defaulthotkey.setText(self.hkholder) + # -------------------------------------- Keyrelease + # NOTE Keyrelease --------------------------------- if event.type() == QtCore.QEvent.KeyRelease: if event.key() == QtCore.Qt.Key_Escape: return QtCore.QObject.eventFilter(self, obj, event) else: self.keyindex -= 1 if self.keyindex == 0: - if self.hkinput.text() == "": - self.hkinput.setText(self.hkholder) - if self.hkinput.text() != "": + if self.defaulthotkey.text() == "": + self.defaulthotkey.setText(self.hkholder) + if self.defaulthotkey.text() != "": self.canedit = False - return False + + # ------------------------------------------- Close + # NOTE Close -------------------------------------- + if event.type() == QtCore.QEvent.Close: + self.ui.isopened = False + self.parentwindow.opensettingstool.setChecked(False) + self.performcheck=True + + return QtCore.QObject.eventFilter(self, obj, event) + + +class LinkLabel(QtWidgets.QLabel): + def __init__(self, parent, text): + super(LinkLabel, self).__init__(parent) + + self.setText(text) + self.setTextFormat(Qt.RichText) + self.setTextInteractionFlags(Qt.TextBrowserInteraction) + self.setOpenExternalLinks(True) \ No newline at end of file diff --git a/scripts/python/searcher/searcher_settings_bak.py b/scripts/python/searcher/searcher_settings_bak.py new file mode 100644 index 0000000..53e2cbf --- /dev/null +++ b/scripts/python/searcher/searcher_settings_bak.py @@ -0,0 +1,297 @@ +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +from searcher import searcher_data +from searcher import util + +from builtins import range +from past.utils import old_div +import platform +import os + +import sys +import hou +import hdefereval +from hutil import py23 +hver = 0 +if os.environ["HFS"] != "": + ver = os.environ["HFS"] + hver = int(ver[ver.rindex('.')+1:]) + from hutil.Qt import QtGui + from hutil.Qt import QtCore + from hutil.Qt import QtWidgets + if hver >= 395: + from hutil.Qt import QtUiTools + elif hver <= 394 and hver >= 391: + from hutil.Qt import _QtUiTools + elif hver < 391 and hver >= 348: + from hutil.Qt import QtUiTools + + +# -------------------------------------------------------------------- App Info +__package__ = "Searcher" +__version__ = "0.1b" +__author__ = "instance.id" +__copyright__ = "2020 All rights reserved. See LICENSE for more details." +__status__ = "Prototype" +# endregion + +the_scaled_icon_size = hou.ui.scaledSize(16) +the_icon_size = 16 + +num = 0 +# info +__author__ = "instance.id" +__copyright__ = "2020 All rights reserved." +__status__ = "Prototype" + +scriptpath = os.path.dirname(os.path.realpath(__file__)) + + +def bc(v): + return str(v).lower() in ("yes", "true", "t", "1") + + +class SearcherSettings(QtWidgets.QWidget): + """ Searcher Settings and Debug Menu""" + + def __init__(self, handler, tmphotkey, parent=None): + super(SearcherSettings, self).__init__(parent=parent) + + # ------------------------------------------------- Component variables + self.settings = {} + self.context_dict = {} + self.command_dict = {} + self.contexts = None + self.commands = None + self.addKeyWidget = None + self.context_data = None + self.command_data = None + self.keys_changed = False + self.keystring = "" + self.keyindex = 0 + self.canedit = False + self.KeySequence = None + self.hkholder = "" + self.defaulthotkey = tmphotkey + self.datahandler = handler + self.tmphotkey = tmphotkey + + self.setObjectName('searcher-settings') + # ------------------------------------------------- Build UI + self.setAutoFillBackground(True) + self.setBackgroundRole(QtGui.QPalette.Window) + self.settings = searcher_data.loadsettings() + self.isdebug = bc(self.settings[util.SETTINGS_KEYS[4]]) + + # Load UI File + loader = None + if int(hver) >= 391 and int(hver) <= 394: + loader = _QtUiTools.QUiLoader() + else: + loader = QtUiTools.QUiLoader() + self.ui = loader.load(scriptpath + '/searchersettings.ui') + + # Get UI Elements + self.hotkey_icon = self.ui.findChild( + QtWidgets.QToolButton, + "hotkey_icon" + ) + self.debugflag = self.ui.findChild( + QtWidgets.QCheckBox, + "debugflag_chk" + ) + self.in_memory_db = self.ui.findChild( + QtWidgets.QCheckBox, + "inmemory_chk" + ) + self.savewindowsize = self.ui.findChild( + QtWidgets.QCheckBox, + "windowsize_chk" + ) + self.defaulthotkey = self.ui.findChild( + QtWidgets.QLineEdit, + "defaulthotkey_txt" + ) + self.database_path = self.ui.findChild( + QtWidgets.QLineEdit, + "databasepath_txt" + ) + self.test1 = self.ui.findChild( + QtWidgets.QPushButton, + "test1_btn" + ) + self.cleardata = self.ui.findChild( + QtWidgets.QPushButton, + "cleardata_btn" + ) + self.savedata = self.ui.findChild( + QtWidgets.QPushButton, + "save_btn" + ) + self.discarddata = self.ui.findChild( + QtWidgets.QPushButton, + "discard_btn" + ) + + mainlayout = QtWidgets.QVBoxLayout() + mainlayout.addWidget(self.ui) + + # ------------------------------------------------- Create Connections + # self.in_memory_db.stateChanged.connect(self.toggledebug) + self.hotkey_icon.clicked.connect(self.hotkeyicon_cb) + self.hotkey_icon.setIcon(util.INFO_ICON) + info_button_size = hou.ui.scaledSize(16) + self.hotkey_icon.setProperty("flat", True) + self.hotkey_icon.setIcon(util.INFO_ICON) + self.hotkey_icon.setIconSize(QtCore.QSize( + info_button_size, + info_button_size + )) + + self.defaulthotkey.setText(self.tmphotkey) + self.defaulthotkey.setStatusTip("Status Tip?") + self.defaulthotkey.setWhatsThis("Whats this?") + # self.defaulthotkey.setToolTip( + # "If left to the default value of (Ctrl+Alt+Shift+F7), in the event that Searcher detects a conflict it will automatically attempt to try different key combinations.") + self.defaulthotkey.setStyleSheet(util.TOOLTIP) + self.database_path.setText(str(self.settings['database_path'])) + self.test1.clicked.connect(self.test1_cb) + self.cleardata.clicked.connect(self.cleardata_cb) + self.savedata.clicked.connect(self.save_cb) + self.discarddata.clicked.connect(self.discard_cb) + + # ------------------------------------------------- Apply Layout + self.setLayout(mainlayout) + self.installEventFilter(self) + + self.debugflag.setChecked(bc(self.settings[util.SETTINGS_KEYS[4]])) + self.debugflag.setVisible(bc(self.settings[util.SETTINGS_KEYS[4]])) + self.in_memory_db.setChecked(bc(self.settings[util.SETTINGS_KEYS[0]])) + self.savewindowsize.setChecked( + bc(self.settings[util.SETTINGS_KEYS[2]])) + + # ------------------------------------------------- Add EventFilters + self.defaulthotkey.installEventFilter(self) + self.debugflag.installEventFilter(self) + + # ----------------------------------------------------------------------------------- Callbacks + def hotkeyicon_cb(self): + self.settings['in_memory_db'] = self.in_memory_db.isChecked() + print(self.settings['in_memory_db']) + + def toggledebug(self): + self.settings['in_memory_db'] = self.in_memory_db.isChecked() + print(self.settings['in_memory_db']) + + def defaulthk_cb(self): + return + + def test1_cb(self): + hkeys = [] + for i in range(len(util.HOTKEYLIST)): + result = hou.hotkeys.findConflicts("h", util.HOTKEYLIST[i]) + if result: + print ("Confliction found: {}".format(result)) + else: + print("No Confliction: {}".format(result)) + hkeys.append(result) + print (hkeys) + + def cleardata_cb(self): + self.datahandler.cleardb() + + def save_cb(self): + if self.defaulthotkey.text() == "": + buttonindex = hou.ui.displayMessage("Please enter a hotkey") + self.activateWindow() + self.defaulthotkey.setFocus() + self.canedit = True + else: + if self.defaulthotkey.text() != self.tmphotkey: + self.tmphotkey = self.defaulthotkey.text() + self.datahandler.updatetmphotkey(self.tmphotkey) + + for i in range(len(util.SETTINGS_KEYS)): + if util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "bool": + self.settings[util.SETTINGS_KEYS[i]] = getattr( + self, util.SETTINGS_KEYS[i]).isChecked() + elif util.SETTINGS_TYPES[util.SETTINGS_KEYS[i]] == "text": + self.settings[util.SETTINGS_KEYS[i]] = getattr( + self, util.SETTINGS_KEYS[i]).text() + + if self.isdebug: + print(self.settings) + + searcher_data.savesettings(self.settings) + self.close() + + def discard_cb(self): + self.defaulthotkey.setText(self.tmphotkey) + self.hkholder = "" + self.close() + + # ----------------------------------------------------------------------------------- Actions + def savecheck(self): + buttonindex = hou.ui.displayMessage( + "Save changes?", buttons=('Save', 'Discard'), default_choice=0, title="Unsaved Changes:",) + if buttonindex == 0: + self.tmphotkey = self.defaulthotkey.text() + self.datahandler.updatetmphotkey(self.tmphotkey) + self.hkholder = "" + elif buttonindex == 1: + self.defaulthotkey.setText(self.hkholder) + self.hkholder = "" + + # ----------------------------------------------------------------------------------- Events + def eventFilter(self, obj, event): + # ------------------------------------------------- Mouse + if event.type() == QtCore.QEvent.MouseButtonDblClick: + self.hkholder = self.defaulthotkey.text() + self.defaulthotkey.setText("") + self.defaulthotkey.setPlaceholderText("Input key sequence") + self.canedit = True + # ------------------------------------------------- Keypress + if event.type() == QtCore.QEvent.KeyPress: + if event.key() == QtCore.Qt.Key_D: + if not self.debugflag.isVisible(): + self.debugflag.setVisible(True) + + if event.key() == QtCore.Qt.Key_Escape: + if self.canedit is False: + self.close() + else: + self.keyindex += 1 + self.keystring = hou.qt.qtKeyToString( + event.key(), + int(event.modifiers()), + event.text() + ) + if self.canedit: + if self.keystring not in ["Esc", "Backspace"]: + if self.defaulthotkey.hasFocus(): + self.KeySequence = QtGui.QKeySequence( + self.keystring).toString() + self.defaulthotkey.setText(self.KeySequence) + if self.keystring in ["Esc", "Backspace"]: + self.defaulthotkey.setText(self.hkholder) + + # ------------------------------------------------- Keyrelease + if event.type() == QtCore.QEvent.KeyRelease: + if event.key() == QtCore.Qt.Key_Escape: + return QtCore.QObject.eventFilter(self, obj, event) + else: + self.keyindex -= 1 + if self.keyindex == 0: + if self.defaulthotkey.text() == "": + self.defaulthotkey.setText(self.hkholder) + if self.defaulthotkey.text() != "": + self.canedit = False + + # ------------------------------------------------- Close + if event.type() == QtCore.QEvent.Close: + if self.canedit is False and self.hkholder != "": + self.savecheck() + + return QtCore.QObject.eventFilter(self, obj, event) diff --git a/scripts/python/searcher/searcher_ui.py b/scripts/python/searcher/searcher_ui.py deleted file mode 100644 index 8c0b147..0000000 --- a/scripts/python/searcher/searcher_ui.py +++ /dev/null @@ -1,185 +0,0 @@ -# -*- coding: utf-8 -*- - -################################################################################ -## Form generated from reading UI file 'searcher_ui.ui' -## -## Created by: Qt User Interface Compiler version 5.14.1 -## -## WARNING! All changes made in this file will be lost when recompiling UI file! -################################################################################ - -from PySide2.QtCore import (QCoreApplication, QMetaObject, QObject, QPoint, - QRect, QSize, QUrl, Qt) -from PySide2.QtGui import (QBrush, QColor, QConicalGradient, QCursor, QFont, - QFontDatabase, QIcon, QLinearGradient, QPalette, QPainter, QPixmap, - QRadialGradient) -from PySide2.QtWidgets import * - -from .HelpButton import HelpButton - - -class Ui_Searcher(object): - def setupUi(self, Searcher): - if Searcher.objectName(): - Searcher.setObjectName(u"Searcher") - Searcher.setWindowModality(Qt.WindowModal) - Searcher.resize(1000, 329) - sizePolicy = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - sizePolicy.setHorizontalStretch(0) - sizePolicy.setVerticalStretch(0) - sizePolicy.setHeightForWidth(Searcher.sizePolicy().hasHeightForWidth()) - Searcher.setSizePolicy(sizePolicy) - Searcher.setMinimumSize(QSize(0, 0)) - Searcher.setBaseSize(QSize(1000, 350)) - Searcher.setStyleSheet(u"QTreeWidget QHeaderView::section {\n" -" font-size: 9pt;\n" -"}") - self.gridLayout = QGridLayout(Searcher) - self.gridLayout.setSpacing(0) - self.gridLayout.setObjectName(u"gridLayout") - self.gridLayout.setContentsMargins(0, 0, 0, 0) - self.verticalLayout = QVBoxLayout() - self.verticalLayout.setSpacing(0) - self.verticalLayout.setObjectName(u"verticalLayout") - self.horizontalLayout = QHBoxLayout() - self.horizontalLayout.setSpacing(0) - self.horizontalLayout.setObjectName(u"horizontalLayout") - self.horizontalSpacer_2 = QSpacerItem(8, 2, QSizePolicy.Fixed, QSizePolicy.Minimum) - - self.horizontalLayout.addItem(self.horizontalSpacer_2) - - self.projectTitle = QLabel(Searcher) - self.projectTitle.setObjectName(u"projectTitle") - sizePolicy1 = QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Preferred) - sizePolicy1.setHorizontalStretch(0) - sizePolicy1.setVerticalStretch(0) - sizePolicy1.setHeightForWidth(self.projectTitle.sizePolicy().hasHeightForWidth()) - self.projectTitle.setSizePolicy(sizePolicy1) - font = QFont() - font.setPointSize(15) - self.projectTitle.setFont(font) - self.projectTitle.setAlignment(Qt.AlignCenter) - - self.horizontalLayout.addWidget(self.projectTitle) - - self.horizontalSpacer = QSpacerItem(40, 5, QSizePolicy.Expanding, QSizePolicy.Minimum) - - self.horizontalLayout.addItem(self.horizontalSpacer) - - self.HelpButton = HelpButton(Searcher) - self.HelpButton.setObjectName(u"HelpButton") - - self.horizontalLayout.addWidget(self.HelpButton) - - self.pinwindow_btn = QToolButton(Searcher) - self.pinwindow_btn.setObjectName(u"pinwindow_btn") - - self.horizontalLayout.addWidget(self.pinwindow_btn) - - self.opensettings_btn = QToolButton(Searcher) - self.opensettings_btn.setObjectName(u"opensettings_btn") - - self.horizontalLayout.addWidget(self.opensettings_btn) - - self.horizontalSpacer_3 = QSpacerItem(8, 2, QSizePolicy.Fixed, QSizePolicy.Minimum) - - self.horizontalLayout.addItem(self.horizontalSpacer_3) - - - self.verticalLayout.addLayout(self.horizontalLayout) - - self.horizontalLayout_3 = QHBoxLayout() - self.horizontalLayout_3.setSpacing(0) - self.horizontalLayout_3.setObjectName(u"horizontalLayout_3") - self.frame = QFrame(Searcher) - self.frame.setObjectName(u"frame") - sizePolicy2 = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred) - sizePolicy2.setHorizontalStretch(2) - sizePolicy2.setVerticalStretch(0) - sizePolicy2.setHeightForWidth(self.frame.sizePolicy().hasHeightForWidth()) - self.frame.setSizePolicy(sizePolicy2) - self.frame.setMinimumSize(QSize(0, 20)) - self.frame.setFrameShape(QFrame.StyledPanel) - self.frame.setFrameShadow(QFrame.Raised) - self.searchfilter_btn = QToolButton(self.frame) - self.searchfilter_btn.setObjectName(u"searchfilter_btn") - self.searchfilter_btn.setGeometry(QRect(0, 0, 26, 20)) - self.searchfilter_btn.setBaseSize(QSize(16, 16)) - self.searchfilter_btn.setStyleSheet(u"background-color: rgb(19, 19, 19);") - self.searchfilter_btn.setArrowType(Qt.NoArrow) - - self.horizontalLayout_3.addWidget(self.frame) - - self.searchbox_txt = QLineEdit(Searcher) - self.searchbox_txt.setObjectName(u"searchbox_txt") - sizePolicy3 = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Minimum) - sizePolicy3.setHorizontalStretch(99) - sizePolicy3.setVerticalStretch(0) - sizePolicy3.setHeightForWidth(self.searchbox_txt.sizePolicy().hasHeightForWidth()) - self.searchbox_txt.setSizePolicy(sizePolicy3) - self.searchbox_txt.setMinimumSize(QSize(50, 0)) - self.searchbox_txt.setMouseTracking(False) - self.searchbox_txt.setStyleSheet(u"background-color: rgb(19, 19, 19);") - self.searchbox_txt.setFrame(False) - - self.horizontalLayout_3.addWidget(self.searchbox_txt) - - - self.verticalLayout.addLayout(self.horizontalLayout_3) - - self.searchresults_tree = QTreeWidget(Searcher) - __qtreewidgetitem = QTreeWidgetItem() - __qtreewidgetitem.setText(0, u"1"); - self.searchresults_tree.setHeaderItem(__qtreewidgetitem) - self.searchresults_tree.setObjectName(u"searchresults_tree") - sizePolicy4 = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) - sizePolicy4.setHorizontalStretch(0) - sizePolicy4.setVerticalStretch(0) - sizePolicy4.setHeightForWidth(self.searchresults_tree.sizePolicy().hasHeightForWidth()) - self.searchresults_tree.setSizePolicy(sizePolicy4) - font1 = QFont() - font1.setPointSize(9) - self.searchresults_tree.setFont(font1) - self.searchresults_tree.setMouseTracking(False) - self.searchresults_tree.setFocusPolicy(Qt.NoFocus) - self.searchresults_tree.setFrameShadow(QFrame.Sunken) - self.searchresults_tree.setLineWidth(0) - self.searchresults_tree.setSizeAdjustPolicy(QAbstractScrollArea.AdjustToContents) - self.searchresults_tree.setAlternatingRowColors(True) - self.searchresults_tree.setSelectionMode(QAbstractItemView.SingleSelection) - self.searchresults_tree.setSelectionBehavior(QAbstractItemView.SelectRows) - - self.verticalLayout.addWidget(self.searchresults_tree) - - - self.gridLayout.addLayout(self.verticalLayout, 1, 0, 1, 1) - - self.info_lbl = QLabel(Searcher) - self.info_lbl.setObjectName(u"info_lbl") - font2 = QFont() - font2.setPointSize(8) - font2.setBold(False) - font2.setWeight(50) - self.info_lbl.setFont(font2) - self.info_lbl.setStyleSheet(u"background-color: rgb(26, 26, 26);") - self.info_lbl.setMargin(2) - self.info_lbl.setIndent(5) - - self.gridLayout.addWidget(self.info_lbl, 2, 0, 1, 1) - - - self.retranslateUi(Searcher) - - QMetaObject.connectSlotsByName(Searcher) - # setupUi - - def retranslateUi(self, Searcher): - Searcher.setWindowTitle(QCoreApplication.translate("Searcher", u"Searcher", None)) - self.projectTitle.setText(QCoreApplication.translate("Searcher", u"Searcher", None)) - self.HelpButton.setText(QCoreApplication.translate("Searcher", u"...", None)) - self.pinwindow_btn.setText(QCoreApplication.translate("Searcher", u"...", None)) - self.opensettings_btn.setText(QCoreApplication.translate("Searcher", u"...", None)) - self.searchfilter_btn.setText(QCoreApplication.translate("Searcher", u"...", None)) - self.info_lbl.setText(QCoreApplication.translate("Searcher", u"Begin typing to search or click magnifying glass icon to display options", None)) - # retranslateUi - diff --git a/scripts/python/searcher/searcher_ui.ui.bak b/scripts/python/searcher/searcher_ui.ui.bak deleted file mode 100644 index 519e1e7..0000000 --- a/scripts/python/searcher/searcher_ui.ui.bak +++ /dev/null @@ -1,316 +0,0 @@ - - - Searcher - - - Qt::WindowModal - - - - 0 - 0 - 1000 - 329 - - - - - 0 - 0 - - - - - 0 - 0 - - - - - 1000 - 350 - - - - Searcher - - - QTreeWidget QHeaderView::section { - font-size: 9pt; -} - - - - 0 - - - 0 - - - 0 - - - 0 - - - 0 - - - - - 0 - - - - - 0 - - - - - Qt::Horizontal - - - QSizePolicy::Fixed - - - - 8 - 2 - - - - - - - - - 0 - 0 - - - - - 15 - - - - Searcher - - - Qt::AlignCenter - - - - - - - Qt::Horizontal - - - - 40 - 5 - - - - - - - - ... - - - - - - - ... - - - - - - - ... - - - - - - - Qt::Horizontal - - - QSizePolicy::Fixed - - - - 8 - 2 - - - - - - - - - - 0 - - - - - - 2 - 0 - - - - - 0 - 20 - - - - QFrame::StyledPanel - - - QFrame::Raised - - - - - 0 - 0 - 26 - 20 - - - - - 16 - 16 - - - - background-color: rgb(19, 19, 19); - - - ... - - - Qt::NoArrow - - - - - - - - - 99 - 0 - - - - - 50 - 0 - - - - false - - - background-color: rgb(19, 19, 19); - - - false - - - - - - - - - - 0 - 0 - - - - - 9 - - - - false - - - Qt::NoFocus - - - QFrame::Sunken - - - 0 - - - QAbstractScrollArea::AdjustToContents - - - true - - - QAbstractItemView::SingleSelection - - - QAbstractItemView::SelectRows - - - - 1 - - - - - - - - - - - 8 - 50 - false - - - - background-color: rgb(26, 26, 26); - - - Begin typing to search or click magnifying glass icon to display options - - - 2 - - - 5 - - - - - - - - HelpButton - QPushButton -
.HelpButton
-
-
- - -
diff --git a/scripts/python/searcher/searcher_util.py b/scripts/python/searcher/searcher_util.py deleted file mode 100644 index 34e4aed..0000000 --- a/scripts/python/searcher/searcher_util.py +++ /dev/null @@ -1,124 +0,0 @@ -# -*- coding: utf-8 -*- - -################################################################################ -## Form generated from reading UI file 'searcher_util.ui' -## -## Created by: Qt User Interface Compiler version 5.14.1 -## -## WARNING! All changes made in this file will be lost when recompiling UI file! -################################################################################ - -from PySide2.QtCore import (QCoreApplication, QMetaObject, QObject, QPoint, - QRect, QSize, QUrl, Qt) -from PySide2.QtGui import (QBrush, QColor, QConicalGradient, QCursor, QFont, - QFontDatabase, QIcon, QLinearGradient, QPalette, QPainter, QPixmap, - QRadialGradient) -from PySide2.QtWidgets import * - - -class Ui_Searcher(object): - def setupUi(self, Searcher): - if Searcher.objectName(): - Searcher.setObjectName(u"Searcher") - Searcher.setWindowModality(Qt.NonModal) - Searcher.resize(1000, 113) - sizePolicy = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - sizePolicy.setHorizontalStretch(0) - sizePolicy.setVerticalStretch(0) - sizePolicy.setHeightForWidth(Searcher.sizePolicy().hasHeightForWidth()) - Searcher.setSizePolicy(sizePolicy) - Searcher.setMinimumSize(QSize(0, 0)) - Searcher.setBaseSize(QSize(1000, 350)) - self.gridLayout = QGridLayout(Searcher) - self.gridLayout.setObjectName(u"gridLayout") - self.verticalLayout_4 = QVBoxLayout() - self.verticalLayout_4.setObjectName(u"verticalLayout_4") - self.horizontalLayout_4 = QHBoxLayout() - self.horizontalLayout_4.setObjectName(u"horizontalLayout_4") - self.projectTitle = QLabel(Searcher) - self.projectTitle.setObjectName(u"projectTitle") - font = QFont() - font.setPointSize(15) - self.projectTitle.setFont(font) - self.projectTitle.setAlignment(Qt.AlignCenter) - - self.horizontalLayout_4.addWidget(self.projectTitle) - - self.gocommand_txt = QLineEdit(Searcher) - self.gocommand_txt.setObjectName(u"gocommand_txt") - - self.horizontalLayout_4.addWidget(self.gocommand_txt) - - self.sendgocommand_btn = QPushButton(Searcher) - self.sendgocommand_btn.setObjectName(u"sendgocommand_btn") - - self.horizontalLayout_4.addWidget(self.sendgocommand_btn) - - - self.verticalLayout_4.addLayout(self.horizontalLayout_4) - - self.horizontalLayout = QHBoxLayout() - self.horizontalLayout.setObjectName(u"horizontalLayout") - self.opensearchwindow_btn = QPushButton(Searcher) - self.opensearchwindow_btn.setObjectName(u"opensearchwindow_btn") - - self.horizontalLayout.addWidget(self.opensearchwindow_btn) - - self.gethotkeys_btn = QPushButton(Searcher) - self.gethotkeys_btn.setObjectName(u"gethotkeys_btn") - - self.horizontalLayout.addWidget(self.gethotkeys_btn) - - self.updatehotkeys_btn = QPushButton(Searcher) - self.updatehotkeys_btn.setObjectName(u"updatehotkeys_btn") - - self.horizontalLayout.addWidget(self.updatehotkeys_btn) - - self.loadhotkeys_btn = QPushButton(Searcher) - self.loadhotkeys_btn.setObjectName(u"loadhotkeys_btn") - - self.horizontalLayout.addWidget(self.loadhotkeys_btn) - - - self.verticalLayout_4.addLayout(self.horizontalLayout) - - - self.gridLayout.addLayout(self.verticalLayout_4, 1, 0, 1, 1) - - self.horizontalLayout_2 = QHBoxLayout() - self.horizontalLayout_2.setObjectName(u"horizontalLayout_2") - self.callgo_btn = QPushButton(Searcher) - self.callgo_btn.setObjectName(u"callgo_btn") - - self.horizontalLayout_2.addWidget(self.callgo_btn) - - self.other_btn = QPushButton(Searcher) - self.other_btn.setObjectName(u"other_btn") - - self.horizontalLayout_2.addWidget(self.other_btn) - - - self.gridLayout.addLayout(self.horizontalLayout_2, 2, 0, 1, 1) - - self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding) - - self.gridLayout.addItem(self.verticalSpacer, 3, 0, 1, 1) - - - self.retranslateUi(Searcher) - - QMetaObject.connectSlotsByName(Searcher) - # setupUi - - def retranslateUi(self, Searcher): - Searcher.setWindowTitle(QCoreApplication.translate("Searcher", u"Form", None)) - self.projectTitle.setText(QCoreApplication.translate("Searcher", u"Searcher Settings", None)) - self.sendgocommand_btn.setText(QCoreApplication.translate("Searcher", u"Send Go Cmd", None)) - self.opensearchwindow_btn.setText(QCoreApplication.translate("Searcher", u"Open Search", None)) - self.gethotkeys_btn.setText(QCoreApplication.translate("Searcher", u"Get Hotkeys", None)) - self.updatehotkeys_btn.setText(QCoreApplication.translate("Searcher", u"Update HotKeys", None)) - self.loadhotkeys_btn.setText(QCoreApplication.translate("Searcher", u"Load Hotkeys", None)) - self.callgo_btn.setText(QCoreApplication.translate("Searcher", u"Call Go", None)) - self.other_btn.setText(QCoreApplication.translate("Searcher", u"Clear Data", None)) - # retranslateUi - diff --git a/scripts/python/searcher/searcher_util.ui b/scripts/python/searcher/searcher_util.ui deleted file mode 100644 index 929d64d..0000000 --- a/scripts/python/searcher/searcher_util.ui +++ /dev/null @@ -1,138 +0,0 @@ - - - Searcher - - - Qt::NonModal - - - - 0 - 0 - 1000 - 113 - - - - - 0 - 0 - - - - - 0 - 0 - - - - - 1000 - 350 - - - - Form - - - - - - - - - - - 15 - - - - Searcher Settings - - - Qt::AlignCenter - - - - - - - - - - Send Go Cmd - - - - - - - - - - - Open Search - - - - - - - Get Hotkeys - - - - - - - Update HotKeys - - - - - - - Load Hotkeys - - - - - - - - - - - - - Call Go - - - - - - - Clear Data - - - - - - - - - Qt::Vertical - - - - 20 - 40 - - - - - - - - - diff --git a/scripts/python/searcher/searchersettings_ui.py b/scripts/python/searcher/searchersettings_ui.py new file mode 100644 index 0000000..636c315 --- /dev/null +++ b/scripts/python/searcher/searchersettings_ui.py @@ -0,0 +1,239 @@ +# -*- coding: utf-8 -*- +from hutil.Qt import QtCore, QtGui, QtWidgets + +def bc(v): + return str(v).lower() in ("yes", "true", "t", "1") + + +class Ui_SearcherSettings(object): + def setupUi(self, SearcherSettings, width, height, animated): + self.width = width + self.height = height + self.animated = animated + + SearcherSettings.setObjectName("SearcherSettings") + SearcherSettings.setWindowModality(QtCore.Qt.NonModal) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(SearcherSettings.sizePolicy().hasHeightForWidth()) + SearcherSettings.setSizePolicy(sizePolicy) + SearcherSettings.setMinimumSize(QtCore.QSize(450, 300)) + SearcherSettings.setBaseSize(QtCore.QSize(0, 0)) + + self.gridLayout = QtWidgets.QGridLayout(SearcherSettings) + self.gridLayout.setContentsMargins(-1, -1, -1, -1) + self.gridLayout.setObjectName("gridLayout") + self.verticallayout = QtWidgets.QVBoxLayout() + self.verticallayout.setObjectName("verticalLayout") + self.verticallayout.setSpacing(10) + + # ------------------------------------------------- headerrow + # NOTE headerrow -------------------------------------------- + self.headerrow = QtWidgets.QHBoxLayout() + self.headerrow.setObjectName("headerrow") + + self.projectTitle = QtWidgets.QLabel(SearcherSettings) + font = QtGui.QFont() + font.setPointSize(15) + self.projectTitle.setFont(font) + self.projectTitle.setAlignment(QtCore.Qt.AlignCenter) + self.projectTitle.setObjectName("projectTitle") + self.headerrow.addWidget(self.projectTitle) + + spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.headerrow.addItem(spacerItem) + + self.animatedsettings_chk = QtWidgets.QCheckBox(SearcherSettings) + self.animatedsettings_chk.setLayoutDirection(QtCore.Qt.RightToLeft) + self.animatedsettings_chk.setObjectName("animatedsettings_chk") + self.headerrow.addWidget(self.animatedsettings_chk) + + self.windowsize_chk = QtWidgets.QCheckBox(SearcherSettings) + self.windowsize_chk.setLayoutDirection(QtCore.Qt.RightToLeft) + self.windowsize_chk.setObjectName("windowsize_chk") + self.headerrow.addWidget(self.windowsize_chk) + self.verticallayout.addLayout(self.headerrow) + + self.line = QtWidgets.QFrame(SearcherSettings) + self.line.setFrameShape(QtWidgets.QFrame.HLine) + self.line.setFrameShadow(QtWidgets.QFrame.Sunken) + self.line.setObjectName("line") + self.verticallayout.addWidget(self.line) + + # ------------------------------------------------- secondrow + # NOTE Second Row ------------------------------------------- + self.secondrow = QtWidgets.QHBoxLayout() + self.secondrow.setObjectName("secondrow") + + self.lang_cbox = QtWidgets.QComboBox(SearcherSettings) + self.lang_cbox.setObjectName("lang_cbox") + self.lang_cbox.addItem("") + self.secondrow.addWidget(self.lang_cbox) + + spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.secondrow.addItem(spacerItem) + + self.maxresults_lbl = QtWidgets.QLabel(SearcherSettings) + self.maxresults_lbl.setObjectName("label_3") + self.secondrow.addWidget(self.maxresults_lbl) + self.maxresults_txt = QtWidgets.QSpinBox(SearcherSettings) + self.maxresults_txt.setMinimum(1) + self.maxresults_txt.setMaximum(9999) + self.maxresults_txt.setObjectName("maxresults_txt") + self.secondrow.addWidget(self.maxresults_txt) + + self.inmemory_chk = QtWidgets.QCheckBox(SearcherSettings) + self.inmemory_chk.setLayoutDirection(QtCore.Qt.RightToLeft) + self.inmemory_chk.setTristate(False) + self.inmemory_chk.setObjectName("inmemory_chk") + self.secondrow.addWidget(self.inmemory_chk) + + self.verticallayout.addLayout(self.secondrow) + + # -------------------------------------------------- thirdrow + # NOTE Third Row -------------------------------------------- + self.thirdrow = QtWidgets.QHBoxLayout() + self.thirdrow.setObjectName("thirdrow") + + self.defaulthotkey_lbl = QtWidgets.QLabel(SearcherSettings) + self.defaulthotkey_lbl.setObjectName("label_2") + self.thirdrow.addWidget(self.defaulthotkey_lbl) + + self.defaulthotkey_txt = QtWidgets.QLineEdit(SearcherSettings) + self.defaulthotkey_txt.setToolTip("") + self.defaulthotkey_txt.setReadOnly(True) + self.defaulthotkey_txt.setObjectName("defaulthotkey_txt") + self.thirdrow.addWidget(self.defaulthotkey_txt) + + self.hotkey_icon = QtWidgets.QToolButton(SearcherSettings) + self.hotkey_icon.setPopupMode(QtWidgets.QToolButton.InstantPopup) + self.hotkey_icon.setObjectName("hotkey_icon") + self.thirdrow.addWidget(self.hotkey_icon) + self.verticallayout.addLayout(self.thirdrow) + + # ------------------------------------------------- fourthrow + # NOTE Fourth Row ------------------------------------------- + self.fourthrow = QtWidgets.QHBoxLayout() + self.fourthrow.setObjectName("fourthrow") + + self.dbpath_lbl = QtWidgets.QLabel(SearcherSettings) + self.dbpath_lbl.setObjectName("label") + self.fourthrow.addWidget(self.dbpath_lbl) + + self.databasepath_txt = QtWidgets.QLineEdit(SearcherSettings) + self.databasepath_txt.setObjectName("databasepath_txt") + self.fourthrow.addWidget(self.databasepath_txt) + + self.dbpath_icon = QtWidgets.QToolButton(SearcherSettings) + self.dbpath_icon.setObjectName("dbpath_icon") + self.fourthrow.addWidget(self.dbpath_icon) + + self.verticallayout.addLayout(self.fourthrow) + + # -------------------------------------------------- fifthrow + # NOTE Fifth Row -------------------------------------------- + self.fifthrow = QtWidgets.QHBoxLayout() + self.fifthrow.setObjectName("fifthrow") + + self.maint_lbl = QtWidgets.QLabel(SearcherSettings) + self.maint_lbl.setObjectName("label_4") + self.fifthrow.addWidget(self.maint_lbl) + + self.test1_btn = QtWidgets.QPushButton(SearcherSettings) + self.test1_btn.setObjectName("test1_btn") + self.fifthrow.addWidget(self.test1_btn) + + self.cleardata_btn = QtWidgets.QPushButton(SearcherSettings) + self.cleardata_btn.setObjectName("cleardata_btn") + self.fifthrow.addWidget(self.cleardata_btn) + + self.verticallayout.addLayout(self.fifthrow) + + # ---------------------------------------------------- Spacer + self.line2 = QtWidgets.QFrame(SearcherSettings) + self.line2.setFrameShape(QtWidgets.QFrame.HLine) + self.line2.setFrameShadow(QtWidgets.QFrame.Sunken) + self.line2.setObjectName("line2") + self.verticallayout.addWidget(self.line2) + # ---------------------------------------------------- Spacer + + # -------------------------------------------------- sixthrow + # NOTE Sixth Row -------------------------------------------- + self.sixthrow = QtWidgets.QHBoxLayout() + self.sixthrow.setObjectName("sixthrow") + + self.about_btn = QtWidgets.QToolButton(SearcherSettings) + self.about_btn.setObjectName("about_btn") + self.sixthrow.addWidget(self.about_btn) + + self.bug_btn = QtWidgets.QToolButton(SearcherSettings) + self.bug_btn.setObjectName("bug_btn") + self.sixthrow.addWidget(self.bug_btn) + + spacerItem1 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.sixthrow.addItem(spacerItem1) + + self.debuglevel_cbx = QtWidgets.QComboBox(SearcherSettings) + self.debuglevel_cbx.setObjectName("debuglevel_cbx") + self.sixthrow.addWidget(self.debuglevel_cbx) + + self.debugflag_chk = QtWidgets.QCheckBox(SearcherSettings) + self.debugflag_chk.setLayoutDirection(QtCore.Qt.RightToLeft) + self.debugflag_chk.setObjectName("debugflag_chk") + self.sixthrow.addWidget(self.debugflag_chk) + + self.discard_btn = QtWidgets.QPushButton(SearcherSettings) + self.discard_btn.setObjectName("discard_btn") + self.sixthrow.addWidget(self.discard_btn) + + self.save_btn = QtWidgets.QPushButton(SearcherSettings) + self.save_btn.setObjectName("save_btn") + self.sixthrow.addWidget(self.save_btn) + + self.verticallayout.addLayout(self.sixthrow) + + if not self.animated: + self.gridLayout.addLayout(self.verticallayout, 1, 0, 1, 1) + + # ----------------------------------------------------------- + self.retranslateUi(SearcherSettings) + QtCore.QMetaObject.connectSlotsByName(SearcherSettings) + + def retranslateUi(self, SearcherSettings): + _translate = QtCore.QCoreApplication.translate + SearcherSettings.setWindowTitle(_translate("SearcherSettings", "Form")) + + # ------------------------------------------------- headerrow + self.projectTitle.setText(_translate("SearcherSettings", "Settings")) + self.animatedsettings_chk.setText(_translate("SearcherSettings", "Use Animated Menus:")) + self.windowsize_chk.setText(_translate("SearcherSettings", "Remember Search Window Size")) + + # ------------------------------------------------- secondrow + self.maxresults_lbl.setText(_translate("SearcherSettings", "Maximum Search Results")) + self.inmemory_chk.setText(_translate("SearcherSettings", "Use In-Memory Database")) + + # -------------------------------------------------- thirdrow + # self.label_3.setText(_translate("SearcherSettings", "Language:")) + self.lang_cbox.setCurrentText(_translate("SearcherSettings", "English")) + self.lang_cbox.setItemText(0, _translate("SearcherSettings", "English")) + self.defaulthotkey_lbl.setText(_translate("SearcherSettings", "Hotkey to use for opening unassigned items: ")) + self.defaulthotkey_txt.setPlaceholderText(_translate("SearcherSettings", "Double Click")) + self.hotkey_icon.setText(_translate("SearcherSettings", "...")) + + # ------------------------------------------------- fourthrow + self.dbpath_lbl.setText(_translate("SearcherSettings", "Database location: ")) + self.dbpath_icon.setText(_translate("SearcherSettings", "...")) + + # -------------------------------------------------- fifthrow + self.maint_lbl.setText(_translate("SearcherSettings", "Maintenance utilities:")) + self.test1_btn.setText(_translate("SearcherSettings", "Maint Button")) + self.cleardata_btn.setText(_translate("SearcherSettings", "Clear Data")) + + # ------------------------------------------------- sixthrow + self.about_btn.setText(_translate("SearcherSettings", "...")) + self.bug_btn.setText(_translate("About", "...")) + self.debugflag_chk.setText(_translate("SearcherSettings", "Debug Mode")) + self.discard_btn.setText(_translate("SearcherSettings", "Discard")) + self.save_btn.setText(_translate("SearcherSettings", "Save")) + diff --git a/scripts/python/searcher/session.py b/scripts/python/searcher/session.py new file mode 100644 index 0000000..113dcc0 --- /dev/null +++ b/scripts/python/searcher/session.py @@ -0,0 +1,111 @@ + + +import os +import time + +# attempt to import hou module. if it fails, not in a houdini session. +try: + import hou +except ImportError: + HOU_IMPORTED = False +else: + HOU_IMPORTED = True + +# ----------------------------------------------------------------------------- +# attempt to import hou ui module. if it fails, not in the UI +try: + from PySide import QtCore, QtGui +except: + HOU_UI_IMPORTED = False +else: + HOU_UI_IMPORTED = True + +from dpa.app.session import RemoteMixin, Session, SessionRegistry, SessionError + +# ----------------------------------------------------------------------------- +class HoudiniSession(RemoteMixin, Session): + + app_name = 'houdini' + + # XXX should come from config + SERVER_EXECUTABLE = "/home/jtomlin/dev/dpa-pipe/bin/dpa_houdini_server" + + # ------------------------------------------------------------------------- + @classmethod + def current(cls): + if not HOU_IMPORTED: + return None + return cls() + + # ------------------------------------------------------------------------- + def __init__(self, filepath=None, remote=False): + + super(HoudiniSession, self).__init__(remote=remote) + + self._hou = self.init_module('hou') + + if filepath: + self.open_file(filepath) + + # ------------------------------------------------------------------------- + def close(self): + if self.remote_connection: + self.shutdown() + else: + self.hou.hipFile.clear() + + # ------------------------------------------------------------------------- + def open_file(self, filepath): + + if not os.path.exists(filepath): + raise SessionError( + "Can not open '{f}'. File does not exist.".format(f=filepath)) + + try: + self.hou.hipFile.load(filepath) + except RuntimeError as e: + raise SessionError(str(e)) + + # ------------------------------------------------------------------------- + def save(self, filepath=None, overwrite=False): + + if filepath and os.path.exists(filepath) and not overwrite: + raise SessionError( + "Can not save '{f}'. File exists.".format(f=filepath)) + + self.hou.hipFile.save(file_name=filepath) + + # ------------------------------------------------------------------------- + @property + def hou(self): + return self._hou + + # ------------------------------------------------------------------------- + @property + def in_session(self): + """Returns True if inside a current app session.""" + return HOU_IMPORTED or self.remote_connection + + # ------------------------------------------------------------------------- + @property + def main_window(self): + + if not HOU_UI_IMPORTED: + return None + + return QtGui.QApplication.activeWindow() + + # ------------------------------------------------------------------------- + @property + def name(self): + """Returns the name of the application.""" + return "houdini" + + # ------------------------------------------------------------------------- + @property + def server_executable(self): + return self.__class__.SERVER_EXECUTABLE + + +# ----------------------------------------------------------------------------- +SessionRegistry().register(HoudiniSession) \ No newline at end of file diff --git a/scripts/python/searcher/ui_files/SearcherSettings.py b/scripts/python/searcher/ui_files/SearcherSettings.py new file mode 100644 index 0000000..548bcf9 --- /dev/null +++ b/scripts/python/searcher/ui_files/SearcherSettings.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- + +# Form implementation generated from reading ui file 'SearcherSettings.ui' +# +# Created by: PyQt5 UI code generator 5.14.1 +# +# WARNING! All changes made in this file will be lost! + + +from PyQt5 import QtCore, QtGui, QtWidgets + + +class Ui_SearcherSettings(object): + def setupUi(self, SearcherSettings): + SearcherSettings.setObjectName("SearcherSettings") + SearcherSettings.setWindowModality(QtCore.Qt.NonModal) + SearcherSettings.resize(450, 211) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(SearcherSettings.sizePolicy().hasHeightForWidth()) + SearcherSettings.setSizePolicy(sizePolicy) + SearcherSettings.setMinimumSize(QtCore.QSize(450, 0)) + SearcherSettings.setBaseSize(QtCore.QSize(0, 0)) + SearcherSettings.setStyleSheet("") + self.gridLayout = QtWidgets.QGridLayout(SearcherSettings) + self.gridLayout.setContentsMargins(-1, -1, -1, 6) + self.gridLayout.setSpacing(6) + self.gridLayout.setObjectName("gridLayout") + self.verticalLayout_4 = QtWidgets.QVBoxLayout() + self.verticalLayout_4.setObjectName("verticalLayout_4") + self.headerrow = QtWidgets.QHBoxLayout() + self.headerrow.setObjectName("headerrow") + self.projectTitle = QtWidgets.QLabel(SearcherSettings) + font = QtGui.QFont() + font.setPointSize(15) + self.projectTitle.setFont(font) + self.projectTitle.setAlignment(QtCore.Qt.AlignCenter) + self.projectTitle.setObjectName("projectTitle") + self.headerrow.addWidget(self.projectTitle) + self.line_2 = QtWidgets.QFrame(SearcherSettings) + self.line_2.setFrameShape(QtWidgets.QFrame.VLine) + self.line_2.setFrameShadow(QtWidgets.QFrame.Sunken) + self.line_2.setObjectName("line_2") + self.headerrow.addWidget(self.line_2) + spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.headerrow.addItem(spacerItem) + self.checkBox = QtWidgets.QCheckBox(SearcherSettings) + self.checkBox.setLayoutDirection(QtCore.Qt.RightToLeft) + self.checkBox.setObjectName("checkBox") + self.headerrow.addWidget(self.checkBox) + self.windowsize_chk = QtWidgets.QCheckBox(SearcherSettings) + self.windowsize_chk.setLayoutDirection(QtCore.Qt.RightToLeft) + self.windowsize_chk.setObjectName("windowsize_chk") + self.headerrow.addWidget(self.windowsize_chk) + self.verticalLayout_4.addLayout(self.headerrow) + self.line = QtWidgets.QFrame(SearcherSettings) + self.line.setFrameShape(QtWidgets.QFrame.HLine) + self.line.setFrameShadow(QtWidgets.QFrame.Sunken) + self.line.setObjectName("line") + self.verticalLayout_4.addWidget(self.line) + self.secondrow = QtWidgets.QHBoxLayout() + self.secondrow.setObjectName("secondrow") + self.lang_cbox = QtWidgets.QComboBox(SearcherSettings) + self.lang_cbox.setObjectName("lang_cbox") + self.lang_cbox.addItem("") + self.secondrow.addWidget(self.lang_cbox) + spacerItem1 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.secondrow.addItem(spacerItem1) + self.label_3 = QtWidgets.QLabel(SearcherSettings) + self.label_3.setObjectName("label_3") + self.secondrow.addWidget(self.label_3) + self.maxresults_txt = QtWidgets.QSpinBox(SearcherSettings) + self.maxresults_txt.setMinimum(1) + self.maxresults_txt.setMaximum(9999) + self.maxresults_txt.setObjectName("maxresults_txt") + self.secondrow.addWidget(self.maxresults_txt) + self.inmemory_chk = QtWidgets.QCheckBox(SearcherSettings) + self.inmemory_chk.setLayoutDirection(QtCore.Qt.RightToLeft) + self.inmemory_chk.setTristate(False) + self.inmemory_chk.setObjectName("inmemory_chk") + self.secondrow.addWidget(self.inmemory_chk) + self.verticalLayout_4.addLayout(self.secondrow) + self.thirdrow = QtWidgets.QHBoxLayout() + self.thirdrow.setObjectName("thirdrow") + self.label_2 = QtWidgets.QLabel(SearcherSettings) + self.label_2.setObjectName("label_2") + self.thirdrow.addWidget(self.label_2) + self.defaulthotkey_txt = QtWidgets.QLineEdit(SearcherSettings) + self.defaulthotkey_txt.setToolTip("") + self.defaulthotkey_txt.setReadOnly(True) + self.defaulthotkey_txt.setObjectName("defaulthotkey_txt") + self.thirdrow.addWidget(self.defaulthotkey_txt) + self.hotkey_icon = QtWidgets.QToolButton(SearcherSettings) + self.hotkey_icon.setPopupMode(QtWidgets.QToolButton.InstantPopup) + self.hotkey_icon.setObjectName("hotkey_icon") + self.thirdrow.addWidget(self.hotkey_icon) + self.verticalLayout_4.addLayout(self.thirdrow) + self.fourthrow = QtWidgets.QHBoxLayout() + self.fourthrow.setObjectName("fourthrow") + self.label = QtWidgets.QLabel(SearcherSettings) + self.label.setObjectName("label") + self.fourthrow.addWidget(self.label) + self.databasepath_txt = QtWidgets.QLineEdit(SearcherSettings) + self.databasepath_txt.setObjectName("databasepath_txt") + self.fourthrow.addWidget(self.databasepath_txt) + self.dbpath_icon = QtWidgets.QToolButton(SearcherSettings) + self.dbpath_icon.setObjectName("dbpath_icon") + self.fourthrow.addWidget(self.dbpath_icon) + self.verticalLayout_4.addLayout(self.fourthrow) + self.fifthrow = QtWidgets.QHBoxLayout() + self.fifthrow.setObjectName("fifthrow") + self.label_4 = QtWidgets.QLabel(SearcherSettings) + self.label_4.setObjectName("label_4") + self.fifthrow.addWidget(self.label_4) + self.test1_btn = QtWidgets.QPushButton(SearcherSettings) + self.test1_btn.setObjectName("test1_btn") + self.fifthrow.addWidget(self.test1_btn) + self.cleardata_btn = QtWidgets.QPushButton(SearcherSettings) + self.cleardata_btn.setObjectName("cleardata_btn") + self.fifthrow.addWidget(self.cleardata_btn) + self.verticalLayout_4.addLayout(self.fifthrow) + self.line_3 = QtWidgets.QFrame(SearcherSettings) + self.line_3.setFrameShape(QtWidgets.QFrame.HLine) + self.line_3.setFrameShadow(QtWidgets.QFrame.Sunken) + self.line_3.setObjectName("line_3") + self.verticalLayout_4.addWidget(self.line_3) + self.sixthrow = QtWidgets.QHBoxLayout() + self.sixthrow.setObjectName("sixthrow") + self.about_btn = QtWidgets.QToolButton(SearcherSettings) + self.about_btn.setObjectName("about_btn") + self.sixthrow.addWidget(self.about_btn) + spacerItem2 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.sixthrow.addItem(spacerItem2) + self.debuglevel_cbx = QtWidgets.QComboBox(SearcherSettings) + self.debuglevel_cbx.setObjectName("debuglevel_cbx") + self.debuglevel_cbx.addItem("") + self.debuglevel_cbx.addItem("") + self.debuglevel_cbx.addItem("") + self.sixthrow.addWidget(self.debuglevel_cbx) + self.debugflag_chk = QtWidgets.QCheckBox(SearcherSettings) + self.debugflag_chk.setLayoutDirection(QtCore.Qt.LeftToRight) + self.debugflag_chk.setObjectName("debugflag_chk") + self.sixthrow.addWidget(self.debugflag_chk) + self.discard_btn = QtWidgets.QPushButton(SearcherSettings) + self.discard_btn.setObjectName("discard_btn") + self.sixthrow.addWidget(self.discard_btn) + self.save_btn = QtWidgets.QPushButton(SearcherSettings) + self.save_btn.setObjectName("save_btn") + self.sixthrow.addWidget(self.save_btn) + self.verticalLayout_4.addLayout(self.sixthrow) + self.gridLayout.addLayout(self.verticalLayout_4, 0, 0, 1, 1) + + self.retranslateUi(SearcherSettings) + QtCore.QMetaObject.connectSlotsByName(SearcherSettings) + + def retranslateUi(self, SearcherSettings): + _translate = QtCore.QCoreApplication.translate + SearcherSettings.setWindowTitle(_translate("SearcherSettings", "Form")) + self.projectTitle.setText(_translate("SearcherSettings", "Settings")) + self.checkBox.setText(_translate("SearcherSettings", "Use Animated Menus:")) + self.windowsize_chk.setText(_translate("SearcherSettings", "Remember Search Window Size")) + self.lang_cbox.setCurrentText(_translate("SearcherSettings", "English")) + self.lang_cbox.setItemText(0, _translate("SearcherSettings", "English")) + self.label_3.setText(_translate("SearcherSettings", "Maximum Search Results")) + self.inmemory_chk.setText(_translate("SearcherSettings", "Use In-Memory Database")) + self.label_2.setText(_translate("SearcherSettings", "Hotkey to use for opening unassigned items: ")) + self.defaulthotkey_txt.setPlaceholderText(_translate("SearcherSettings", "Double Click")) + self.hotkey_icon.setText(_translate("SearcherSettings", "...")) + self.label.setText(_translate("SearcherSettings", "Database location: ")) + self.dbpath_icon.setText(_translate("SearcherSettings", "...")) + self.label_4.setText(_translate("SearcherSettings", "Maintenance utilities:")) + self.test1_btn.setText(_translate("SearcherSettings", "Test Button 1")) + self.cleardata_btn.setText(_translate("SearcherSettings", "Clear Data")) + self.about_btn.setText(_translate("SearcherSettings", "...")) + self.debuglevel_cbx.setItemText(0, _translate("SearcherSettings", "NONE")) + self.debuglevel_cbx.setItemText(1, _translate("SearcherSettings", "TIMER")) + self.debuglevel_cbx.setItemText(2, _translate("SearcherSettings", "ALL")) + self.debugflag_chk.setText(_translate("SearcherSettings", "Debug Mode")) + self.discard_btn.setText(_translate("SearcherSettings", "Discard")) + self.save_btn.setText(_translate("SearcherSettings", "Save")) diff --git a/scripts/python/searcher/searchersettings.ui b/scripts/python/searcher/ui_files/SearcherSettings.ui similarity index 66% rename from scripts/python/searcher/searchersettings.ui rename to scripts/python/searcher/ui_files/SearcherSettings.ui index a87aead..6840c36 100644 --- a/scripts/python/searcher/searchersettings.ui +++ b/scripts/python/searcher/ui_files/SearcherSettings.ui @@ -9,8 +9,8 @@ 0 0 - 600 - 164 + 450 + 211 @@ -21,7 +21,7 @@ - 600 + 450 0 @@ -34,11 +34,20 @@ Form + + + - + + 6 + + + 6 + + - + @@ -47,13 +56,20 @@ - Searcher Settings + Settings Qt::AlignCenter + + + + Qt::Vertical + + + @@ -68,15 +84,12 @@ - + Qt::RightToLeft - Use In-Memory Database - - - false + Use Animated Menus: @@ -93,7 +106,26 @@ - + + + Qt::Horizontal + + + + + + + + + English + + + + English + + + + @@ -107,6 +139,40 @@ + + + + Maximum Search Results + + + + + + + 1 + + + 9999 + + + + + + + Qt::RightToLeft + + + Use In-Memory Database + + + false + + + + + + + @@ -115,7 +181,7 @@ - + @@ -140,7 +206,7 @@ - + @@ -161,7 +227,7 @@ - + @@ -176,13 +242,6 @@ - - - - Test HContext - - - @@ -193,7 +252,21 @@ - + + + Qt::Horizontal + + + + + + + + + ... + + + @@ -207,10 +280,29 @@ + + + + + NONE + + + + + TIMER + + + + + ALL + + + + - Qt::RightToLeft + Qt::LeftToRight Debug Mode diff --git a/scripts/python/searcher/ui_files/about.ui b/scripts/python/searcher/ui_files/about.ui new file mode 100644 index 0000000..e020f1b --- /dev/null +++ b/scripts/python/searcher/ui_files/about.ui @@ -0,0 +1,120 @@ + + + About + + + Qt::NonModal + + + + 0 + 0 + 142 + 210 + + + + + 0 + 0 + + + + + 100 + 0 + + + + + 0 + 0 + + + + Form + + + + + + + 6 + + + 6 + + + + + + + + 0 + 0 + + + + + 120 + 120 + + + + + + + C:/Users/mosthated/Downloads/483688212.png + + + true + + + + + + + + + github.com/instance-id + + + + + + + + + + + instance.id + + + + + + + + + + + Report Bug + + + + + + + ... + + + + + + + + + + + + diff --git a/scripts/python/searcher/ui_files/searcher_ui.py b/scripts/python/searcher/ui_files/searcher_ui.py new file mode 100644 index 0000000..500e4fd --- /dev/null +++ b/scripts/python/searcher/ui_files/searcher_ui.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- + +# Form implementation generated from reading ui file 'searcher_ui.ui' +# +# Created by: PyQt5 UI code generator 5.14.1 +# +# WARNING! All changes made in this file will be lost! + + +from PyQt5 import QtCore, QtGui, QtWidgets + + +class Ui_Searcher(object): + def setupUi(self, Searcher): + Searcher.setObjectName("Searcher") + Searcher.setWindowModality(QtCore.Qt.WindowModal) + Searcher.resize(1000, 329) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(Searcher.sizePolicy().hasHeightForWidth()) + Searcher.setSizePolicy(sizePolicy) + Searcher.setMinimumSize(QtCore.QSize(0, 0)) + Searcher.setBaseSize(QtCore.QSize(1000, 350)) + Searcher.setStyleSheet("QTreeWidget QHeaderView::section {\n" +" font-size: 9pt;\n" +"}") + self.gridLayout = QtWidgets.QGridLayout(Searcher) + self.gridLayout.setContentsMargins(0, 0, 0, 0) + self.gridLayout.setSpacing(0) + self.gridLayout.setObjectName("gridLayout") + self.verticalLayout = QtWidgets.QVBoxLayout() + self.verticalLayout.setSpacing(0) + self.verticalLayout.setObjectName("verticalLayout") + self.horizontalLayout_3 = QtWidgets.QHBoxLayout() + self.horizontalLayout_3.setSpacing(0) + self.horizontalLayout_3.setObjectName("horizontalLayout_3") + self.frame = QtWidgets.QFrame(Searcher) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred) + sizePolicy.setHorizontalStretch(2) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.frame.sizePolicy().hasHeightForWidth()) + self.frame.setSizePolicy(sizePolicy) + self.frame.setMinimumSize(QtCore.QSize(0, 20)) + self.frame.setFrameShape(QtWidgets.QFrame.StyledPanel) + self.frame.setFrameShadow(QtWidgets.QFrame.Raised) + self.frame.setObjectName("frame") + self.searchfilter_btn = QtWidgets.QToolButton(self.frame) + self.searchfilter_btn.setGeometry(QtCore.QRect(0, 0, 26, 20)) + self.searchfilter_btn.setBaseSize(QtCore.QSize(16, 16)) + self.searchfilter_btn.setStyleSheet("background-color: rgb(19, 19, 19);") + self.searchfilter_btn.setArrowType(QtCore.Qt.NoArrow) + self.searchfilter_btn.setObjectName("searchfilter_btn") + self.horizontalLayout_3.addWidget(self.frame) + self.searchbox_txt = QtWidgets.QLineEdit(Searcher) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + sizePolicy.setHorizontalStretch(99) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.searchbox_txt.sizePolicy().hasHeightForWidth()) + self.searchbox_txt.setSizePolicy(sizePolicy) + self.searchbox_txt.setMinimumSize(QtCore.QSize(50, 0)) + self.searchbox_txt.setMouseTracking(False) + self.searchbox_txt.setStyleSheet("background-color: rgb(19, 19, 19);") + self.searchbox_txt.setFrame(False) + self.searchbox_txt.setObjectName("searchbox_txt") + self.horizontalLayout_3.addWidget(self.searchbox_txt) + self.verticalLayout.addLayout(self.horizontalLayout_3) + self.searchresults_tree = QtWidgets.QTreeWidget(Searcher) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Expanding) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.searchresults_tree.sizePolicy().hasHeightForWidth()) + self.searchresults_tree.setSizePolicy(sizePolicy) + font = QtGui.QFont() + font.setPointSize(9) + self.searchresults_tree.setFont(font) + self.searchresults_tree.setMouseTracking(False) + self.searchresults_tree.setFocusPolicy(QtCore.Qt.NoFocus) + self.searchresults_tree.setFrameShadow(QtWidgets.QFrame.Sunken) + self.searchresults_tree.setLineWidth(0) + self.searchresults_tree.setSizeAdjustPolicy(QtWidgets.QAbstractScrollArea.AdjustToContents) + self.searchresults_tree.setAlternatingRowColors(True) + self.searchresults_tree.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) + self.searchresults_tree.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.searchresults_tree.setObjectName("searchresults_tree") + self.searchresults_tree.headerItem().setText(0, "1") + self.verticalLayout.addWidget(self.searchresults_tree) + self.gridLayout.addLayout(self.verticalLayout, 1, 0, 1, 1) + self.infobar = QtWidgets.QHBoxLayout() + self.infobar.setObjectName("infobar") + self.infobargrid = QtWidgets.QGridLayout() + self.infobargrid.setObjectName("infobargrid") + self.info_lbl = QtWidgets.QLabel(Searcher) + font = QtGui.QFont() + font.setPointSize(8) + font.setBold(False) + font.setWeight(50) + self.info_lbl.setFont(font) + self.info_lbl.setStyleSheet("background-color: rgb(26, 26, 26);") + self.info_lbl.setIndent(5) + self.info_lbl.setObjectName("info_lbl") + self.infobargrid.addWidget(self.info_lbl, 1, 0, 1, 1) + self.label = QtWidgets.QLabel(Searcher) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Preferred) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.label.sizePolicy().hasHeightForWidth()) + self.label.setSizePolicy(sizePolicy) + self.label.setMaximumSize(QtCore.QSize(50, 16777215)) + self.label.setLayoutDirection(QtCore.Qt.RightToLeft) + self.label.setAlignment(QtCore.Qt.AlignRight|QtCore.Qt.AlignTrailing|QtCore.Qt.AlignVCenter) + self.label.setObjectName("label") + self.infobargrid.addWidget(self.label, 1, 1, 1, 1) + self.infobar.addLayout(self.infobargrid) + self.gridLayout.addLayout(self.infobar, 3, 0, 1, 1) + + self.retranslateUi(Searcher) + QtCore.QMetaObject.connectSlotsByName(Searcher) + + def retranslateUi(self, Searcher): + _translate = QtCore.QCoreApplication.translate + Searcher.setWindowTitle(_translate("Searcher", "Searcher")) + self.searchfilter_btn.setText(_translate("Searcher", "...")) + self.info_lbl.setText(_translate("Searcher", "Begin typing to search or click magnifying glass icon to display options")) + self.label.setText(_translate("Searcher", "TextLabel")) diff --git a/scripts/python/searcher/searcher_ui.ui b/scripts/python/searcher/ui_files/searcher_ui.ui similarity index 75% rename from scripts/python/searcher/searcher_ui.ui rename to scripts/python/searcher/ui_files/searcher_ui.ui index baa6881..2189a62 100644 --- a/scripts/python/searcher/searcher_ui.ui +++ b/scripts/python/searcher/ui_files/searcher_ui.ui @@ -185,28 +185,61 @@ - - - - - 8 - 50 - false - - - - background-color: rgb(26, 26, 26); - - - Begin typing to search or click magnifying glass icon to display options - - - 2 - - - 5 - - + + + + + + + + + 8 + 50 + false + + + + background-color: rgb(26, 26, 26); + + + Begin typing to search or click magnifying glass icon to display options + + + 2 + + + 5 + + + + + + + + 0 + 0 + + + + + 50 + 16777215 + + + + Qt::RightToLeft + + + TextLabel + + + Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter + + + + + + diff --git a/scripts/python/searcher/util.py b/scripts/python/searcher/util.py index 2c3b3e2..1113b2f 100644 --- a/scripts/python/searcher/util.py +++ b/scripts/python/searcher/util.py @@ -1,7 +1,7 @@ -# region ----------------------------------------------------------------- Imports +# ------------------------------------------------------------------------ Imports from __future__ import print_function from __future__ import absolute_import -import weakref +from searcher import enum from sys import platform from typing import Tuple @@ -11,56 +11,69 @@ if os.environ["HFS"] != "": ver = os.environ["HFS"] hver = int(ver[ver.rindex('.')+1:]) - if int(hver) >= 391: - from hutil.Qt import _QtUiTools - from hutil.Qt import QtGui - from hutil.Qt import QtCore - from hutil.Qt import QtWidgets - elif int(hver) < 391: - from hutil.Qt import QtUiTools - from hutil.Qt import QtGui - from hutil.Qt import QtCore - from hutil.Qt import QtWidgets -# else: -# os.environ['QT_API'] = 'pyside2' -# from PySide import QtUiTools -# from qtpy import QtGui -# from qtpy import QtCore -# from qtpy import QtWidgets - + from hutil.Qt import QtCore # endregion -SequenceT = Tuple[str, ...] +DEBUG_LEVEL = enum.Enum('NONE', 'TIMER', 'ALL') +class Dbug(object): + def __init__(self, enabled, level): + self.enabled = enabled + self.level = level + def __nonzero__(self): return bool(self.enabled) -# region ----------------------------------------------------------------- Helper Functions +SequenceT = Tuple[str, ...] + +# ------------------------------------------------------------------------ Helper Functions def bc(v): return str(v).lower() in ("yes", "true", "t", "1") # endregion - -# region ----------------------------------------------------------------- Application Settings +# ------------------------------------------------------------------------ Application Settings SETTINGS_KEYS = [ - 'in_memory_db', - 'database_path', - 'savewindowsize', - 'windowsize', - 'debugflag', - 'pinwindow', + 'in_memory_db', # 0 + 'database_path', # 1 + 'savewindowsize', # 2 + 'windowsize', # 3 + 'debugflag', # 4 + 'pinwindow', # 5 + 'defaulthotkey', # 6 + 'showctx', # 7 + 'animatedsettings', # 8 + 'maxresults', # 9 + 'debuglevel' # 10 ] # Include parameter type if it is to be processed by settings menu, else mark NA SETTINGS_TYPES = { - SETTINGS_KEYS[0]: 'bool', # in_memory_db - SETTINGS_KEYS[1]: 'text', # database_path - SETTINGS_KEYS[2]: 'bool', # savewindowsize - SETTINGS_KEYS[3]: 'int', # windowsize - SETTINGS_KEYS[4]: 'bool', # debugflag - SETTINGS_KEYS[5]: 'NA', # pinwindow + SETTINGS_KEYS[0]: 'bool', # in_memory_db + SETTINGS_KEYS[1]: 'text', # database_path + SETTINGS_KEYS[2]: 'bool', # savewindowsize + SETTINGS_KEYS[3]: 'int', # windowsize + SETTINGS_KEYS[4]: 'bool', # debugflag + SETTINGS_KEYS[5]: 'NA', # pinwindow + SETTINGS_KEYS[6]: 'text', # defaulthotkey + SETTINGS_KEYS[7]: 'NA', # showctx + SETTINGS_KEYS[8]: 'bool', # animatedsettings + SETTINGS_KEYS[9]: 'intval', # maxresults + SETTINGS_KEYS[10]: 'cbx', # debuglevel +} + +DEFAULT_SETTINGS = { + SETTINGS_KEYS[0]: "False", # in_memory_db + SETTINGS_KEYS[1]: "", # database_path + SETTINGS_KEYS[2]: "False", # savewindowsize + SETTINGS_KEYS[3]: [1000, 600], # windowsize + SETTINGS_KEYS[4]: "False", # debugflag + SETTINGS_KEYS[5]: "False", # pinwindow + SETTINGS_KEYS[6]: u"Ctrl+Alt+Shift+F7", # defaulthotkey + SETTINGS_KEYS[7]: "True", # showctx + SETTINGS_KEYS[8]: "True", # animatedsettings + SETTINGS_KEYS[9]: 100, # maxresults + SETTINGS_KEYS[10]: "NONE", # debuglevel } -# endregion -# region ----------------------------------------------------------------- Key Translations +# ------------------------------------------------------------------------ Key Translations # Directional conversion KEYCONVERSIONS = { "DownArrow": "down", @@ -69,7 +82,6 @@ def bc(v): "RightArrow": "right", } - # List of possible hotkeys to use a temp keys when running commands HOTKEYLIST = [ (u"Ctrl+Alt+Shift+F7"), @@ -79,7 +91,6 @@ def bc(v): (u"Ctrl+Alt+Shift+F10") ] - # Used to detect if a keypress was just a modifier MODIFIER_KEYS = { QtCore.Qt.Key_Alt: "Alt", @@ -113,7 +124,7 @@ def bc(v): QtCore.Qt.Key_Home: "Page_Home", } -# region -------------------------------------------- Platform conversions +# --------------------------------------------------- Platform conversions # # Platform conversions # if platform == "linux" or platform == "linux2": # tmp = { @@ -383,12 +394,37 @@ def bc(v): } # endregion -# region ----------------------------------------------------------------- UI Constants +# ------------------------------------------------------------------------ UI Constants ICON_SIZE = hou.ui.scaledSize(32) EDIT_ICON_SIZE = hou.ui.scaledSize(28) -SETTINGS_ICON = hou.ui.createQtIcon( - 'BUTTONS_gear', +# DOP_pyrosolver +# MISC_database +# MISC_python +# MISC_rename +# NETVIEW_64bit_badge # bug +# NETVIEW_comment_badge +# NETVIEW_debug +# NETVIEW_info_button +# NETVIEW_message_badge +# NETVIEW_image_link +# NETVIEW_image_link_located + + +BUG_ICON = hou.ui.createQtIcon( + 'NETVIEW_64bit_badge', + EDIT_ICON_SIZE, + EDIT_ICON_SIZE +) + +COLLAPSE_ICON = hou.ui.createQtIcon( + 'BUTTONS_collapse_left', + EDIT_ICON_SIZE, + EDIT_ICON_SIZE +) + +EXPAND_ICON = hou.ui.createQtIcon( + 'BUTTONS_expand_right', EDIT_ICON_SIZE, EDIT_ICON_SIZE ) @@ -405,6 +441,12 @@ def bc(v): EDIT_ICON_SIZE ) +ABOUT_ICON1 = hou.ui.createQtIcon( + 'NETVIEW_info_button', + EDIT_ICON_SIZE, + EDIT_ICON_SIZE +) + PIN_IN_ICON = hou.ui.createQtIcon( 'BUTTONS_pin_in_mono', EDIT_ICON_SIZE, @@ -423,6 +465,11 @@ def bc(v): EDIT_ICON_SIZE ) +SETTINGS_ICON = hou.ui.createQtIcon( + 'BUTTONS_gear', + EDIT_ICON_SIZE, + EDIT_ICON_SIZE +) MENUSTYLE = """QMenu {background-color: rgb(64,64,64); menu-scrollable: 1; margin: 0px;} QMenu:item {background-color: rgb(46,46,46); padding: 5px 25px; margin: 1px; height:16px;} @@ -438,3 +485,4 @@ def bc(v): CTXSHOTCUTS = [":v", ":c", ":g"] # endregion + diff --git a/scripts/python/searcher/version.py b/scripts/python/searcher/version.py new file mode 100644 index 0000000..00b3c83 --- /dev/null +++ b/scripts/python/searcher/version.py @@ -0,0 +1,9 @@ +VERSION_MAJOR = 0 +VERSION_MINOR = 0 +VERSION_PATCH = 1 + +version_info = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) +version = '%i.%i.%i' % version_info +__version__ = version + +__all__ = ['version', 'version_info', '__version__'] \ No newline at end of file diff --git a/scripts/python/searcher/widgets/collapsedock.py b/scripts/python/searcher/widgets/collapsedock.py new file mode 100644 index 0000000..cef9cfe --- /dev/null +++ b/scripts/python/searcher/widgets/collapsedock.py @@ -0,0 +1,112 @@ +from PyQt5 import QtCore, QtGui, QtWidgets + + +class CollapsibleBox(QtWidgets.QWidget): + def __init__(self, title="", parent=None): + super(CollapsibleBox, self).__init__(parent) + + self.toggle_button = QtWidgets.QToolButton( + text=title, checkable=True, checked=False + ) + self.toggle_button.setStyleSheet("QToolButton { border: none; }") + self.toggle_button.setToolButtonStyle( + QtCore.Qt.ToolButtonTextBesideIcon + ) + self.toggle_button.setArrowType(QtCore.Qt.RightArrow) + self.toggle_button.pressed.connect(self.on_pressed) + + self.toggle_animation = QtCore.QParallelAnimationGroup(self) + + self.content_area = QtWidgets.QScrollArea( + maximumHeight=0, minimumHeight=0 + ) + self.content_area.setSizePolicy( + QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed + ) + self.content_area.setFrameShape(QtWidgets.QFrame.NoFrame) + + lay = QtWidgets.QVBoxLayout(self) + lay.setSpacing(0) + lay.setContentsMargins(0, 0, 0, 0) + lay.addWidget(self.toggle_button) + lay.addWidget(self.content_area) + + self.toggle_animation.addAnimation( + QtCore.QPropertyAnimation(self, b"minimumHeight") + ) + self.toggle_animation.addAnimation( + QtCore.QPropertyAnimation(self, b"maximumHeight") + ) + self.toggle_animation.addAnimation( + QtCore.QPropertyAnimation(self.content_area, b"maximumHeight") + ) + + @QtCore.pyqtSlot() + def on_pressed(self): + checked = self.toggle_button.isChecked() + self.toggle_button.setArrowType( + QtCore.Qt.DownArrow if not checked else QtCore.Qt.RightArrow + ) + self.toggle_animation.setDirection( + QtCore.QAbstractAnimation.Forward + if not checked + else QtCore.QAbstractAnimation.Backward + ) + self.toggle_animation.start() + + def setContentLayout(self, layout): + lay = self.content_area.layout() + del lay + self.content_area.setLayout(layout) + collapsed_height = ( + self.sizeHint().height() - self.content_area.maximumHeight() + ) + content_height = layout.sizeHint().height() + for i in range(self.toggle_animation.animationCount()): + animation = self.toggle_animation.animationAt(i) + animation.setDuration(200) + animation.setStartValue(collapsed_height) + animation.setEndValue(collapsed_height + content_height) + + content_animation = self.toggle_animation.animationAt( + self.toggle_animation.animationCount() - 1 + ) + content_animation.setDuration(200) + content_animation.setStartValue(0) + content_animation.setEndValue(content_height) + + +if __name__ == "__main__": + import sys + import random + + app = QtWidgets.QApplication(sys.argv) + + w = QtWidgets.QMainWindow() + w.setCentralWidget(QtWidgets.QWidget()) + dock = QtWidgets.QDockWidget("Collapsible Demo") + w.addDockWidget(QtCore.Qt.LeftDockWidgetArea, dock) + scroll = QtWidgets.QScrollArea() + dock.setWidget(scroll) + content = QtWidgets.QWidget() + scroll.setWidget(content) + scroll.setWidgetResizable(True) + vlay = QtWidgets.QVBoxLayout(content) + for i in range(10): + box = CollapsibleBox("Collapsible Box Header-{}".format(i)) + vlay.addWidget(box) + lay = QtWidgets.QVBoxLayout() + for j in range(8): + label = QtWidgets.QLineEdit("{}".format(j)) + color = QtGui.QColor(*[random.randint(0, 255) for _ in range(3)]) + label.setStyleSheet( + "background-color: {}; color : white;".format(color.name()) + ) + label.setAlignment(QtCore.Qt.AlignCenter) + lay.addWidget(label) + + box.setContentLayout(lay) + vlay.addStretch() + w.resize(640, 480) + w.show() + sys.exit(app.exec_()) \ No newline at end of file