diff --git a/python3.10libs/Qt.py b/python3.10libs/Qt.py new file mode 100644 index 0000000..fe4b45f --- /dev/null +++ b/python3.10libs/Qt.py @@ -0,0 +1,1989 @@ +"""Minimal Python 2 & 3 shim around all Qt bindings + +DOCUMENTATION + Qt.py was born in the film and visual effects industry to address + the growing need for the development of software capable of running + with more than one flavour of the Qt bindings for Python - PySide, + PySide2, PyQt4 and PyQt5. + + 1. Build for one, run with all + 2. Explicit is better than implicit + 3. Support co-existence + + Default resolution order: + - PySide2 + - PyQt5 + - PySide + - PyQt4 + + Usage: + >> import sys + >> from Qt import QtWidgets + >> app = QtWidgets.QApplication(sys.argv) + >> button = QtWidgets.QPushButton("Hello World") + >> button.show() + >> app.exec_() + + All members of PySide2 are mapped from other bindings, should they exist. + If no equivalent member exist, it is excluded from Qt.py and inaccessible. + The idea is to highlight members that exist across all supported binding, + and guarantee that code that runs on one binding runs on all others. + + For more details, visit https://github.com/mottosso/Qt.py + +LICENSE + + See end of file for license (MIT, BSD) information. + +""" + +import os +import sys +import types +import shutil +import importlib + + +__version__ = "1.2.3" + +# Enable support for `from Qt import *` +__all__ = [] + +# Flags from environment variables +QT_VERBOSE = bool(os.getenv("QT_VERBOSE")) +QT_PREFERRED_BINDING = os.getenv("QT_PREFERRED_BINDING", "") +QT_SIP_API_HINT = os.getenv("QT_SIP_API_HINT") + +# Reference to Qt.py +Qt = sys.modules[__name__] +Qt.QtCompat = types.ModuleType("QtCompat") + +try: + long +except NameError: + # Python 3 compatibility + long = int + + +"""Common members of all bindings + +This is where each member of Qt.py is explicitly defined. +It is based on a "lowest common denominator" of all bindings; +including members found in each of the 4 bindings. + +The "_common_members" dictionary is generated using the +build_membership.sh script. + +""" + +_common_members = { + "QtCore": [ + "QAbstractAnimation", + "QAbstractEventDispatcher", + "QAbstractItemModel", + "QAbstractListModel", + "QAbstractState", + "QAbstractTableModel", + "QAbstractTransition", + "QAnimationGroup", + "QBasicTimer", + "QBitArray", + "QBuffer", + "QByteArray", + "QByteArrayMatcher", + "QChildEvent", + "QCoreApplication", + "QCryptographicHash", + "QDataStream", + "QDate", + "QDateTime", + "QDir", + "QDirIterator", + "QDynamicPropertyChangeEvent", + "QEasingCurve", + "QElapsedTimer", + "QEvent", + "QEventLoop", + "QEventTransition", + "QFile", + "QFileInfo", + "QFileSystemWatcher", + "QFinalState", + "QGenericArgument", + "QGenericReturnArgument", + "QHistoryState", + "QItemSelectionRange", + "QIODevice", + "QLibraryInfo", + "QLine", + "QLineF", + "QLocale", + "QMargins", + "QMetaClassInfo", + "QMetaEnum", + "QMetaMethod", + "QMetaObject", + "QMetaProperty", + "QMimeData", + "QModelIndex", + "QMutex", + "QMutexLocker", + "QObject", + "QParallelAnimationGroup", + "QPauseAnimation", + "QPersistentModelIndex", + "QPluginLoader", + "QPoint", + "QPointF", + "QProcess", + "QProcessEnvironment", + "QPropertyAnimation", + "QReadLocker", + "QReadWriteLock", + "QRect", + "QRectF", + "QRegExp", + "QResource", + "QRunnable", + "QSemaphore", + "QSequentialAnimationGroup", + "QSettings", + "QSignalMapper", + "QSignalTransition", + "QSize", + "QSizeF", + "QSocketNotifier", + "QState", + "QStateMachine", + "QSysInfo", + "QSystemSemaphore", + "QT_TRANSLATE_NOOP", + "QT_TR_NOOP", + "QT_TR_NOOP_UTF8", + "QTemporaryFile", + "QTextBoundaryFinder", + "QTextCodec", + "QTextDecoder", + "QTextEncoder", + "QTextStream", + "QTextStreamManipulator", + "QThread", + "QThreadPool", + "QTime", + "QTimeLine", + "QTimer", + "QTimerEvent", + "QTranslator", + "QUrl", + "QVariantAnimation", + "QWaitCondition", + "QWriteLocker", + "QXmlStreamAttribute", + "QXmlStreamAttributes", + "QXmlStreamEntityDeclaration", + "QXmlStreamEntityResolver", + "QXmlStreamNamespaceDeclaration", + "QXmlStreamNotationDeclaration", + "QXmlStreamReader", + "QXmlStreamWriter", + "Qt", + "QtCriticalMsg", + "QtDebugMsg", + "QtFatalMsg", + "QtMsgType", + "QtSystemMsg", + "QtWarningMsg", + "qAbs", + "qAddPostRoutine", + "qChecksum", + "qCritical", + "qDebug", + "qFatal", + "qFuzzyCompare", + "qIsFinite", + "qIsInf", + "qIsNaN", + "qIsNull", + "qRegisterResourceData", + "qUnregisterResourceData", + "qVersion", + "qWarning", + "qrand", + "qsrand" + ], + "QtGui": [ + "QAbstractTextDocumentLayout", + "QActionEvent", + "QBitmap", + "QBrush", + "QClipboard", + "QCloseEvent", + "QColor", + "QConicalGradient", + "QContextMenuEvent", + "QCursor", + "QDesktopServices", + "QDoubleValidator", + "QDrag", + "QDragEnterEvent", + "QDragLeaveEvent", + "QDragMoveEvent", + "QDropEvent", + "QFileOpenEvent", + "QFocusEvent", + "QFont", + "QFontDatabase", + "QFontInfo", + "QFontMetrics", + "QFontMetricsF", + "QGradient", + "QHelpEvent", + "QHideEvent", + "QHoverEvent", + "QIcon", + "QIconDragEvent", + "QIconEngine", + "QImage", + "QImageIOHandler", + "QImageReader", + "QImageWriter", + "QInputEvent", + "QInputMethodEvent", + "QIntValidator", + "QKeyEvent", + "QKeySequence", + "QLinearGradient", + "QMatrix2x2", + "QMatrix2x3", + "QMatrix2x4", + "QMatrix3x2", + "QMatrix3x3", + "QMatrix3x4", + "QMatrix4x2", + "QMatrix4x3", + "QMatrix4x4", + "QMouseEvent", + "QMoveEvent", + "QMovie", + "QPaintDevice", + "QPaintEngine", + "QPaintEngineState", + "QPaintEvent", + "QPainter", + "QPainterPath", + "QPainterPathStroker", + "QPalette", + "QPen", + "QPicture", + "QPictureIO", + "QPixmap", + "QPixmapCache", + "QPolygon", + "QPolygonF", + "QQuaternion", + "QRadialGradient", + "QRegExpValidator", + "QRegion", + "QResizeEvent", + "QSessionManager", + "QShortcutEvent", + "QShowEvent", + "QStandardItem", + "QStandardItemModel", + "QStatusTipEvent", + "QSyntaxHighlighter", + "QTabletEvent", + "QTextBlock", + "QTextBlockFormat", + "QTextBlockGroup", + "QTextBlockUserData", + "QTextCharFormat", + "QTextCursor", + "QTextDocument", + "QTextDocumentFragment", + "QTextFormat", + "QTextFragment", + "QTextFrame", + "QTextFrameFormat", + "QTextImageFormat", + "QTextInlineObject", + "QTextItem", + "QTextLayout", + "QTextLength", + "QTextLine", + "QTextList", + "QTextListFormat", + "QTextObject", + "QTextObjectInterface", + "QTextOption", + "QTextTable", + "QTextTableCell", + "QTextTableCellFormat", + "QTextTableFormat", + "QTouchEvent", + "QTransform", + "QValidator", + "QVector2D", + "QVector3D", + "QVector4D", + "QWhatsThisClickedEvent", + "QWheelEvent", + "QWindowStateChangeEvent", + "qAlpha", + "qBlue", + "qGray", + "qGreen", + "qIsGray", + "qRed", + "qRgb", + "qRgba" + ], + "QtHelp": [ + "QHelpContentItem", + "QHelpContentModel", + "QHelpContentWidget", + "QHelpEngine", + "QHelpEngineCore", + "QHelpIndexModel", + "QHelpIndexWidget", + "QHelpSearchEngine", + "QHelpSearchQuery", + "QHelpSearchQueryWidget", + "QHelpSearchResultWidget" + ], + "QtMultimedia": [ + "QAbstractVideoBuffer", + "QAbstractVideoSurface", + "QAudio", + "QAudioDeviceInfo", + "QAudioFormat", + "QAudioInput", + "QAudioOutput", + "QVideoFrame", + "QVideoSurfaceFormat" + ], + "QtNetwork": [ + "QAbstractNetworkCache", + "QAbstractSocket", + "QAuthenticator", + "QHostAddress", + "QHostInfo", + "QLocalServer", + "QLocalSocket", + "QNetworkAccessManager", + "QNetworkAddressEntry", + "QNetworkCacheMetaData", + "QNetworkConfiguration", + "QNetworkConfigurationManager", + "QNetworkCookie", + "QNetworkCookieJar", + "QNetworkDiskCache", + "QNetworkInterface", + "QNetworkProxy", + "QNetworkProxyFactory", + "QNetworkProxyQuery", + "QNetworkReply", + "QNetworkRequest", + "QNetworkSession", + "QSsl", + "QTcpServer", + "QTcpSocket", + "QUdpSocket" + ], + "QtOpenGL": [ + "QGL", + "QGLContext", + "QGLFormat", + "QGLWidget" + ], + "QtPrintSupport": [ + "QAbstractPrintDialog", + "QPageSetupDialog", + "QPrintDialog", + "QPrintEngine", + "QPrintPreviewDialog", + "QPrintPreviewWidget", + "QPrinter", + "QPrinterInfo" + ], + "QtSql": [ + "QSql", + "QSqlDatabase", + "QSqlDriver", + "QSqlDriverCreatorBase", + "QSqlError", + "QSqlField", + "QSqlIndex", + "QSqlQuery", + "QSqlQueryModel", + "QSqlRecord", + "QSqlRelation", + "QSqlRelationalDelegate", + "QSqlRelationalTableModel", + "QSqlResult", + "QSqlTableModel" + ], + "QtSvg": [ + "QGraphicsSvgItem", + "QSvgGenerator", + "QSvgRenderer", + "QSvgWidget" + ], + "QtTest": [ + "QTest" + ], + "QtWidgets": [ + "QAbstractButton", + "QAbstractGraphicsShapeItem", + "QAbstractItemDelegate", + "QAbstractItemView", + "QAbstractScrollArea", + "QAbstractSlider", + "QAbstractSpinBox", + "QAction", + "QActionGroup", + "QApplication", + "QBoxLayout", + "QButtonGroup", + "QCalendarWidget", + "QCheckBox", + "QColorDialog", + "QColumnView", + "QComboBox", + "QCommandLinkButton", + "QCommonStyle", + "QCompleter", + "QDataWidgetMapper", + "QDateEdit", + "QDateTimeEdit", + "QDesktopWidget", + "QDial", + "QDialog", + "QDialogButtonBox", + "QDirModel", + "QDockWidget", + "QDoubleSpinBox", + "QErrorMessage", + "QFileDialog", + "QFileIconProvider", + "QFileSystemModel", + "QFocusFrame", + "QFontComboBox", + "QFontDialog", + "QFormLayout", + "QFrame", + "QGesture", + "QGestureEvent", + "QGestureRecognizer", + "QGraphicsAnchor", + "QGraphicsAnchorLayout", + "QGraphicsBlurEffect", + "QGraphicsColorizeEffect", + "QGraphicsDropShadowEffect", + "QGraphicsEffect", + "QGraphicsEllipseItem", + "QGraphicsGridLayout", + "QGraphicsItem", + "QGraphicsItemGroup", + "QGraphicsLayout", + "QGraphicsLayoutItem", + "QGraphicsLineItem", + "QGraphicsLinearLayout", + "QGraphicsObject", + "QGraphicsOpacityEffect", + "QGraphicsPathItem", + "QGraphicsPixmapItem", + "QGraphicsPolygonItem", + "QGraphicsProxyWidget", + "QGraphicsRectItem", + "QGraphicsRotation", + "QGraphicsScale", + "QGraphicsScene", + "QGraphicsSceneContextMenuEvent", + "QGraphicsSceneDragDropEvent", + "QGraphicsSceneEvent", + "QGraphicsSceneHelpEvent", + "QGraphicsSceneHoverEvent", + "QGraphicsSceneMouseEvent", + "QGraphicsSceneMoveEvent", + "QGraphicsSceneResizeEvent", + "QGraphicsSceneWheelEvent", + "QGraphicsSimpleTextItem", + "QGraphicsTextItem", + "QGraphicsTransform", + "QGraphicsView", + "QGraphicsWidget", + "QGridLayout", + "QGroupBox", + "QHBoxLayout", + "QHeaderView", + "QInputDialog", + "QItemDelegate", + "QItemEditorCreatorBase", + "QItemEditorFactory", + "QKeyEventTransition", + "QLCDNumber", + "QLabel", + "QLayout", + "QLayoutItem", + "QLineEdit", + "QListView", + "QListWidget", + "QListWidgetItem", + "QMainWindow", + "QMdiArea", + "QMdiSubWindow", + "QMenu", + "QMenuBar", + "QMessageBox", + "QMouseEventTransition", + "QPanGesture", + "QPinchGesture", + "QPlainTextDocumentLayout", + "QPlainTextEdit", + "QProgressBar", + "QProgressDialog", + "QPushButton", + "QRadioButton", + "QRubberBand", + "QScrollArea", + "QScrollBar", + "QShortcut", + "QSizeGrip", + "QSizePolicy", + "QSlider", + "QSpacerItem", + "QSpinBox", + "QSplashScreen", + "QSplitter", + "QSplitterHandle", + "QStackedLayout", + "QStackedWidget", + "QStatusBar", + "QStyle", + "QStyleFactory", + "QStyleHintReturn", + "QStyleHintReturnMask", + "QStyleHintReturnVariant", + "QStyleOption", + "QStyleOptionButton", + "QStyleOptionComboBox", + "QStyleOptionComplex", + "QStyleOptionDockWidget", + "QStyleOptionFocusRect", + "QStyleOptionFrame", + "QStyleOptionGraphicsItem", + "QStyleOptionGroupBox", + "QStyleOptionHeader", + "QStyleOptionMenuItem", + "QStyleOptionProgressBar", + "QStyleOptionRubberBand", + "QStyleOptionSizeGrip", + "QStyleOptionSlider", + "QStyleOptionSpinBox", + "QStyleOptionTab", + "QStyleOptionTabBarBase", + "QStyleOptionTabWidgetFrame", + "QStyleOptionTitleBar", + "QStyleOptionToolBar", + "QStyleOptionToolBox", + "QStyleOptionToolButton", + "QStyleOptionViewItem", + "QStylePainter", + "QStyledItemDelegate", + "QSwipeGesture", + "QSystemTrayIcon", + "QTabBar", + "QTabWidget", + "QTableView", + "QTableWidget", + "QTableWidgetItem", + "QTableWidgetSelectionRange", + "QTapAndHoldGesture", + "QTapGesture", + "QTextBrowser", + "QTextEdit", + "QTimeEdit", + "QToolBar", + "QToolBox", + "QToolButton", + "QToolTip", + "QTreeView", + "QTreeWidget", + "QTreeWidgetItem", + "QTreeWidgetItemIterator", + "QUndoCommand", + "QUndoGroup", + "QUndoStack", + "QUndoView", + "QVBoxLayout", + "QWhatsThis", + "QWidget", + "QWidgetAction", + "QWidgetItem", + "QWizard", + "QWizardPage" + ], + "QtX11Extras": [ + "QX11Info" + ], + "QtXml": [ + "QDomAttr", + "QDomCDATASection", + "QDomCharacterData", + "QDomComment", + "QDomDocument", + "QDomDocumentFragment", + "QDomDocumentType", + "QDomElement", + "QDomEntity", + "QDomEntityReference", + "QDomImplementation", + "QDomNamedNodeMap", + "QDomNode", + "QDomNodeList", + "QDomNotation", + "QDomProcessingInstruction", + "QDomText", + "QXmlAttributes", + "QXmlContentHandler", + "QXmlDTDHandler", + "QXmlDeclHandler", + "QXmlDefaultHandler", + "QXmlEntityResolver", + "QXmlErrorHandler", + "QXmlInputSource", + "QXmlLexicalHandler", + "QXmlLocator", + "QXmlNamespaceSupport", + "QXmlParseException", + "QXmlReader", + "QXmlSimpleReader" + ], + "QtXmlPatterns": [ + "QAbstractMessageHandler", + "QAbstractUriResolver", + "QAbstractXmlNodeModel", + "QAbstractXmlReceiver", + "QSourceLocation", + "QXmlFormatter", + "QXmlItem", + "QXmlName", + "QXmlNamePool", + "QXmlNodeModelIndex", + "QXmlQuery", + "QXmlResultItems", + "QXmlSchema", + "QXmlSchemaValidator", + "QXmlSerializer" + ] +} + +""" Missing members + +This mapping describes members that have been deprecated +in one or more bindings and have been left out of the +_common_members mapping. + +The member can provide an extra details string to be +included in exceptions and warnings. +""" + +_missing_members = { + "QtGui": { + "QMatrix": "Deprecated in PyQt5", + }, +} + + +def _qInstallMessageHandler(handler): + """Install a message handler that works in all bindings + + Args: + handler: A function that takes 3 arguments, or None + """ + def messageOutputHandler(*args): + # In Qt4 bindings, message handlers are passed 2 arguments + # In Qt5 bindings, message handlers are passed 3 arguments + # The first argument is a QtMsgType + # The last argument is the message to be printed + # The Middle argument (if passed) is a QMessageLogContext + if len(args) == 3: + msgType, logContext, msg = args + elif len(args) == 2: + msgType, msg = args + logContext = None + else: + raise TypeError( + "handler expected 2 or 3 arguments, got {0}".format(len(args))) + + if isinstance(msg, bytes): + # In python 3, some bindings pass a bytestring, which cannot be + # used elsewhere. Decoding a python 2 or 3 bytestring object will + # consistently return a unicode object. + msg = msg.decode() + + handler(msgType, logContext, msg) + + passObject = messageOutputHandler if handler else handler + if Qt.IsPySide or Qt.IsPyQt4: + return Qt._QtCore.qInstallMsgHandler(passObject) + elif Qt.IsPySide2 or Qt.IsPyQt5: + return Qt._QtCore.qInstallMessageHandler(passObject) + + +def _getcpppointer(object): + if hasattr(Qt, "_shiboken2"): + return getattr(Qt, "_shiboken2").getCppPointer(object)[0] + elif hasattr(Qt, "_shiboken"): + return getattr(Qt, "_shiboken").getCppPointer(object)[0] + elif hasattr(Qt, "_sip"): + return getattr(Qt, "_sip").unwrapinstance(object) + raise AttributeError("'module' has no attribute 'getCppPointer'") + + +def _wrapinstance(ptr, base=None): + """Enable implicit cast of pointer to most suitable class + + This behaviour is available in sip per default. + + Based on http://nathanhorne.com/pyqtpyside-wrap-instance + + Usage: + This mechanism kicks in under these circumstances. + 1. Qt.py is using PySide 1 or 2. + 2. A `base` argument is not provided. + + See :func:`QtCompat.wrapInstance()` + + Arguments: + ptr (long): Pointer to QObject in memory + base (QObject, optional): Base class to wrap with. Defaults to QObject, + which should handle anything. + + """ + + assert isinstance(ptr, long), "Argument 'ptr' must be of type " + assert (base is None) or issubclass(base, Qt.QtCore.QObject), ( + "Argument 'base' must be of type ") + + if Qt.IsPyQt4 or Qt.IsPyQt5: + func = getattr(Qt, "_sip").wrapinstance + elif Qt.IsPySide2: + func = getattr(Qt, "_shiboken2").wrapInstance + elif Qt.IsPySide: + func = getattr(Qt, "_shiboken").wrapInstance + else: + raise AttributeError("'module' has no attribute 'wrapInstance'") + + if base is None: + q_object = func(long(ptr), Qt.QtCore.QObject) + meta_object = q_object.metaObject() + class_name = meta_object.className() + super_class_name = meta_object.superClass().className() + + if hasattr(Qt.QtWidgets, class_name): + base = getattr(Qt.QtWidgets, class_name) + + elif hasattr(Qt.QtWidgets, super_class_name): + base = getattr(Qt.QtWidgets, super_class_name) + + else: + base = Qt.QtCore.QObject + + return func(long(ptr), base) + + +def _isvalid(object): + """Check if the object is valid to use in Python runtime. + + Usage: + See :func:`QtCompat.isValid()` + + Arguments: + object (QObject): QObject to check the validity of. + + """ + + assert isinstance(object, Qt.QtCore.QObject) + + if hasattr(Qt, "_shiboken2"): + return getattr(Qt, "_shiboken2").isValid(object) + + elif hasattr(Qt, "_shiboken"): + return getattr(Qt, "_shiboken").isValid(object) + + elif hasattr(Qt, "_sip"): + return not getattr(Qt, "_sip").isdeleted(object) + + else: + raise AttributeError("'module' has no attribute isValid") + + +def _translate(context, sourceText, *args): + # In Qt4 bindings, translate can be passed 2 or 3 arguments + # In Qt5 bindings, translate can be passed 2 arguments + # The first argument is disambiguation[str] + # The last argument is n[int] + # The middle argument can be encoding[QtCore.QCoreApplication.Encoding] + if len(args) == 3: + disambiguation, encoding, n = args + elif len(args) == 2: + disambiguation, n = args + encoding = None + else: + raise TypeError( + "Expected 4 or 5 arguments, got {0}.".format(len(args) + 2)) + + if hasattr(Qt.QtCore, "QCoreApplication"): + app = getattr(Qt.QtCore, "QCoreApplication") + else: + raise NotImplementedError( + "Missing QCoreApplication implementation for {binding}".format( + binding=Qt.__binding__, + ) + ) + if Qt.__binding__ in ("PySide2", "PyQt5"): + sanitized_args = [context, sourceText, disambiguation, n] + else: + sanitized_args = [ + context, + sourceText, + disambiguation, + encoding or app.CodecForTr, + n + ] + return app.translate(*sanitized_args) + + +def _loadUi(uifile, baseinstance=None): + """Dynamically load a user interface from the given `uifile` + + This function calls `uic.loadUi` if using PyQt bindings, + else it implements a comparable binding for PySide. + + Documentation: + http://pyqt.sourceforge.net/Docs/PyQt5/designer.html#PyQt5.uic.loadUi + + Arguments: + uifile (str): Absolute path to Qt Designer file. + baseinstance (QWidget): Instantiated QWidget or subclass thereof + + Return: + baseinstance if `baseinstance` is not `None`. Otherwise + return the newly created instance of the user interface. + + """ + if hasattr(Qt, "_uic"): + return Qt._uic.loadUi(uifile, baseinstance) + + elif hasattr(Qt, "_QtUiTools"): + # Implement `PyQt5.uic.loadUi` for PySide(2) + + class _UiLoader(Qt._QtUiTools.QUiLoader): + """Create the user interface in a base instance. + + Unlike `Qt._QtUiTools.QUiLoader` itself this class does not + create a new instance of the top-level widget, but creates the user + interface in an existing instance of the top-level class if needed. + + This mimics the behaviour of `PyQt5.uic.loadUi`. + + """ + + def __init__(self, baseinstance): + super(_UiLoader, self).__init__(baseinstance) + self.baseinstance = baseinstance + self.custom_widgets = {} + + def _loadCustomWidgets(self, etree): + """ + Workaround to pyside-77 bug. + + From QUiLoader doc we should use registerCustomWidget method. + But this causes a segfault on some platforms. + + Instead we fetch from customwidgets DOM node the python class + objects. Then we can directly use them in createWidget method. + """ + + def headerToModule(header): + """ + Translate a header file to python module path + foo/bar.h => foo.bar + """ + # Remove header extension + module = os.path.splitext(header)[0] + + # Replace os separator by python module separator + return module.replace("/", ".").replace("\\", ".") + + custom_widgets = etree.find("customwidgets") + + if custom_widgets is None: + return + + for custom_widget in custom_widgets: + class_name = custom_widget.find("class").text + header = custom_widget.find("header").text + module = importlib.import_module(headerToModule(header)) + self.custom_widgets[class_name] = getattr(module, + class_name) + + def load(self, uifile, *args, **kwargs): + from xml.etree.ElementTree import ElementTree + + # For whatever reason, if this doesn't happen then + # reading an invalid or non-existing .ui file throws + # a RuntimeError. + etree = ElementTree() + etree.parse(uifile) + self._loadCustomWidgets(etree) + + widget = Qt._QtUiTools.QUiLoader.load( + self, uifile, *args, **kwargs) + + # Workaround for PySide 1.0.9, see issue #208 + widget.parentWidget() + + return widget + + def createWidget(self, class_name, parent=None, name=""): + """Called for each widget defined in ui file + + Overridden here to populate `baseinstance` instead. + + """ + + if parent is None and self.baseinstance: + # Supposed to create the top-level widget, + # return the base instance instead + return self.baseinstance + + # For some reason, Line is not in the list of available + # widgets, but works fine, so we have to special case it here. + if class_name in self.availableWidgets() + ["Line"]: + # Create a new widget for child widgets + widget = Qt._QtUiTools.QUiLoader.createWidget(self, + class_name, + parent, + name) + elif class_name in self.custom_widgets: + widget = self.custom_widgets[class_name](parent) + else: + raise Exception("Custom widget '%s' not supported" + % class_name) + + if self.baseinstance: + # Set an attribute for the new child widget on the base + # instance, just like PyQt5.uic.loadUi does. + setattr(self.baseinstance, name, widget) + + return widget + + widget = _UiLoader(baseinstance).load(uifile) + Qt.QtCore.QMetaObject.connectSlotsByName(widget) + + return widget + + else: + raise NotImplementedError("No implementation available for loadUi") + + +"""Misplaced members + +These members from the original submodule are misplaced relative PySide2 + +""" +_misplaced_members = { + "PySide2": { + "QtCore.QStringListModel": "QtCore.QStringListModel", + "QtGui.QStringListModel": "QtCore.QStringListModel", + "QtCore.Property": "QtCore.Property", + "QtCore.Signal": "QtCore.Signal", + "QtCore.Slot": "QtCore.Slot", + "QtCore.QAbstractProxyModel": "QtCore.QAbstractProxyModel", + "QtCore.QSortFilterProxyModel": "QtCore.QSortFilterProxyModel", + "QtCore.QItemSelection": "QtCore.QItemSelection", + "QtCore.QItemSelectionModel": "QtCore.QItemSelectionModel", + "QtCore.QItemSelectionRange": "QtCore.QItemSelectionRange", + "QtUiTools.QUiLoader": ["QtCompat.loadUi", _loadUi], + "shiboken2.wrapInstance": ["QtCompat.wrapInstance", _wrapinstance], + "shiboken2.getCppPointer": ["QtCompat.getCppPointer", _getcpppointer], + "shiboken2.isValid": ["QtCompat.isValid", _isvalid], + "QtWidgets.qApp": "QtWidgets.QApplication.instance()", + "QtCore.QCoreApplication.translate": [ + "QtCompat.translate", _translate + ], + "QtWidgets.QApplication.translate": [ + "QtCompat.translate", _translate + ], + "QtCore.qInstallMessageHandler": [ + "QtCompat.qInstallMessageHandler", _qInstallMessageHandler + ], + "QtWidgets.QStyleOptionViewItem": "QtCompat.QStyleOptionViewItemV4", + }, + "PyQt5": { + "QtCore.pyqtProperty": "QtCore.Property", + "QtCore.pyqtSignal": "QtCore.Signal", + "QtCore.pyqtSlot": "QtCore.Slot", + "QtCore.QAbstractProxyModel": "QtCore.QAbstractProxyModel", + "QtCore.QSortFilterProxyModel": "QtCore.QSortFilterProxyModel", + "QtCore.QStringListModel": "QtCore.QStringListModel", + "QtCore.QItemSelection": "QtCore.QItemSelection", + "QtCore.QItemSelectionModel": "QtCore.QItemSelectionModel", + "QtCore.QItemSelectionRange": "QtCore.QItemSelectionRange", + "uic.loadUi": ["QtCompat.loadUi", _loadUi], + "sip.wrapinstance": ["QtCompat.wrapInstance", _wrapinstance], + "sip.unwrapinstance": ["QtCompat.getCppPointer", _getcpppointer], + "sip.isdeleted": ["QtCompat.isValid", _isvalid], + "QtWidgets.qApp": "QtWidgets.QApplication.instance()", + "QtCore.QCoreApplication.translate": [ + "QtCompat.translate", _translate + ], + "QtWidgets.QApplication.translate": [ + "QtCompat.translate", _translate + ], + "QtCore.qInstallMessageHandler": [ + "QtCompat.qInstallMessageHandler", _qInstallMessageHandler + ], + "QtWidgets.QStyleOptionViewItem": "QtCompat.QStyleOptionViewItemV4", + }, + "PySide": { + "QtGui.QAbstractProxyModel": "QtCore.QAbstractProxyModel", + "QtGui.QSortFilterProxyModel": "QtCore.QSortFilterProxyModel", + "QtGui.QStringListModel": "QtCore.QStringListModel", + "QtGui.QItemSelection": "QtCore.QItemSelection", + "QtGui.QItemSelectionModel": "QtCore.QItemSelectionModel", + "QtCore.Property": "QtCore.Property", + "QtCore.Signal": "QtCore.Signal", + "QtCore.Slot": "QtCore.Slot", + "QtGui.QItemSelectionRange": "QtCore.QItemSelectionRange", + "QtGui.QAbstractPrintDialog": "QtPrintSupport.QAbstractPrintDialog", + "QtGui.QPageSetupDialog": "QtPrintSupport.QPageSetupDialog", + "QtGui.QPrintDialog": "QtPrintSupport.QPrintDialog", + "QtGui.QPrintEngine": "QtPrintSupport.QPrintEngine", + "QtGui.QPrintPreviewDialog": "QtPrintSupport.QPrintPreviewDialog", + "QtGui.QPrintPreviewWidget": "QtPrintSupport.QPrintPreviewWidget", + "QtGui.QPrinter": "QtPrintSupport.QPrinter", + "QtGui.QPrinterInfo": "QtPrintSupport.QPrinterInfo", + "QtUiTools.QUiLoader": ["QtCompat.loadUi", _loadUi], + "shiboken.wrapInstance": ["QtCompat.wrapInstance", _wrapinstance], + "shiboken.unwrapInstance": ["QtCompat.getCppPointer", _getcpppointer], + "shiboken.isValid": ["QtCompat.isValid", _isvalid], + "QtGui.qApp": "QtWidgets.QApplication.instance()", + "QtCore.QCoreApplication.translate": [ + "QtCompat.translate", _translate + ], + "QtGui.QApplication.translate": [ + "QtCompat.translate", _translate + ], + "QtCore.qInstallMsgHandler": [ + "QtCompat.qInstallMessageHandler", _qInstallMessageHandler + ], + "QtGui.QStyleOptionViewItemV4": "QtCompat.QStyleOptionViewItemV4", + }, + "PyQt4": { + "QtGui.QAbstractProxyModel": "QtCore.QAbstractProxyModel", + "QtGui.QSortFilterProxyModel": "QtCore.QSortFilterProxyModel", + "QtGui.QItemSelection": "QtCore.QItemSelection", + "QtGui.QStringListModel": "QtCore.QStringListModel", + "QtGui.QItemSelectionModel": "QtCore.QItemSelectionModel", + "QtCore.pyqtProperty": "QtCore.Property", + "QtCore.pyqtSignal": "QtCore.Signal", + "QtCore.pyqtSlot": "QtCore.Slot", + "QtGui.QItemSelectionRange": "QtCore.QItemSelectionRange", + "QtGui.QAbstractPrintDialog": "QtPrintSupport.QAbstractPrintDialog", + "QtGui.QPageSetupDialog": "QtPrintSupport.QPageSetupDialog", + "QtGui.QPrintDialog": "QtPrintSupport.QPrintDialog", + "QtGui.QPrintEngine": "QtPrintSupport.QPrintEngine", + "QtGui.QPrintPreviewDialog": "QtPrintSupport.QPrintPreviewDialog", + "QtGui.QPrintPreviewWidget": "QtPrintSupport.QPrintPreviewWidget", + "QtGui.QPrinter": "QtPrintSupport.QPrinter", + "QtGui.QPrinterInfo": "QtPrintSupport.QPrinterInfo", + # "QtCore.pyqtSignature": "QtCore.Slot", + "uic.loadUi": ["QtCompat.loadUi", _loadUi], + "sip.wrapinstance": ["QtCompat.wrapInstance", _wrapinstance], + "sip.unwrapinstance": ["QtCompat.getCppPointer", _getcpppointer], + "sip.isdeleted": ["QtCompat.isValid", _isvalid], + "QtCore.QString": "str", + "QtGui.qApp": "QtWidgets.QApplication.instance()", + "QtCore.QCoreApplication.translate": [ + "QtCompat.translate", _translate + ], + "QtGui.QApplication.translate": [ + "QtCompat.translate", _translate + ], + "QtCore.qInstallMsgHandler": [ + "QtCompat.qInstallMessageHandler", _qInstallMessageHandler + ], + "QtGui.QStyleOptionViewItemV4": "QtCompat.QStyleOptionViewItemV4", + } +} + +""" Compatibility Members + +This dictionary is used to build Qt.QtCompat objects that provide a consistent +interface for obsolete members, and differences in binding return values. + +{ + "binding": { + "classname": { + "targetname": "binding_namespace", + } + } +} +""" +_compatibility_members = { + "PySide2": { + "QWidget": { + "grab": "QtWidgets.QWidget.grab", + }, + "QHeaderView": { + "sectionsClickable": "QtWidgets.QHeaderView.sectionsClickable", + "setSectionsClickable": + "QtWidgets.QHeaderView.setSectionsClickable", + "sectionResizeMode": "QtWidgets.QHeaderView.sectionResizeMode", + "setSectionResizeMode": + "QtWidgets.QHeaderView.setSectionResizeMode", + "sectionsMovable": "QtWidgets.QHeaderView.sectionsMovable", + "setSectionsMovable": "QtWidgets.QHeaderView.setSectionsMovable", + }, + "QFileDialog": { + "getOpenFileName": "QtWidgets.QFileDialog.getOpenFileName", + "getOpenFileNames": "QtWidgets.QFileDialog.getOpenFileNames", + "getSaveFileName": "QtWidgets.QFileDialog.getSaveFileName", + }, + }, + "PyQt5": { + "QWidget": { + "grab": "QtWidgets.QWidget.grab", + }, + "QHeaderView": { + "sectionsClickable": "QtWidgets.QHeaderView.sectionsClickable", + "setSectionsClickable": + "QtWidgets.QHeaderView.setSectionsClickable", + "sectionResizeMode": "QtWidgets.QHeaderView.sectionResizeMode", + "setSectionResizeMode": + "QtWidgets.QHeaderView.setSectionResizeMode", + "sectionsMovable": "QtWidgets.QHeaderView.sectionsMovable", + "setSectionsMovable": "QtWidgets.QHeaderView.setSectionsMovable", + }, + "QFileDialog": { + "getOpenFileName": "QtWidgets.QFileDialog.getOpenFileName", + "getOpenFileNames": "QtWidgets.QFileDialog.getOpenFileNames", + "getSaveFileName": "QtWidgets.QFileDialog.getSaveFileName", + }, + }, + "PySide": { + "QWidget": { + "grab": "QtWidgets.QPixmap.grabWidget", + }, + "QHeaderView": { + "sectionsClickable": "QtWidgets.QHeaderView.isClickable", + "setSectionsClickable": "QtWidgets.QHeaderView.setClickable", + "sectionResizeMode": "QtWidgets.QHeaderView.resizeMode", + "setSectionResizeMode": "QtWidgets.QHeaderView.setResizeMode", + "sectionsMovable": "QtWidgets.QHeaderView.isMovable", + "setSectionsMovable": "QtWidgets.QHeaderView.setMovable", + }, + "QFileDialog": { + "getOpenFileName": "QtWidgets.QFileDialog.getOpenFileName", + "getOpenFileNames": "QtWidgets.QFileDialog.getOpenFileNames", + "getSaveFileName": "QtWidgets.QFileDialog.getSaveFileName", + }, + }, + "PyQt4": { + "QWidget": { + "grab": "QtWidgets.QPixmap.grabWidget", + }, + "QHeaderView": { + "sectionsClickable": "QtWidgets.QHeaderView.isClickable", + "setSectionsClickable": "QtWidgets.QHeaderView.setClickable", + "sectionResizeMode": "QtWidgets.QHeaderView.resizeMode", + "setSectionResizeMode": "QtWidgets.QHeaderView.setResizeMode", + "sectionsMovable": "QtWidgets.QHeaderView.isMovable", + "setSectionsMovable": "QtWidgets.QHeaderView.setMovable", + }, + "QFileDialog": { + "getOpenFileName": "QtWidgets.QFileDialog.getOpenFileName", + "getOpenFileNames": "QtWidgets.QFileDialog.getOpenFileNames", + "getSaveFileName": "QtWidgets.QFileDialog.getSaveFileName", + }, + }, +} + + +def _apply_site_config(): + try: + import QtSiteConfig + except ImportError: + # If no QtSiteConfig module found, no modifications + # to _common_members are needed. + pass + else: + # Provide the ability to modify the dicts used to build Qt.py + if hasattr(QtSiteConfig, 'update_members'): + QtSiteConfig.update_members(_common_members) + + if hasattr(QtSiteConfig, 'update_misplaced_members'): + QtSiteConfig.update_misplaced_members(members=_misplaced_members) + + if hasattr(QtSiteConfig, 'update_compatibility_members'): + QtSiteConfig.update_compatibility_members( + members=_compatibility_members) + + +def _new_module(name): + return types.ModuleType(__name__ + "." + name) + + +def _import_sub_module(module, name): + """import_sub_module will mimic the function of importlib.import_module""" + module = __import__(module.__name__ + "." + name) + for level in name.split("."): + module = getattr(module, level) + return module + + +def _setup(module, extras): + """Install common submodules""" + + Qt.__binding__ = module.__name__ + + for name in list(_common_members) + extras: + try: + submodule = _import_sub_module( + module, name) + except ImportError: + try: + # For extra modules like sip and shiboken that may not be + # children of the binding. + submodule = __import__(name) + except ImportError: + continue + + setattr(Qt, "_" + name, submodule) + + if name not in extras: + # Store reference to original binding, + # but don't store speciality modules + # such as uic or QtUiTools + setattr(Qt, name, _new_module(name)) + + +def _reassign_misplaced_members(binding): + """Apply misplaced members from `binding` to Qt.py + + Arguments: + binding (dict): Misplaced members + + """ + + for src, dst in _misplaced_members[binding].items(): + dst_value = None + + src_parts = src.split(".") + src_module = src_parts[0] + src_member = None + if len(src_parts) > 1: + src_member = src_parts[1:] + + if isinstance(dst, (list, tuple)): + dst, dst_value = dst + + dst_parts = dst.split(".") + dst_module = dst_parts[0] + dst_member = None + if len(dst_parts) > 1: + dst_member = dst_parts[1] + + # Get the member we want to store in the namesapce. + if not dst_value: + try: + _part = getattr(Qt, "_" + src_module) + while src_member: + member = src_member.pop(0) + _part = getattr(_part, member) + dst_value = _part + except AttributeError: + # If the member we want to store in the namespace does not + # exist, there is no need to continue. This can happen if a + # request was made to rename a member that didn't exist, for + # example if QtWidgets isn't available on the target platform. + _log("Misplaced member has no source: {0}".format(src)) + continue + + try: + src_object = getattr(Qt, dst_module) + except AttributeError: + if dst_module not in _common_members: + # Only create the Qt parent module if its listed in + # _common_members. Without this check, if you remove QtCore + # from _common_members, the default _misplaced_members will add + # Qt.QtCore so it can add Signal, Slot, etc. + msg = 'Not creating missing member module "{m}" for "{c}"' + _log(msg.format(m=dst_module, c=dst_member)) + continue + # If the dst is valid but the Qt parent module does not exist + # then go ahead and create a new module to contain the member. + setattr(Qt, dst_module, _new_module(dst_module)) + src_object = getattr(Qt, dst_module) + # Enable direct import of the new module + sys.modules[__name__ + "." + dst_module] = src_object + + if not dst_value: + dst_value = getattr(Qt, "_" + src_module) + if src_member: + dst_value = getattr(dst_value, src_member) + + setattr( + src_object, + dst_member or dst_module, + dst_value + ) + + +def _build_compatibility_members(binding, decorators=None): + """Apply `binding` to QtCompat + + Arguments: + binding (str): Top level binding in _compatibility_members. + decorators (dict, optional): Provides the ability to decorate the + original Qt methods when needed by a binding. This can be used + to change the returned value to a standard value. The key should + be the classname, the value is a dict where the keys are the + target method names, and the values are the decorator functions. + + """ + + decorators = decorators or dict() + + # Allow optional site-level customization of the compatibility members. + # This method does not need to be implemented in QtSiteConfig. + try: + import QtSiteConfig + except ImportError: + pass + else: + if hasattr(QtSiteConfig, 'update_compatibility_decorators'): + QtSiteConfig.update_compatibility_decorators(binding, decorators) + + _QtCompat = type("QtCompat", (object,), {}) + + for classname, bindings in _compatibility_members[binding].items(): + attrs = {} + for target, binding in bindings.items(): + namespaces = binding.split('.') + try: + src_object = getattr(Qt, "_" + namespaces[0]) + except AttributeError as e: + _log("QtCompat: AttributeError: %s" % e) + # Skip reassignment of non-existing members. + # This can happen if a request was made to + # rename a member that didn't exist, for example + # if QtWidgets isn't available on the target platform. + continue + + # Walk down any remaining namespace getting the object assuming + # that if the first namespace exists the rest will exist. + for namespace in namespaces[1:]: + src_object = getattr(src_object, namespace) + + # decorate the Qt method if a decorator was provided. + if target in decorators.get(classname, []): + # staticmethod must be called on the decorated method to + # prevent a TypeError being raised when the decorated method + # is called. + src_object = staticmethod( + decorators[classname][target](src_object)) + + attrs[target] = src_object + + # Create the QtCompat class and install it into the namespace + compat_class = type(classname, (_QtCompat,), attrs) + setattr(Qt.QtCompat, classname, compat_class) + + +def _pyside2(): + """Initialise PySide2 + + These functions serve to test the existence of a binding + along with set it up in such a way that it aligns with + the final step; adding members from the original binding + to Qt.py + + """ + + import PySide2 as module + extras = ["QtUiTools"] + try: + try: + # Before merge of PySide and shiboken + import shiboken2 + except ImportError: + # After merge of PySide and shiboken, May 2017 + from PySide2 import shiboken2 + extras.append("shiboken2") + except ImportError: + pass + + _setup(module, extras) + Qt.__binding_version__ = module.__version__ + + if hasattr(Qt, "_shiboken2"): + Qt.QtCompat.wrapInstance = _wrapinstance + Qt.QtCompat.getCppPointer = _getcpppointer + Qt.QtCompat.delete = shiboken2.delete + + if hasattr(Qt, "_QtUiTools"): + Qt.QtCompat.loadUi = _loadUi + + if hasattr(Qt, "_QtCore"): + Qt.__qt_version__ = Qt._QtCore.qVersion() + Qt.QtCompat.dataChanged = ( + lambda self, topleft, bottomright, roles=None: + self.dataChanged.emit(topleft, bottomright, roles or []) + ) + + if hasattr(Qt, "_QtWidgets"): + Qt.QtCompat.setSectionResizeMode = \ + Qt._QtWidgets.QHeaderView.setSectionResizeMode + + _reassign_misplaced_members("PySide2") + _build_compatibility_members("PySide2") + + +def _pyside(): + """Initialise PySide""" + + import PySide as module + extras = ["QtUiTools"] + try: + try: + # Before merge of PySide and shiboken + import shiboken + except ImportError: + # After merge of PySide and shiboken, May 2017 + from PySide import shiboken + extras.append("shiboken") + except ImportError: + pass + + _setup(module, extras) + Qt.__binding_version__ = module.__version__ + + if hasattr(Qt, "_shiboken"): + Qt.QtCompat.wrapInstance = _wrapinstance + Qt.QtCompat.getCppPointer = _getcpppointer + Qt.QtCompat.delete = shiboken.delete + + if hasattr(Qt, "_QtUiTools"): + Qt.QtCompat.loadUi = _loadUi + + if hasattr(Qt, "_QtGui"): + setattr(Qt, "QtWidgets", _new_module("QtWidgets")) + setattr(Qt, "_QtWidgets", Qt._QtGui) + if hasattr(Qt._QtGui, "QX11Info"): + setattr(Qt, "QtX11Extras", _new_module("QtX11Extras")) + Qt.QtX11Extras.QX11Info = Qt._QtGui.QX11Info + + Qt.QtCompat.setSectionResizeMode = Qt._QtGui.QHeaderView.setResizeMode + + if hasattr(Qt, "_QtCore"): + Qt.__qt_version__ = Qt._QtCore.qVersion() + Qt.QtCompat.dataChanged = ( + lambda self, topleft, bottomright, roles=None: + self.dataChanged.emit(topleft, bottomright) + ) + + _reassign_misplaced_members("PySide") + _build_compatibility_members("PySide") + + +def _pyqt5(): + """Initialise PyQt5""" + + import PyQt5 as module + extras = ["uic"] + + try: + import sip + extras += ["sip"] + except ImportError: + + # Relevant to PyQt5 5.11 and above + try: + from PyQt5 import sip + extras += ["sip"] + except ImportError: + sip = None + + _setup(module, extras) + if hasattr(Qt, "_sip"): + Qt.QtCompat.wrapInstance = _wrapinstance + Qt.QtCompat.getCppPointer = _getcpppointer + Qt.QtCompat.delete = sip.delete + + if hasattr(Qt, "_uic"): + Qt.QtCompat.loadUi = _loadUi + + if hasattr(Qt, "_QtCore"): + Qt.__binding_version__ = Qt._QtCore.PYQT_VERSION_STR + Qt.__qt_version__ = Qt._QtCore.QT_VERSION_STR + Qt.QtCompat.dataChanged = ( + lambda self, topleft, bottomright, roles=None: + self.dataChanged.emit(topleft, bottomright, roles or []) + ) + + if hasattr(Qt, "_QtWidgets"): + Qt.QtCompat.setSectionResizeMode = \ + Qt._QtWidgets.QHeaderView.setSectionResizeMode + + _reassign_misplaced_members("PyQt5") + _build_compatibility_members('PyQt5') + + +def _pyqt4(): + """Initialise PyQt4""" + + import sip + + # Validation of envivornment variable. Prevents an error if + # the variable is invalid since it's just a hint. + try: + hint = int(QT_SIP_API_HINT) + except TypeError: + hint = None # Variable was None, i.e. not set. + except ValueError: + raise ImportError("QT_SIP_API_HINT=%s must be a 1 or 2") + + for api in ("QString", + "QVariant", + "QDate", + "QDateTime", + "QTextStream", + "QTime", + "QUrl"): + try: + sip.setapi(api, hint or 2) + except AttributeError: + raise ImportError("PyQt4 < 4.6 isn't supported by Qt.py") + except ValueError: + actual = sip.getapi(api) + if not hint: + raise ImportError("API version already set to %d" % actual) + else: + # Having provided a hint indicates a soft constraint, one + # that doesn't throw an exception. + sys.stderr.write( + "Warning: API '%s' has already been set to %d.\n" + % (api, actual) + ) + + import PyQt4 as module + extras = ["uic"] + try: + import sip + extras.append(sip.__name__) + except ImportError: + sip = None + + _setup(module, extras) + if hasattr(Qt, "_sip"): + Qt.QtCompat.wrapInstance = _wrapinstance + Qt.QtCompat.getCppPointer = _getcpppointer + Qt.QtCompat.delete = sip.delete + + if hasattr(Qt, "_uic"): + Qt.QtCompat.loadUi = _loadUi + + if hasattr(Qt, "_QtGui"): + setattr(Qt, "QtWidgets", _new_module("QtWidgets")) + setattr(Qt, "_QtWidgets", Qt._QtGui) + if hasattr(Qt._QtGui, "QX11Info"): + setattr(Qt, "QtX11Extras", _new_module("QtX11Extras")) + Qt.QtX11Extras.QX11Info = Qt._QtGui.QX11Info + + Qt.QtCompat.setSectionResizeMode = \ + Qt._QtGui.QHeaderView.setResizeMode + + if hasattr(Qt, "_QtCore"): + Qt.__binding_version__ = Qt._QtCore.PYQT_VERSION_STR + Qt.__qt_version__ = Qt._QtCore.QT_VERSION_STR + Qt.QtCompat.dataChanged = ( + lambda self, topleft, bottomright, roles=None: + self.dataChanged.emit(topleft, bottomright) + ) + + _reassign_misplaced_members("PyQt4") + + # QFileDialog QtCompat decorator + def _standardizeQFileDialog(some_function): + """Decorator that makes PyQt4 return conform to other bindings""" + def wrapper(*args, **kwargs): + ret = (some_function(*args, **kwargs)) + + # PyQt4 only returns the selected filename, force it to a + # standard return of the selected filename, and a empty string + # for the selected filter + return ret, '' + + wrapper.__doc__ = some_function.__doc__ + wrapper.__name__ = some_function.__name__ + + return wrapper + + decorators = { + "QFileDialog": { + "getOpenFileName": _standardizeQFileDialog, + "getOpenFileNames": _standardizeQFileDialog, + "getSaveFileName": _standardizeQFileDialog, + } + } + _build_compatibility_members('PyQt4', decorators) + + +def _none(): + """Internal option (used in installer)""" + + Mock = type("Mock", (), {"__getattr__": lambda Qt, attr: None}) + + Qt.__binding__ = "None" + Qt.__qt_version__ = "0.0.0" + Qt.__binding_version__ = "0.0.0" + Qt.QtCompat.loadUi = lambda uifile, baseinstance=None: None + Qt.QtCompat.setSectionResizeMode = lambda *args, **kwargs: None + + for submodule in _common_members.keys(): + setattr(Qt, submodule, Mock()) + setattr(Qt, "_" + submodule, Mock()) + + +def _log(text): + if QT_VERBOSE: + sys.stdout.write(text + "\n") + + +def _convert(lines): + """Convert compiled .ui file from PySide2 to Qt.py + + Arguments: + lines (list): Each line of of .ui file + + Usage: + >> with open("myui.py") as f: + .. lines = _convert(f.readlines()) + + """ + + def parse(line): + line = line.replace("from PySide2 import", "from Qt import QtCompat,") + line = line.replace("QtWidgets.QApplication.translate", + "QtCompat.translate") + if "QtCore.SIGNAL" in line: + raise NotImplementedError("QtCore.SIGNAL is missing from PyQt5 " + "and so Qt.py does not support it: you " + "should avoid defining signals inside " + "your ui files.") + return line + + parsed = list() + for line in lines: + line = parse(line) + parsed.append(line) + + return parsed + + +def _cli(args): + """Qt.py command-line interface""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--convert", + help="Path to compiled Python module, e.g. my_ui.py") + parser.add_argument("--compile", + help="Accept raw .ui file and compile with native " + "PySide2 compiler.") + parser.add_argument("--stdout", + help="Write to stdout instead of file", + action="store_true") + parser.add_argument("--stdin", + help="Read from stdin instead of file", + action="store_true") + + args = parser.parse_args(args) + + if args.stdout: + raise NotImplementedError("--stdout") + + if args.stdin: + raise NotImplementedError("--stdin") + + if args.compile: + raise NotImplementedError("--compile") + + if args.convert: + sys.stdout.write("#\n" + "# WARNING: --convert is an ALPHA feature.\n#\n" + "# See https://github.com/mottosso/Qt.py/pull/132\n" + "# for details.\n" + "#\n") + + # + # ------> Read + # + with open(args.convert) as f: + lines = _convert(f.readlines()) + + backup = "%s_backup%s" % os.path.splitext(args.convert) + sys.stdout.write("Creating \"%s\"..\n" % backup) + shutil.copy(args.convert, backup) + + # + # <------ Write + # + with open(args.convert, "w") as f: + f.write("".join(lines)) + + sys.stdout.write("Successfully converted \"%s\"\n" % args.convert) + + +class MissingMember(object): + """ + A placeholder type for a missing Qt object not + included in Qt.py + + Args: + name (str): The name of the missing type + details (str): An optional custom error message + """ + ERR_TMPL = ("{} is not a common object across PySide2 " + "and the other Qt bindings. It is not included " + "as a common member in the Qt.py layer") + + def __init__(self, name, details=''): + self.__name = name + self.__err = self.ERR_TMPL.format(name) + + if details: + self.__err = "{}: {}".format(self.__err, details) + + def __repr__(self): + return "<{}: {}>".format(self.__class__.__name__, self.__name) + + def __getattr__(self, name): + raise NotImplementedError(self.__err) + + def __call__(self, *a, **kw): + raise NotImplementedError(self.__err) + + +def _install(): + # Default order (customise order and content via QT_PREFERRED_BINDING) + default_order = ("PySide2", "PyQt5", "PySide", "PyQt4") + preferred_order = list( + b for b in QT_PREFERRED_BINDING.split(os.pathsep) if b + ) + + order = preferred_order or default_order + + available = { + "PySide2": _pyside2, + "PyQt5": _pyqt5, + "PySide": _pyside, + "PyQt4": _pyqt4, + "None": _none + } + + _log("Order: '%s'" % "', '".join(order)) + + # Allow site-level customization of the available modules. + _apply_site_config() + + found_binding = False + for name in order: + _log("Trying %s" % name) + + try: + available[name]() + found_binding = True + break + + except ImportError as e: + _log("ImportError: %s" % e) + + except KeyError: + _log("ImportError: Preferred binding '%s' not found." % name) + + if not found_binding: + # If not binding were found, throw this error + raise ImportError("No Qt binding were found.") + + # Install individual members + for name, members in _common_members.items(): + try: + their_submodule = getattr(Qt, "_%s" % name) + except AttributeError: + continue + + our_submodule = getattr(Qt, name) + + # Enable import * + __all__.append(name) + + # Enable direct import of submodule, + # e.g. import Qt.QtCore + sys.modules[__name__ + "." + name] = our_submodule + + for member in members: + # Accept that a submodule may miss certain members. + try: + their_member = getattr(their_submodule, member) + except AttributeError: + _log("'%s.%s' was missing." % (name, member)) + continue + + setattr(our_submodule, member, their_member) + + # Install missing member placeholders + for name, members in _missing_members.items(): + our_submodule = getattr(Qt, name) + + for member in members: + + # If the submodule already has this member installed, + # either by the common members, or the site config, + # then skip installing this one over it. + if hasattr(our_submodule, member): + continue + + placeholder = MissingMember("{}.{}".format(name, member), + details=members[member]) + setattr(our_submodule, member, placeholder) + + # Enable direct import of QtCompat + sys.modules['Qt.QtCompat'] = Qt.QtCompat + + # Backwards compatibility + if hasattr(Qt.QtCompat, 'loadUi'): + Qt.QtCompat.load_ui = Qt.QtCompat.loadUi + + +_install() + +# Setup Binding Enum states +Qt.IsPySide2 = Qt.__binding__ == 'PySide2' +Qt.IsPyQt5 = Qt.__binding__ == 'PyQt5' +Qt.IsPySide = Qt.__binding__ == 'PySide' +Qt.IsPyQt4 = Qt.__binding__ == 'PyQt4' + +"""Augment QtCompat + +QtCompat contains wrappers and added functionality +to the original bindings, such as the CLI interface +and otherwise incompatible members between bindings, +such as `QHeaderView.setSectionResizeMode`. + +""" + +Qt.QtCompat._cli = _cli +Qt.QtCompat._convert = _convert + +# Enable command-line interface +if __name__ == "__main__": + _cli(sys.argv[1:]) + + +# The MIT License (MIT) +# +# Copyright (c) 2016-2017 Marcus Ottosson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# In PySide(2), loadUi does not exist, so we implement it +# +# `_UiLoader` is adapted from the qtpy project, which was further influenced +# by qt-helpers which was released under a 3-clause BSD license which in turn +# is based on a solution at: +# +# - https://gist.github.com/cpbotha/1b42a20c8f3eb9bb7cb8 +# +# The License for this code is as follows: +# +# qt-helpers - a common front-end to various Qt modules +# +# Copyright (c) 2015, Chris Beaumont and Thomas Robitaille +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the +# distribution. +# * Neither the name of the Glue project nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Which itself was based on the solution at +# +# https://gist.github.com/cpbotha/1b42a20c8f3eb9bb7cb8 +# +# which was released under the MIT license: +# +# Copyright (c) 2011 Sebastian Wiesner +# Modifications by Charl Botha +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files +# (the "Software"),to deal in the Software without restriction, +# including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/python3.10libs/QtPy-2.4.1.dist-info/AUTHORS.md b/python3.10libs/QtPy-2.4.1.dist-info/AUTHORS.md new file mode 100644 index 0000000..bec8765 --- /dev/null +++ b/python3.10libs/QtPy-2.4.1.dist-info/AUTHORS.md @@ -0,0 +1,20 @@ +# Authors + +## Original Authors + +* [pyqode.qt](https://github.com/pyQode/pyqode.qt): Colin Duquesnoy ([@ColinDuquesnoy](https://github.com/ColinDuquesnoy)) +* [spyderlib.qt](https://github.com/spyder-ide/spyder/commits/2.3/spyderlib/qt): Pierre Raybaut ([@PierreRaybaut](https://github.com/PierreRaybaut)) +* [qt-helpers](https://github.com/glue-viz/qt-helpers): Thomas Robitaille ([@astrofrog](https://www.github.com/astrofrog)) + + +## Current Maintainers + +* Daniel Althviz ([@dalthviz](https://github.com/dalthviz)) +* Carlos Cordoba ([@ccordoba12](https://github.com/ccordoba12)) +* C.A.M. Gerlach ([@CAM-Gerlach](https://github.com/CAM-Gerlach)) +* Spyder Development Team ([Spyder-IDE](https://github.com/spyder-ide)) + + +## Contributors + +* [The QtPy Contributors](https://github.com/spyder-ide/qtpy/graphs/contributors) diff --git a/python3.10libs/QtPy-2.4.1.dist-info/INSTALLER b/python3.10libs/QtPy-2.4.1.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/python3.10libs/QtPy-2.4.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/python3.10libs/QtPy-2.4.1.dist-info/LICENSE.txt b/python3.10libs/QtPy-2.4.1.dist-info/LICENSE.txt new file mode 100644 index 0000000..fed9ab7 --- /dev/null +++ b/python3.10libs/QtPy-2.4.1.dist-info/LICENSE.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2011- QtPy contributors and others (see AUTHORS.md) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/python3.10libs/QtPy-2.4.1.dist-info/METADATA b/python3.10libs/QtPy-2.4.1.dist-info/METADATA new file mode 100644 index 0000000..95744eb --- /dev/null +++ b/python3.10libs/QtPy-2.4.1.dist-info/METADATA @@ -0,0 +1,266 @@ +Metadata-Version: 2.1 +Name: QtPy +Version: 2.4.1 +Summary: Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6). +Home-page: https://github.com/spyder-ide/qtpy +Author: Colin Duquesnoy and the Spyder Development Team +Author-email: spyder.python@gmail.com +Maintainer: Spyder Development Team and QtPy Contributors +Maintainer-email: spyder.python@gmail.com +License: MIT +Project-URL: Github, https://github.com/spyder-ide/qtpy +Project-URL: Bug Tracker, https://github.com/spyder-ide/qtpy/issues +Project-URL: Parent Project, https://www.spyder-ide.org/ +Keywords: qt PyQt5 PyQt6 PySide2 PySide6 +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: MacOS X +Classifier: Environment :: Win32 (MS Windows) +Classifier: Environment :: X11 Applications :: Qt +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Topic :: Software Development :: Libraries +Classifier: Topic :: Software Development :: User Interfaces +Classifier: Topic :: Software Development :: Widget Sets +Requires-Python: >=3.7 +Description-Content-Type: text/markdown +License-File: AUTHORS.md +License-File: LICENSE.txt +Requires-Dist: packaging +Provides-Extra: test +Requires-Dist: pytest !=7.0.0,!=7.0.1,>=6 ; extra == 'test' +Requires-Dist: pytest-cov >=3.0.0 ; extra == 'test' +Requires-Dist: pytest-qt ; extra == 'test' + +# QtPy: Abstraction layer for PyQt5/PySide2/PyQt6/PySide6 + +[![license](https://img.shields.io/pypi/l/qtpy.svg)](./LICENSE) +[![pypi version](https://img.shields.io/pypi/v/qtpy.svg)](https://pypi.org/project/QtPy/) +[![conda version](https://img.shields.io/conda/vn/conda-forge/qtpy.svg)](https://www.anaconda.com/download/) +[![download count](https://img.shields.io/conda/dn/conda-forge/qtpy.svg)](https://www.anaconda.com/download/) +[![OpenCollective Backers](https://opencollective.com/spyder/backers/badge.svg?color=blue)](#backers) +[![Join the chat at https://gitter.im/spyder-ide/public](https://badges.gitter.im/spyder-ide/spyder.svg)](https://gitter.im/spyder-ide/public)
+[![PyPI status](https://img.shields.io/pypi/status/qtpy.svg)](https://github.com/spyder-ide/qtpy) +[![Github build status](https://github.com/spyder-ide/qtpy/workflows/Tests/badge.svg)](https://github.com/spyder-ide/qtpy/actions) +[![Coverage Status](https://coveralls.io/repos/github/spyder-ide/qtpy/badge.svg?branch=master)](https://coveralls.io/github/spyder-ide/qtpy?branch=master) + +*Copyright © 2009– The Spyder Development Team* + + +## Description + +**QtPy** is a small abstraction layer that lets you +write applications using a single API call to either PyQt or PySide. + +It provides support for PyQt5, PySide2, PyQt6 and PySide6 using the Qt5 layout +(where the QtGui module has been split into QtGui and QtWidgets). + +Basically, you can write your code as if you were using PyQt or PySide directly, +but import Qt modules from `qtpy` instead of `PyQt5`, `PySide2`, `PyQt6` or `PySide6`. + +Accordingly, when porting code between different Qt bindings (PyQt vs PySide) or Qt versions (Qt5 vs Qt6), QtPy makes this much more painless, and allows you to easily and incrementally transition between them. QtPy handles incompatibilities and differences between bindings or Qt versions for you while keeping your project running, so you can focus more on your own code and less on keeping track of supporting every Qt version and binding. Furthermore, when you do want to upgrade or support new bindings, it allows you to update your project module by module rather than all at once. You can check out examples of this approach in projects using QtPy, like [git-cola](https://github.com/git-cola/git-cola/issues/232). + +### Attribution and acknowledgments + +This project is based on the [pyqode.qt](https://github.com/pyQode/pyqode.qt) +project and the [spyderlib.qt](https://github.com/spyder-ide/spyder/tree/2.3/spyderlib/qt) +module from the [Spyder](https://github.com/spyder-ide/spyder) project, and +also includes contributions adapted from +[qt-helpers](https://github.com/glue-viz/qt-helpers), developed as part of the +[glue](http://glueviz.org) project. + +Unlike `pyqode.qt` this is not a namespace package, so it is not tied +to a particular project or namespace. + + +### License + +This project is released under the [MIT license](LICENSE.txt). + + +### Requirements + +You need PyQt5, PySide2, PyQt6 or PySide6 installed in your system to make use +of QtPy. If several of these packages are found, PyQt5 is used by +default unless you set the `QT_API` environment variable. + +`QT_API` can take the following values: + +* `pyqt5` (to use PyQt5). +* `pyside2` (to use PySide2). +* `pyqt6` (to use PyQt6). +* `pyside6` (to use PySide6). + + +### Module aliases and constants + +* `QtCore.pyqtSignal`, `QtCore.pyqtSlot` and `QtCore.pyqtProperty` (available on PyQt5/6) are instead exposed as `QtCore.Signal`, `QtCore.Slot` and `QtCore.Property`, respectively, following the Qt5 module layout. + +* The Qt version being used can be checked with `QtCore.__version__` (instead of `QtCore.QT_VERSION_STR`) as well as from `qtpy.QT_VERSION`. + +* For PyQt6 enums, unscoped enum access was added by promoting the enums of the `QtCore`, `QtGui`, `QtTest` and `QtWidgets` modules. + +* Compatibility is added between the `QtGui` and `QtOpenGL` modules for the `QOpenGL*` classes. + +* To check the current binding version, you can use `qtpy.PYSIDE_VERSION` for PySide2/6 and `qtpy.PYQT_VERSION` for PyQt5/6. If the respective binding is not being used, the value of its attribute will be `None`. + +* To check the current selected binding, you can use `qtpy.API_NAME` + +* There are boolean values to check if Qt5/6, PyQt5/6 or PySide2/6 are being used: `qtpy.QT5`, `qtpy.QT6`, `qtpy.PYQT5`, `qtpy.PYQT6`, `qtpy.PYSIDE2` and `qtpy.PYSIDE6`. `True` if currently being used, `False` otherwise. + +#### Compat module + +In the `qtpy.compat` module, you can find wrappers for `QFileDialog` static methods and SIP/Shiboken functions, such as: + +* `QFileDialog.getExistingDirectory` wrapped with `qtpy.compat.getexistingdirectory` + +* `QFileDialog.getOpenFileName` wrapped with `qtpy.compat.getopenfilename` + +* `QFileDialog.getOpenFileNames` wrapped with `qtpy.compat.getopenfilenames` + +* `QFileDialog.getSaveFileName` wrapped with `qtpy.compat.getsavefilename` + +* `sip.isdeleted` and `shiboken.isValid` wrapped with `qtpy.compat.isalive` + + +### Installation + +```bash +pip install qtpy +``` + +or + +```bash +conda install qtpy +``` + + +### Type checker integration + +Type checkers have no knowledge of installed packages, so these tools require +additional configuration. + +A Command Line Interface (CLI) is offered to help with usage of QtPy (to get MyPy +and Pyright/Pylance args/configurations). + +#### Mypy + +The `mypy-args` command helps you to generate command line arguments for Mypy +that will enable it to process the QtPy source files with the same API +as QtPy itself would have selected. + +If you run + +```bash +qtpy mypy-args +``` + +QtPy will output a string of Mypy CLI args that will reflect the currently +selected Qt API. +For example, in an environment where PyQt5 is installed and selected +(or the default fallback, if no binding can be found in the environment), +this would output the following: + +```text +--always-true=PYQT5 --always-false=PYSIDE2 --always-false=PYQT6 --always-false=PYSIDE6 +``` + +Using Bash or a similar shell, this can be injected into +the Mypy command line invocation as follows: + +```bash +mypy --package mypackage $(qtpy mypy-args) +``` + +#### Pyright/Pylance + +In the case of Pyright, instead of runtime arguments, it is required to create a +config file for the project, called `pyrightconfig.json` or a `pyright` section +in `pyproject.toml`. See [here](https://github.com/microsoft/pyright/blob/main/docs/configuration.md) +for reference. In order to set this configuration, QtPy offers the `pyright-config` +command for guidance. + +If you run + +```bash +qtpy pyright-config +``` + +you will get the necessary configs to be included in your project files. If you don't +have them, it is recommended to create the latter. For example, in an environment where PyQt5 +is installed and selected (or the default fallback, if no binding can be found in the +environment), this would output the following: + +```text +pyrightconfig.json: +{"defineConstant": {"PYQT5": true, "PYSIDE2": false, "PYQT6": false, "PYSIDE6": false}} + +pyproject.toml: +[tool.pyright.defineConstant] +PYQT5 = true +PYSIDE2 = false +PYQT6 = false +PYSIDE6 = false +``` + +**Note**: These configurations are necessary for the correct usage of the default VSCode's type +checking feature while using QtPy in your source code. + + +## Testing matrix + +Currently, QtPy runs tests for different bindings on Linux, Windows and macOS, using +Python 3.7 and 3.11, and installing those bindings with `conda` and `pip`. For the +PyQt bindings, we also check the installation of extra packages via `pip`. + +Following this, the current test matrix looks something like this: + +| | Python | 3.7 | | 3.11 | | +|---------|-----------------|--------------------------------------------|------|--------------------|----------------------------| +| OS | Binding / manager | conda | pip | conda | pip | +| Linux | PyQt5 | 5.12 | 5.15 | 5.15 | 5.15 (with extras) | +| | PyQt6 | skip (unavailable) | 6.3 | skip (unavailable) | 6.5 (with extras) | +| | PySide2 | 5.13 | 5.12 | 5.15 | skip (no wheels available) | +| | PySide6 | 6.4 | 6.3 | 6.5 | 6.5 | +| Windows | PyQt5 | 5.9 | 5.15 | 5.15 | 5.15 (with extras) | +| | PyQt6 | skip (unavailable) | 6.2 | skip (unavailable) | 6.5 (with extras) | +| | PySide2 | 5.13 | 5.12 | 5.15 | skip (no wheels available) | +| | PySide6 | skip (test hang with 6.4. 6.5 unavailable) | 6.2 | 6.5 | 6.5 | +| MacOS | PyQt5 | 5.12 | 5.15 | 5.15 | 5.15 (with extras) | +| | PyQt6 | skip (unavailable) | 6.3 | skip (unavailable) | 6.5 (with extras) | +| | PySide2 | 5.13 | 5.12 | 5.15 | skip (no wheels available) | +| | PySide6 | 6.4 | 6.3 | 6.5 | 6.5 | + +**Note**: The mentioned extra packages for the PyQt bindings are the following: + +* `PyQt3D` and `PyQt6-3D` +* `PyQtChart` and `PyQt6-Charts` +* `PyQtDataVisualization` and `PyQt6-DataVisualization` +* `PyQtNetworkAuth` and `PyQt6-NetworkAuth` +* `PyQtPurchasing` +* `PyQtWebEngine` and `PyQt6-WebEngine` +* `QScintilla` and `PyQt6-QScintilla` + +## Contributing + +Everyone is welcome to contribute! See our [Contributing guide](CONTRIBUTING.md) for more details. + + +## Sponsors + +QtPy is funded thanks to the generous support of + + +[![Quansight](https://user-images.githubusercontent.com/16781833/142477716-53152d43-99a0-470c-a70b-c04bbfa97dd4.png)](https://www.quansight.com/)[![Numfocus](https://i2.wp.com/numfocus.org/wp-content/uploads/2017/07/NumFocus_LRG.png?fit=320%2C148&ssl=1)](https://numfocus.org/) + +and the donations we have received from our users around the world through [Open Collective](https://opencollective.com/spyder/): + +[![Sponsors](https://opencollective.com/spyder/sponsors.svg)](https://opencollective.com/spyder#support) diff --git a/python3.10libs/QtPy-2.4.1.dist-info/RECORD b/python3.10libs/QtPy-2.4.1.dist-info/RECORD new file mode 100644 index 0000000..64db7ae --- /dev/null +++ b/python3.10libs/QtPy-2.4.1.dist-info/RECORD @@ -0,0 +1,287 @@ +../../../bin/qtpy,sha256=EnXLa38f1oucNIn4c1pLBQvQshBMeRbvlwH1cU8Lx-Y,260 +QtPy-2.4.1.dist-info/AUTHORS.md,sha256=RhSdGR1L5Pz5RO00m0bcm_gIm9Kvp57AiTY9Y78a7HQ,816 +QtPy-2.4.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +QtPy-2.4.1.dist-info/LICENSE.txt,sha256=WexCJb04DjSaguZIJDf_lHXuscLmdqLRGFu1MxXUW_k,1113 +QtPy-2.4.1.dist-info/METADATA,sha256=xJGie-i45MwmIGKKA_BsLTD9VsOwVqa8ENyawklGI7A,12616 +QtPy-2.4.1.dist-info/RECORD,, +QtPy-2.4.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +QtPy-2.4.1.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92 +QtPy-2.4.1.dist-info/entry_points.txt,sha256=i9L-wdRb3aEbLLbb0V4yV3TMOrbHefwQmZPtQD1QwNY,44 +QtPy-2.4.1.dist-info/top_level.txt,sha256=P_I2N1064Bw78JAT09wjPsZSW63PhTkGce3YwzuqZEM,5 +qtpy/Qsci.py,sha256=SFP-AbB1gfDMo7Ak6aL8IpRa822Y7MMr-VCRplD-HaY,993 +qtpy/Qt3DAnimation.py,sha256=fm-X5lXqVJzLNTqvrXWQNfH03eF00wW86gOV1IFRcjg,1407 +qtpy/Qt3DCore.py,sha256=SY1iXGBtrG0D3bw1bS7DhgSOchyWCMKEEC40FgnxRmA,1362 +qtpy/Qt3DExtras.py,sha256=RW8gVTi12SOZySFU7aww6pKTH9TZOrU8XzGgM2kbnZg,1380 +qtpy/Qt3DInput.py,sha256=Qm7nvZvUnztWyrD8R0G3USAC8ZjdiZIV2NNbmIfD12I,1371 +qtpy/Qt3DLogic.py,sha256=HfC2rmHA6DMJmPVdig9L0f1H6AfTQA2Ow94cLuMB4Ks,1371 +qtpy/Qt3DRender.py,sha256=XwAZCPEq4YiL83BLSsgu9Q6WLjK3UiUz2_ehaIQ4G9U,1380 +qtpy/QtAxContainer.py,sha256=old8cahiot2FhERXe7CQR1OjW-N4ufzszQG4Npc2KXM,630 +qtpy/QtBluetooth.py,sha256=u1F1URx85Gab-IqUBk-XoeW99sOAjD6ATeFmShRXSm0,659 +qtpy/QtCharts.py,sha256=USPwj9wbqlTVsG_UIqGhSEjVAK2_pVDgb8xoZOFnlMY,1330 +qtpy/QtConcurrent.py,sha256=LyBpnonRoeSSbcgzAN6KYRvv_vB2a3BpkclV5qI6yzo,626 +qtpy/QtCore.py,sha256=9ka0FsiAjYy18raX8c1IlAFbCH1BMq4EfnaeeiXiJfo,6474 +qtpy/QtDBus.py,sha256=eWYrnWxqD1q4uA69yiQw39sZZiaTbCHYEwn1-TldyuQ,768 +qtpy/QtDataVisualization.py,sha256=jZz5RtqwJYCqX44IT9l3EuEuFEEon5eHvPbVCZPLTZk,1294 +qtpy/QtDesigner.py,sha256=tuAj2ly_WSPaJ6lLhHxSswXX6WEFzrDGJqE6DeArPg8,646 +qtpy/QtGui.py,sha256=8zf_BbR5bfKCMrI6KfI4weKF3pOc6sv1aJGqx2pRHKw,8624 +qtpy/QtHelp.py,sha256=jRtW0PKfSdjIn2F0zf7zhJDh5ALLvUTFrp31NY0RGd0,537 +qtpy/QtLocation.py,sha256=wVcTvhJJQju3eAwIcijex5hd3Uh9EoDICiO4_LTrcvE,676 +qtpy/QtMacExtras.py,sha256=Trp7FkMkbwtMU16vHkOLfh6-JpAFhkGVn9CkXl-g5Y8,868 +qtpy/QtMultimedia.py,sha256=ORiMN8H57lSWl_9r5--4XMeIXRUKRypmvAq4d36ADg4,590 +qtpy/QtMultimediaWidgets.py,sha256=QJgnatAwIaQVHLaQzZ2OUeD5lvWb7ByV1CTtqWjOlGI,625 +qtpy/QtNetwork.py,sha256=uYxl7GrM_nq89HeQfKXoA2w2QvQwHJDFwDqe0gm3bWU,616 +qtpy/QtNetworkAuth.py,sha256=fa6iCB5JVFXgpuyXrzTg5xRvWczCJxzzyJPjI0oTMSU,1096 +qtpy/QtNfc.py,sha256=ssXZhTwu--SG_8DIxrRzwG_C7U6vVX8MnDi9sCmztIg,629 +qtpy/QtOpenGL.py,sha256=blhkA-QvjWVT2IDe9g5U1pUlhAV74UVan05AqVn9tMM,2032 +qtpy/QtOpenGLWidgets.py,sha256=c7Kj7MWL7SJEXWH39kd0uSgLvTLZFbJDzXzvHGEiPEY,701 +qtpy/QtPdf.py,sha256=8YJ25jTBFaHottFSWvM08XRILwLVaqIXVK00jOgYaJA,725 +qtpy/QtPdfWidgets.py,sha256=xM7-07kbu2V9QzuNmlWkT3OCGBQ57TWZPqDiFtipRLI,760 +qtpy/QtPositioning.py,sha256=wlyoKl90i-lD4B0Tbag0t1kE8xjj_37qAwGplMoqQEw,581 +qtpy/QtPrintSupport.py,sha256=BRhsP1Qr69LUx6Dh-Na468RBPNt2kQDlreh3alckbAM,1181 +qtpy/QtPurchasing.py,sha256=JZh3T6kSR6d8MPb2eEXtoIHzwegkyTt2eoGd5GuB7aw,808 +qtpy/QtQml.py,sha256=jhM9kdgMT9jrlQGX2NpTEM6rKfjcxDiAC8zPI73FDy4,555 +qtpy/QtQuick.py,sha256=kBQDiOQmCgwsXIInj9wmMK-MwQix1i9KOMDONDPuo_c,565 +qtpy/QtQuick3D.py,sha256=P-eJMwz9k4CnMRtz3DvddheUNrMwk_IgUdU-I3ViyVA,649 +qtpy/QtQuickControls2.py,sha256=-7yvdZbibrx9wJrLsEyy35D7Y46YZpl6wk3E2NkPEs0,642 +qtpy/QtQuickWidgets.py,sha256=eUoQdOWv5qpS5_ruKuTWeTLflvTxmbtsZX32QGB4RN8,600 +qtpy/QtRemoteObjects.py,sha256=k1m1tXXf8KTJyAbhD5Yn3KDqYAybaNv9rt8NFo7ajm4,605 +qtpy/QtScxml.py,sha256=rUh0ux_CVbXnZMp7M7gp2WWGIYjOA08Um8r5fXgrmPI,606 +qtpy/QtSensors.py,sha256=OXldt4zne5aQkFQguPXM2NRInOVNY-ty7m67BY8ys3k,575 +qtpy/QtSerialPort.py,sha256=HuXULHflOzjGDxZVCe7xt3NeNV7iQ4bQNL5RjZdeBEo,623 +qtpy/QtSql.py,sha256=j49PlFTqShGdkdJXJM4Ouip_hT0i94BNr3nFucTlNYQ,1122 +qtpy/QtStateMachine.py,sha256=puCh5SvQnpmkZmiHocg26dzNZL9OJ-JrLgM8cgQgaeA,590 +qtpy/QtSvg.py,sha256=x-786tv5dgNGsWmCrKExBXsdK0y5v_tL704zY0ofP4Y,555 +qtpy/QtSvgWidgets.py,sha256=FuqVxHGQiFx3J-tQhn0zytdZDdf4TJ8iXpG2Ng_Xw8I,686 +qtpy/QtTest.py,sha256=IE14FyJ5wXLGWE_PfDYXptDR_yhjFsBkrNS_HpLcMxI,771 +qtpy/QtTextToSpeech.py,sha256=42zk_a1ZNM0196h1rnHHuwAxGKuTQmvw0aIThO6lEgs,696 +qtpy/QtUiTools.py,sha256=4evPwC99gIpGLPK0NxbHwpnJnAElVoSbclVACBGpIgI,614 +qtpy/QtWebChannel.py,sha256=QvsGHoRol-_CLGmVRBSMROPbuxYP7cnktW8fFsDFe2U,590 +qtpy/QtWebEngine.py,sha256=WrePBNXgS3QfDTZm52leUlgGwgyTxkcX4GUeZ6r8S1A,946 +qtpy/QtWebEngineCore.py,sha256=eHSxe7IRpmX8ukJ3mb4VVaa_I6vX8hjxatyZ-FSdLP4,1053 +qtpy/QtWebEngineQuick.py,sha256=ae--mnYknnO9hOj7rDpveJnffcMowkUaPNy3y1d8XXA,937 +qtpy/QtWebEngineWidgets.py,sha256=d7kff4ah0_5mCWNCaMGZmYSCr5JXHp_4xJknlmyPYZs,2044 +qtpy/QtWebSockets.py,sha256=6a1A2rvQZAZlKug6e0obPnmleJepu1Ubg0bEvVVcs0w,590 +qtpy/QtWidgets.py,sha256=wxJSKk-i8b1iXAgcG5Jiu2uJnQ6U7DEwe3_visV3eSE,7068 +qtpy/QtWinExtras.py,sha256=rGAPMlNG-d1SxjUW0wHz3Z9zPODSZYOYk_fdbtKcX_Q,828 +qtpy/QtX11Extras.py,sha256=WOuSmwAHPShgYHQ85O_1vguyJtCb4LoD8XGtnffkqxA,826 +qtpy/QtXml.py,sha256=FCdmIQB_DCdI_7XkpG4vhlSmXJG47Pt4atRj6b21i_I,555 +qtpy/QtXmlPatterns.py,sha256=FSRvBZ0BKxEjP6Kkv5nzO0GYKXO6bPVajiJxloyfQVM,691 +qtpy/__init__.py,sha256=ke665Hqdc2qFv8aqUoDwfT4JUbv7pak7YT_RLXEm0v4,10356 +qtpy/__main__.py,sha256=Qjix6fqUHeUIu7k9DMX2B59j2U3L9nyFYXdu2c9Nlo8,461 +qtpy/__pycache__/Qsci.cpython-310.pyc,, +qtpy/__pycache__/Qt3DAnimation.cpython-310.pyc,, +qtpy/__pycache__/Qt3DCore.cpython-310.pyc,, +qtpy/__pycache__/Qt3DExtras.cpython-310.pyc,, +qtpy/__pycache__/Qt3DInput.cpython-310.pyc,, +qtpy/__pycache__/Qt3DLogic.cpython-310.pyc,, +qtpy/__pycache__/Qt3DRender.cpython-310.pyc,, +qtpy/__pycache__/QtAxContainer.cpython-310.pyc,, +qtpy/__pycache__/QtBluetooth.cpython-310.pyc,, +qtpy/__pycache__/QtCharts.cpython-310.pyc,, +qtpy/__pycache__/QtConcurrent.cpython-310.pyc,, +qtpy/__pycache__/QtCore.cpython-310.pyc,, +qtpy/__pycache__/QtDBus.cpython-310.pyc,, +qtpy/__pycache__/QtDataVisualization.cpython-310.pyc,, +qtpy/__pycache__/QtDesigner.cpython-310.pyc,, +qtpy/__pycache__/QtGui.cpython-310.pyc,, +qtpy/__pycache__/QtHelp.cpython-310.pyc,, +qtpy/__pycache__/QtLocation.cpython-310.pyc,, +qtpy/__pycache__/QtMacExtras.cpython-310.pyc,, +qtpy/__pycache__/QtMultimedia.cpython-310.pyc,, +qtpy/__pycache__/QtMultimediaWidgets.cpython-310.pyc,, +qtpy/__pycache__/QtNetwork.cpython-310.pyc,, +qtpy/__pycache__/QtNetworkAuth.cpython-310.pyc,, +qtpy/__pycache__/QtNfc.cpython-310.pyc,, +qtpy/__pycache__/QtOpenGL.cpython-310.pyc,, +qtpy/__pycache__/QtOpenGLWidgets.cpython-310.pyc,, +qtpy/__pycache__/QtPdf.cpython-310.pyc,, +qtpy/__pycache__/QtPdfWidgets.cpython-310.pyc,, +qtpy/__pycache__/QtPositioning.cpython-310.pyc,, +qtpy/__pycache__/QtPrintSupport.cpython-310.pyc,, +qtpy/__pycache__/QtPurchasing.cpython-310.pyc,, +qtpy/__pycache__/QtQml.cpython-310.pyc,, +qtpy/__pycache__/QtQuick.cpython-310.pyc,, +qtpy/__pycache__/QtQuick3D.cpython-310.pyc,, +qtpy/__pycache__/QtQuickControls2.cpython-310.pyc,, +qtpy/__pycache__/QtQuickWidgets.cpython-310.pyc,, +qtpy/__pycache__/QtRemoteObjects.cpython-310.pyc,, +qtpy/__pycache__/QtScxml.cpython-310.pyc,, +qtpy/__pycache__/QtSensors.cpython-310.pyc,, +qtpy/__pycache__/QtSerialPort.cpython-310.pyc,, +qtpy/__pycache__/QtSql.cpython-310.pyc,, +qtpy/__pycache__/QtStateMachine.cpython-310.pyc,, +qtpy/__pycache__/QtSvg.cpython-310.pyc,, +qtpy/__pycache__/QtSvgWidgets.cpython-310.pyc,, +qtpy/__pycache__/QtTest.cpython-310.pyc,, +qtpy/__pycache__/QtTextToSpeech.cpython-310.pyc,, +qtpy/__pycache__/QtUiTools.cpython-310.pyc,, +qtpy/__pycache__/QtWebChannel.cpython-310.pyc,, +qtpy/__pycache__/QtWebEngine.cpython-310.pyc,, +qtpy/__pycache__/QtWebEngineCore.cpython-310.pyc,, +qtpy/__pycache__/QtWebEngineQuick.cpython-310.pyc,, +qtpy/__pycache__/QtWebEngineWidgets.cpython-310.pyc,, +qtpy/__pycache__/QtWebSockets.cpython-310.pyc,, +qtpy/__pycache__/QtWidgets.cpython-310.pyc,, +qtpy/__pycache__/QtWinExtras.cpython-310.pyc,, +qtpy/__pycache__/QtX11Extras.cpython-310.pyc,, +qtpy/__pycache__/QtXml.cpython-310.pyc,, +qtpy/__pycache__/QtXmlPatterns.cpython-310.pyc,, +qtpy/__pycache__/__init__.cpython-310.pyc,, +qtpy/__pycache__/__main__.cpython-310.pyc,, +qtpy/__pycache__/_utils.cpython-310.pyc,, +qtpy/__pycache__/cli.cpython-310.pyc,, +qtpy/__pycache__/compat.cpython-310.pyc,, +qtpy/__pycache__/enums_compat.cpython-310.pyc,, +qtpy/__pycache__/shiboken.cpython-310.pyc,, +qtpy/__pycache__/sip.cpython-310.pyc,, +qtpy/__pycache__/uic.cpython-310.pyc,, +qtpy/_utils.py,sha256=l6n0SBRQLi35r3mT42yugQlEzh5aFSdgQ47RafbfEsY,5187 +qtpy/cli.py,sha256=ohYT7O0KaO08uxQBNi3PnZLPvVMBd5laNYV2paLVedk,4945 +qtpy/compat.py,sha256=RJDkzd3HAno2i7KFeyVLOk7l-5XO0dOuOShW2xnjr2w,5624 +qtpy/enums_compat.py,sha256=WXnb5V17ATiBYWt15f5Y2PWfm1AdhFSmUasvlpWqge0,1454 +qtpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +qtpy/shiboken.py,sha256=Qv4JJpNB7yxkAYdX_X5pgywXWMupU1VNWqUKetDNSdg,584 +qtpy/sip.py,sha256=AjUR3ilW11aNpcjYhNWKfXwtgXqMxigl-L8XWk3bs4s,574 +qtpy/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +qtpy/tests/__pycache__/__init__.cpython-310.pyc,, +qtpy/tests/__pycache__/conftest.cpython-310.pyc,, +qtpy/tests/__pycache__/test_cli.cpython-310.pyc,, +qtpy/tests/__pycache__/test_compat.cpython-310.pyc,, +qtpy/tests/__pycache__/test_macos_checks.cpython-310.pyc,, +qtpy/tests/__pycache__/test_main.cpython-310.pyc,, +qtpy/tests/__pycache__/test_missing_optional_deps.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qdesktopservice_split.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qsci.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qt3danimation.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qt3dcore.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qt3dextras.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qt3dinput.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qt3dlogic.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qt3drender.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtaxcontainer.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtbluetooth.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtcharts.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtconcurrent.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtcore.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtdatavisualization.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtdbus.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtdesigner.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtgui.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qthelp.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtlocation.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtmacextras.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtmultimedia.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtmultimediawidgets.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtnetwork.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtnetworkauth.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtopengl.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtopenglwidgets.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtpdf.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtpdfwidgets.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtpositioning.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtprintsupport.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtpurchasing.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtqml.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtquick.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtquick3d.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtquickcontrols2.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtquickwidgets.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtremoteobjects.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtscxml.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtsensors.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtserialport.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtsql.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtstatemachine.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtsvg.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtsvgwidgets.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qttest.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qttexttospeech.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtuitools.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtwebchannel.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtwebenginecore.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtwebenginequick.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtwebenginewidgets.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtwebsockets.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtwidgets.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtwinextras.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtx11extras.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtxml.cpython-310.pyc,, +qtpy/tests/__pycache__/test_qtxmlpatterns.cpython-310.pyc,, +qtpy/tests/__pycache__/test_shiboken.cpython-310.pyc,, +qtpy/tests/__pycache__/test_sip.cpython-310.pyc,, +qtpy/tests/__pycache__/test_uic.cpython-310.pyc,, +qtpy/tests/__pycache__/utils.cpython-310.pyc,, +qtpy/tests/conftest.py,sha256=RospoL0sZ7uX1OGwsW0j4u-2YUuifg9d_zX6GbguJtw,1982 +qtpy/tests/optional_deps/__init__.py,sha256=6Pwy6j9NfOnXrEqUvoMvFu9q9ye5pIkBKVGPIvwmMjw,755 +qtpy/tests/optional_deps/__pycache__/__init__.cpython-310.pyc,, +qtpy/tests/optional_deps/__pycache__/optional_dep.cpython-310.pyc,, +qtpy/tests/optional_deps/optional_dep.py,sha256=HZFdG82EfcftfWuazvRjwg4KK9casodupFtGJGXyerU,98 +qtpy/tests/test.ui,sha256=rD0hQHmHcyBMVdLnfLW3BjrIXIBHNktgbR_2BKIGWFg,882 +qtpy/tests/test_cli.py,sha256=2G7q9nlMO4WfdtObr85S0XIKmQYnGXhV28wHhVInebs,3983 +qtpy/tests/test_compat.py,sha256=P_OBOj2cz14JE9xhu_nQXfYG0fV9ff1gkw2TJdMJp_s,649 +qtpy/tests/test_custom.ui,sha256=A3SQiTsjVB9aFPZ1FtyHOeX_cymvZ8ILbD-vbLRuCVc,1076 +qtpy/tests/test_macos_checks.py,sha256=OJQmmQM0qxDG1CfIwHQymO0AEStsZ8-MleQR25VyLVo,2935 +qtpy/tests/test_main.py,sha256=8iodVt0jcE_t2pAKtBH-muY6B5DTei7xVSp6XjfiIy8,3791 +qtpy/tests/test_missing_optional_deps.py,sha256=v-rZK4VpEsBVMFvgs62YI_p4UtILLS8q9-9HPMbdXRg,635 +qtpy/tests/test_qdesktopservice_split.py,sha256=9GJErD1EjIbiqsJW3r-B0aVzLFesC_pXnImcWItGGMA,575 +qtpy/tests/test_qsci.py,sha256=TSztZozQdWUrcEyxW2_ukBOEzSsx3CORz9-g41cy5l4,2629 +qtpy/tests/test_qt3danimation.py,sha256=zXT2Hsfa9Yf63DCzoJDybwtS09r5K6GUA1PsntSTfWA,991 +qtpy/tests/test_qt3dcore.py,sha256=UCQ0Ebhj6k9U_PBCjJh_aXzYcKTCLAgnsIsch582-Cg,2034 +qtpy/tests/test_qt3dextras.py,sha256=zZL13La7wY50QIBoQ_dmRaQAulduUOBOmqX-vQY_nVk,2064 +qtpy/tests/test_qt3dinput.py,sha256=EMzowzEeXIseyVNXuXkmJhW_yuIpMVhtdU6PA3892QI,1182 +qtpy/tests/test_qt3dlogic.py,sha256=Cgpb2Czej-wy8OC8tnIGvjzRzaGono_k16rx4Fvky4s,229 +qtpy/tests/test_qt3drender.py,sha256=CZK36OBlDQQ1GciEN-5QnucqKhueDs-XquHgp9HDtEY,5686 +qtpy/tests/test_qtaxcontainer.py,sha256=wrrnjYOhTQiSKctr9Dl5XibLs3U5oRfOb2XJlL5c_B8,247 +qtpy/tests/test_qtbluetooth.py,sha256=TplntM4wo_hpiQOKS8Y4b54HCFbApy3xGVIF6mh7p_c,521 +qtpy/tests/test_qtcharts.py,sha256=Sts9UkwMuPsE0o4hP84JDZXsrToXBFk42oJl3ehOTms,358 +qtpy/tests/test_qtconcurrent.py,sha256=s8UypAbKFs8cCEOMU9ItPtD900fBnzDZbFMFK5vCetY,566 +qtpy/tests/test_qtcore.py,sha256=Inu_RPQ_H25ug4yo54IgrKav1-x-33D5A5pmKuQHOds,6752 +qtpy/tests/test_qtdatavisualization.py,sha256=rGehCaCg2c-GhxC6247zth0ZSJJXK7rwDtzoR2RpBdI,4885 +qtpy/tests/test_qtdbus.py,sha256=vC-IajtPNJvU8G4A8334MrjvfOugQR4SrGzw6_aC2lo,319 +qtpy/tests/test_qtdesigner.py,sha256=2eoubExPv2thuGino7Soe0UOljuTrY3iaegiNq-xmhs,1437 +qtpy/tests/test_qtgui.py,sha256=5q_suo5-7XA4QW7TV3ZW48XGCGqx67W2kBisaMRCKfs,6739 +qtpy/tests/test_qthelp.py,sha256=_maYHpvw3_6a0trDYe7h4iqvvGhSLAFg7IU0OPiJPHw,652 +qtpy/tests/test_qtlocation.py,sha256=-Uzp36XmCe99k2BCC20zvQdZrkvGDni3zcmZKIrNbGk,2256 +qtpy/tests/test_qtmacextras.py,sha256=wc7q_E_cWtPs4TglrUHdoyAs-ysfGm-kx23LxyvnZUY,622 +qtpy/tests/test_qtmultimedia.py,sha256=sPRt-l4X_kAzrRcj-1rHeaduA4a50N-uprGTCqkqfOA,456 +qtpy/tests/test_qtmultimediawidgets.py,sha256=SsNNAoYSZ1AW_7SqcHR_txWuMckntyoZOgJFLEavVtw,473 +qtpy/tests/test_qtnetwork.py,sha256=70TCShIeRO_zxGLXhS53TcrEybGpkmt1rAxBdZ2wn44,1799 +qtpy/tests/test_qtnetworkauth.py,sha256=qObJEP2HU7Pzebx29TqiN64tyfAAGtkaenE-h6p6NdU,592 +qtpy/tests/test_qtopengl.py,sha256=UIaczTS3ZhJxuYJl6TWz_LaOcYbi9yjyAqOBjyMdYBQ,1011 +qtpy/tests/test_qtopenglwidgets.py,sha256=EJr8DnxvObqGjgLIhj-ZRW-xlxKOFo08jktcxsEiS50,297 +qtpy/tests/test_qtpdf.py,sha256=zt84FP5nfqOC2EmaVHElQ5og2DTJVSf5VQ1Ebt7GGaI,244 +qtpy/tests/test_qtpdfwidgets.py,sha256=ZZI9IjKDfC5GHTts1l4uENFFMxy46oeBkpvxuijS7w0,194 +qtpy/tests/test_qtpositioning.py,sha256=oPyLhWER81lmSxR7kCMialv9eHvvjnWiaBJ4cQBNiIo,1399 +qtpy/tests/test_qtprintsupport.py,sha256=HWubiDiGXLM2dtvPY9shb0StxWxAk93M0TemM7QUreg,1154 +qtpy/tests/test_qtpurchasing.py,sha256=-jHqjXd8JJSuUXNPrlkn05QxuYFpsuy_sPVuHaR-2iQ,301 +qtpy/tests/test_qtqml.py,sha256=F7wHe85z7DdjuESY82KKaC4YKnhvrBPbme9cVY9KjWQ,1291 +qtpy/tests/test_qtquick.py,sha256=6xtm6mRo-iVygIG2tNrT9L4smixm1ywde37dE-yeI2g,2053 +qtpy/tests/test_qtquick3d.py,sha256=WG1iU839xfqczxAHPzcn0ovSI_AsXioKgSoHpLKrOY8,277 +qtpy/tests/test_qtquickcontrols2.py,sha256=t43bAX5sZiKK2j2mOYnmQ_pygLWyIIi9Q6BTTyAcJaI,217 +qtpy/tests/test_qtquickwidgets.py,sha256=jrxC3uhp_6Ef0A7sUDYdwVEjPgR78oHBWeuM3iwvM7M,164 +qtpy/tests/test_qtremoteobjects.py,sha256=eGoegKzSP0CNAeOd8PFNpj6K8PL-kRr59NxYDuTkMO4,478 +qtpy/tests/test_qtscxml.py,sha256=Uclwx5MbqaAqYRVv-5S1947fHassE6-Ss05_QKdwf4Y,294 +qtpy/tests/test_qtsensors.py,sha256=XX52GxasupGrgV9Gq0pjtMNK0jJf7rQnEYWwycknThU,294 +qtpy/tests/test_qtserialport.py,sha256=PuhzQfTzLjcU0B0l7TqAjvF0_Gkkpz963OrpvFus4bw,334 +qtpy/tests/test_qtsql.py,sha256=laT-EM1S2e2R17qvd-sIDOzLNo9gpsZ9lNCXHVOhmxo,2443 +qtpy/tests/test_qtstatemachine.py,sha256=vDQIOSEIgiHh_qR9e9JulXZyh1aKdVPZThi_u81MJ9U,644 +qtpy/tests/test_qtsvg.py,sha256=U6vSGNhwFEbcZ6Qowfi74CeUFdYbcf2V8EGEnPZqw7c,364 +qtpy/tests/test_qtsvgwidgets.py,sha256=tlotjZpWfeXdcKt_d2jh4YZzF-HAlj3Rm6KQsVumZXs,249 +qtpy/tests/test_qttest.py,sha256=HtsMH71EXeL9KcJym9sNjnBp-CS2zhVJe_PuAtB77zU,835 +qtpy/tests/test_qttexttospeech.py,sha256=6_GsjAdxTqk9vX8Fki5ZlPyi1hHOQ0mkF8V97PPLnVg,591 +qtpy/tests/test_qtuitools.py,sha256=fMldfjrWNs0BTqqzh0_KqiPZGPY1la9Gl4m53oQ8G0w,180 +qtpy/tests/test_qtwebchannel.py,sha256=qA0QC8ylUspJ6x7XdRBVqzIyPXZDpfsKIjBbE4YMGaI,220 +qtpy/tests/test_qtwebenginecore.py,sha256=1tfAV93VMoUH2p5ghm1R0V8zyxvU4RO1OGXJepUECsE,222 +qtpy/tests/test_qtwebenginequick.py,sha256=dS8T3A2tLHYETrK8ASyaR9WyMkoaZ9-ljq86D-vs0wg,398 +qtpy/tests/test_qtwebenginewidgets.py,sha256=iaRQ_qT__X3-kJzTkBHEysS2abcccHTFfe-yLlpdOZ4,795 +qtpy/tests/test_qtwebsockets.py,sha256=WP_Vp92ktHSwSei9SGXWQi_K6P2-Ms2gF-FNqU0_dR8,377 +qtpy/tests/test_qtwidgets.py,sha256=-_DQ2dlpRGKzQeFpL75GxYodEbRuRZRDEu3l-wsgQ8I,9827 +qtpy/tests/test_qtwinextras.py,sha256=UOy5g0y_tUJV9NIMHeVf4jbpmKKmeZO8FntcYkHMqhk,1176 +qtpy/tests/test_qtx11extras.py,sha256=a_c8ssSS3o7YNI2F643Ogg6a69SP06IvqvdDdPY3V0g,245 +qtpy/tests/test_qtxml.py,sha256=oXQEBxuRRDTyr1LuhM-Gm6_brz9xRm-mA7somTTprMA,835 +qtpy/tests/test_qtxmlpatterns.py,sha256=80bxjuKN0_S0yCtjSOfZd0rCeBv9F7zhbi08Xh3j0vM,1120 +qtpy/tests/test_shiboken.py,sha256=8M9AcjDhb--L7_0SrBqHpfCOwa91eqdoU75cPFuPCDs,340 +qtpy/tests/test_sip.py,sha256=THb-8-npVx0kVfu_ZqD637osk1q_ae7MkKXd8niQKd8,794 +qtpy/tests/test_uic.py,sha256=ceWc05CrOmHRNxBknISee_XV4oxH4Nnrt1ymtZAR3Fg,3654 +qtpy/tests/utils.py,sha256=ikYpL__NpZzIe8M4CPCUnjYw1I84ZBKz6IzgKGXn1e0,200 +qtpy/uic.py,sha256=8c9XqJ3pxW5TZ_RaEAHZmbO6I0vjsb9RkhpXEFkdw90,11647 diff --git a/python3.10libs/QtPy-2.4.1.dist-info/REQUESTED b/python3.10libs/QtPy-2.4.1.dist-info/REQUESTED new file mode 100644 index 0000000..e69de29 diff --git a/python3.10libs/QtPy-2.4.1.dist-info/WHEEL b/python3.10libs/QtPy-2.4.1.dist-info/WHEEL new file mode 100644 index 0000000..7e68873 --- /dev/null +++ b/python3.10libs/QtPy-2.4.1.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.41.2) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/python3.10libs/QtPy-2.4.1.dist-info/entry_points.txt b/python3.10libs/QtPy-2.4.1.dist-info/entry_points.txt new file mode 100644 index 0000000..1248032 --- /dev/null +++ b/python3.10libs/QtPy-2.4.1.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +qtpy = qtpy.__main__:main diff --git a/python3.10libs/QtPy-2.4.1.dist-info/top_level.txt b/python3.10libs/QtPy-2.4.1.dist-info/top_level.txt new file mode 100644 index 0000000..086fa2e --- /dev/null +++ b/python3.10libs/QtPy-2.4.1.dist-info/top_level.txt @@ -0,0 +1 @@ +qtpy diff --git a/python3.10libs/__init__.py b/python3.10libs/__init__.py new file mode 100644 index 0000000..a207b60 --- /dev/null +++ b/python3.10libs/__init__.py @@ -0,0 +1,2 @@ +from playhouse import sqlite_ext +from peewee import * \ No newline at end of file diff --git a/python3.10libs/peewee-3.17.0.dist-info/INSTALLER b/python3.10libs/peewee-3.17.0.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/python3.10libs/peewee-3.17.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/python3.10libs/peewee-3.17.0.dist-info/LICENSE b/python3.10libs/peewee-3.17.0.dist-info/LICENSE new file mode 100644 index 0000000..c752ab3 --- /dev/null +++ b/python3.10libs/peewee-3.17.0.dist-info/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2010 Charles Leifer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/python3.10libs/peewee-3.17.0.dist-info/METADATA b/python3.10libs/peewee-3.17.0.dist-info/METADATA new file mode 100644 index 0000000..3de1b98 --- /dev/null +++ b/python3.10libs/peewee-3.17.0.dist-info/METADATA @@ -0,0 +1,167 @@ +Metadata-Version: 2.1 +Name: peewee +Version: 3.17.0 +Summary: a little orm +Home-page: https://github.com/coleifer/peewee/ +Author: Charles Leifer +Author-email: coleifer@gmail.com +License: MIT License +Project-URL: Documentation, http://docs.peewee-orm.com +Project-URL: Source, https://github.com/coleifer/peewee +Platform: any +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 2 +Classifier: Programming Language :: Python :: 2.7 +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.4 +Classifier: Programming Language :: Python :: 3.5 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Topic :: Database +Classifier: Topic :: Software Development :: Libraries :: Python Modules +License-File: LICENSE + +.. image:: https://media.charlesleifer.com/blog/photos/peewee3-logo.png + +peewee +====== + +Peewee is a simple and small ORM. It has few (but expressive) concepts, making it easy to learn and intuitive to use. + +* a small, expressive ORM +* python 2.7+ and 3.4+ +* supports sqlite, mysql, postgresql and cockroachdb +* tons of `extensions `_ + +New to peewee? These may help: + +* `Quickstart `_ +* `Example twitter app `_ +* `Using peewee interactively `_ +* `Models and fields `_ +* `Querying `_ +* `Relationships and joins `_ + +Examples +-------- + +Defining models is similar to Django or SQLAlchemy: + +.. code-block:: python + + from peewee import * + import datetime + + + db = SqliteDatabase('my_database.db') + + class BaseModel(Model): + class Meta: + database = db + + class User(BaseModel): + username = CharField(unique=True) + + class Tweet(BaseModel): + user = ForeignKeyField(User, backref='tweets') + message = TextField() + created_date = DateTimeField(default=datetime.datetime.now) + is_published = BooleanField(default=True) + +Connect to the database and create tables: + +.. code-block:: python + + db.connect() + db.create_tables([User, Tweet]) + +Create a few rows: + +.. code-block:: python + + charlie = User.create(username='charlie') + huey = User(username='huey') + huey.save() + + # No need to set `is_published` or `created_date` since they + # will just use the default values we specified. + Tweet.create(user=charlie, message='My first tweet') + +Queries are expressive and composable: + +.. code-block:: python + + # A simple query selecting a user. + User.get(User.username == 'charlie') + + # Get tweets created by one of several users. + usernames = ['charlie', 'huey', 'mickey'] + users = User.select().where(User.username.in_(usernames)) + tweets = Tweet.select().where(Tweet.user.in_(users)) + + # We could accomplish the same using a JOIN: + tweets = (Tweet + .select() + .join(User) + .where(User.username.in_(usernames))) + + # How many tweets were published today? + tweets_today = (Tweet + .select() + .where( + (Tweet.created_date >= datetime.date.today()) & + (Tweet.is_published == True)) + .count()) + + # Paginate the user table and show me page 3 (users 41-60). + User.select().order_by(User.username).paginate(3, 20) + + # Order users by the number of tweets they've created: + tweet_ct = fn.Count(Tweet.id) + users = (User + .select(User, tweet_ct.alias('ct')) + .join(Tweet, JOIN.LEFT_OUTER) + .group_by(User) + .order_by(tweet_ct.desc())) + + # Do an atomic update (for illustrative purposes only, imagine a simple + # table for tracking a "count" associated with each URL). We don't want to + # naively get the save in two separate steps since this is prone to race + # conditions. + Counter.update(count=Counter.count + 1).where(Counter.url == request.url) + +Check out the `example twitter app `_. + +Learning more +------------- + +Check the `documentation `_ for more examples. + +Specific question? Come hang out in the #peewee channel on irc.libera.chat, or post to the mailing list, http://groups.google.com/group/peewee-orm . If you would like to report a bug, `create a new issue `_ on GitHub. + +Still want more info? +--------------------- + +.. image:: https://media.charlesleifer.com/blog/photos/wat.jpg + +I've written a number of blog posts about building applications and web-services with peewee (and usually Flask). If you'd like to see some real-life applications that use peewee, the following resources may be useful: + +* `Building a note-taking app with Flask and Peewee `_ as well as `Part 2 `_ and `Part 3 `_. +* `Analytics web service built with Flask and Peewee `_. +* `Personalized news digest (with a boolean query parser!) `_. +* `Structuring Flask apps with Peewee `_. +* `Creating a lastpass clone with Flask and Peewee `_. +* `Creating a bookmarking web-service that takes screenshots of your bookmarks `_. +* `Building a pastebin, wiki and a bookmarking service using Flask and Peewee `_. +* `Encrypted databases with Python and SQLCipher `_. +* `Dear Diary: An Encrypted, Command-Line Diary with Peewee `_. +* `Query Tree Structures in SQLite using Peewee and the Transitive Closure Extension `_. diff --git a/python3.10libs/peewee-3.17.0.dist-info/RECORD b/python3.10libs/peewee-3.17.0.dist-info/RECORD new file mode 100644 index 0000000..6767df5 --- /dev/null +++ b/python3.10libs/peewee-3.17.0.dist-info/RECORD @@ -0,0 +1,61 @@ +../../../bin/__pycache__/pwiz.cpython-310.pyc,, +../../../bin/pwiz.py,sha256=oSAF-z_-2grw9piy3r2NfBkZgK-0HzewsO4g5MdG1ug,8238 +__pycache__/peewee.cpython-310.pyc,, +__pycache__/pwiz.cpython-310.pyc,, +peewee-3.17.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +peewee-3.17.0.dist-info/LICENSE,sha256=N0AJYSWwhzWiR7jdCM2C4LqYTTvr2SIdN4V2Y35SQNo,1058 +peewee-3.17.0.dist-info/METADATA,sha256=Pul7QQXN86BKMkxA1AVyYWmTFIo_-sBvs5VvfMX74Mo,7354 +peewee-3.17.0.dist-info/RECORD,, +peewee-3.17.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +peewee-3.17.0.dist-info/WHEEL,sha256=Y5IjLrH3IwsncIuLdTVHz-QBJTvM9A9z8sCLa_1_XO4,105 +peewee-3.17.0.dist-info/top_level.txt,sha256=uV7RZ61bWm9zDrPVGNrGay4E4WDonEqtU2NPe5GGUWs,22 +peewee.py,sha256=bbh0_O7_b0IfTNH7u-UDKzfbW4-oC-V2AHZx3AJH7Zk,277280 +playhouse/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +playhouse/__pycache__/__init__.cpython-310.pyc,, +playhouse/__pycache__/apsw_ext.cpython-310.pyc,, +playhouse/__pycache__/cockroachdb.cpython-310.pyc,, +playhouse/__pycache__/dataset.cpython-310.pyc,, +playhouse/__pycache__/db_url.cpython-310.pyc,, +playhouse/__pycache__/fields.cpython-310.pyc,, +playhouse/__pycache__/flask_utils.cpython-310.pyc,, +playhouse/__pycache__/hybrid.cpython-310.pyc,, +playhouse/__pycache__/kv.cpython-310.pyc,, +playhouse/__pycache__/migrate.cpython-310.pyc,, +playhouse/__pycache__/mysql_ext.cpython-310.pyc,, +playhouse/__pycache__/pool.cpython-310.pyc,, +playhouse/__pycache__/postgres_ext.cpython-310.pyc,, +playhouse/__pycache__/psycopg3_ext.cpython-310.pyc,, +playhouse/__pycache__/reflection.cpython-310.pyc,, +playhouse/__pycache__/shortcuts.cpython-310.pyc,, +playhouse/__pycache__/signals.cpython-310.pyc,, +playhouse/__pycache__/sqlcipher_ext.cpython-310.pyc,, +playhouse/__pycache__/sqlite_changelog.cpython-310.pyc,, +playhouse/__pycache__/sqlite_ext.cpython-310.pyc,, +playhouse/__pycache__/sqlite_udf.cpython-310.pyc,, +playhouse/__pycache__/sqliteq.cpython-310.pyc,, +playhouse/__pycache__/test_utils.cpython-310.pyc,, +playhouse/_sqlite_ext.cpython-310-x86_64-linux-gnu.so,sha256=AbOME396-26V62qWIt6qBAWO8eHA5nH4Nh9cUnthK-8,1064704 +playhouse/_sqlite_udf.cpython-310-x86_64-linux-gnu.so,sha256=hi5LjWyv-HgRIKVlx97w4n9B8t0NHd07SZjntkv5BDM,383880 +playhouse/apsw_ext.py,sha256=YltEYvR-nTiqUgXvvstKgj8ydeI7RnpNkJ0tF52nMto,5018 +playhouse/cockroachdb.py,sha256=Z183VMQKfBFtAcmLTWuNhPBt-q7zpqh1qR0cN345VYw,9169 +playhouse/dataset.py,sha256=vzY_BxlcLNxi2miGqzJYlSFzdXnqt3G62pDHgpzOYbA,14488 +playhouse/db_url.py,sha256=JFhMZN268SbumQeQWpYZwCiKxZeXI76cbwXCCi9SAZw,4246 +playhouse/fields.py,sha256=dyO8d3l-3uq_XP7gmQRVdL6k70oUfAiXGhCXjAbayeY,1699 +playhouse/flask_utils.py,sha256=Ou26hV7d5NLzm-3oq4dJArs1VUAWh8vDclyuHDNrfTI,8197 +playhouse/hybrid.py,sha256=rRAPBImP2x61DoYh53mLm4JMPoNCLPFTd7_WIrq6_gU,1528 +playhouse/kv.py,sha256=tAG__zzti3zM98PXMoPypUobJF6zgH_CBq3x2VgTxeU,5608 +playhouse/migrate.py,sha256=EOeKJHpitM0thQEcGW6SpUPIn8mgwDsp9CB07eoPBmA,31379 +playhouse/mysql_ext.py,sha256=i2XwK2UDPVc9MMCfqGky27qqgQNvaTNJccZfTNlyYAE,3874 +playhouse/pool.py,sha256=nrp-zLRmzDQsbIVvT8r4GI6NwIP53Are2Sj8jm0uC3c,11476 +playhouse/postgres_ext.py,sha256=dfaV_1xbXqBgp4PKVMzx95My34dvJ6_4XdZeqHz3Q5g,14758 +playhouse/psycopg3_ext.py,sha256=u1FlKWz4V_Hm0pC8mi2nLMblzeRK20dsCvyHg-7ANPw,1173 +playhouse/reflection.py,sha256=Z4hu8suYzoBKRCojDYVBDp-QTdQumRnbY1DgtnaOQ3U,31129 +playhouse/shortcuts.py,sha256=bSNWCVJ29TVBPn2b0HHtffsgT9VZFP3CkrowyvsWU8I,11680 +playhouse/signals.py,sha256=FeHi7SxJ3ThG0tVes6M_x_K-wx93RhmsIqkSBy0YKkI,2511 +playhouse/sqlcipher_ext.py,sha256=ZO8zN6pM4_gA-5ML3P7cnlCop5aLU2w9JLFTktNznkU,3632 +playhouse/sqlite_changelog.py,sha256=c3FaYNZ-aWnDrZ9tgy9WxDRB6a8sFikDMajOzYmYeIQ,4793 +playhouse/sqlite_ext.py,sha256=hOWJApyRz0D2D0f1EroofcydvKX-sCZ6dBZ4J-srnDg,46750 +playhouse/sqlite_udf.py,sha256=wl356xkDKRq6rNPZOF2SPKm8UY4W1u7o1xGliqPeUQk,13665 +playhouse/sqliteq.py,sha256=90mvAPmriHUPCD9Crsb17uMFvh0WOyBv8KKtvhiTGe8,9982 +playhouse/test_utils.py,sha256=AAxfQsWFmoPdaq-JW13tWnPbZBIiU8qLBWfWiQgLQ3A,1854 +pwiz.py,sha256=7ctwTZ44cPsCCKTtd-8pPVMqNon9gTVy4K_Ue3gRLTQ,8193 diff --git a/python3.10libs/peewee-3.17.0.dist-info/REQUESTED b/python3.10libs/peewee-3.17.0.dist-info/REQUESTED new file mode 100644 index 0000000..e69de29 diff --git a/python3.10libs/peewee-3.17.0.dist-info/WHEEL b/python3.10libs/peewee-3.17.0.dist-info/WHEEL new file mode 100644 index 0000000..9f1aa26 --- /dev/null +++ b/python3.10libs/peewee-3.17.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.41.3) +Root-Is-Purelib: false +Tag: cp310-cp310-linux_x86_64 + diff --git a/python3.10libs/peewee-3.17.0.dist-info/top_level.txt b/python3.10libs/peewee-3.17.0.dist-info/top_level.txt new file mode 100644 index 0000000..1d507be --- /dev/null +++ b/python3.10libs/peewee-3.17.0.dist-info/top_level.txt @@ -0,0 +1,3 @@ +peewee +playhouse +pwiz diff --git a/python3.10libs/peewee.py b/python3.10libs/peewee.py new file mode 100644 index 0000000..13b2e21 --- /dev/null +++ b/python3.10libs/peewee.py @@ -0,0 +1,8084 @@ +from bisect import bisect_left +from bisect import bisect_right +from contextlib import contextmanager +from copy import deepcopy +from functools import wraps +from inspect import isclass +import calendar +import collections +import datetime +import decimal +import hashlib +import itertools +import logging +import operator +import re +import socket +import struct +import sys +import threading +import time +import uuid +import warnings +try: + from collections.abc import Mapping +except ImportError: + from collections import Mapping + +try: + from pysqlite3 import dbapi2 as pysq3 +except ImportError: + try: + from pysqlite2 import dbapi2 as pysq3 + except ImportError: + pysq3 = None +try: + import sqlite3 +except ImportError: + sqlite3 = pysq3 +else: + if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info: + sqlite3 = pysq3 +try: + from psycopg2cffi import compat + compat.register() +except ImportError: + pass +try: + import psycopg2 + from psycopg2 import extensions as pg_extensions + try: + from psycopg2 import errors as pg_errors + except ImportError: + pg_errors = None +except ImportError: + psycopg2 = pg_errors = None +try: + from psycopg2.extras import register_uuid as pg_register_uuid + pg_register_uuid() +except Exception: + pass + +mysql_passwd = False +try: + import pymysql as mysql +except ImportError: + try: + import MySQLdb as mysql + mysql_passwd = True + except ImportError: + mysql = None + + +__version__ = '3.17.0' +__all__ = [ + 'AnyField', + 'AsIs', + 'AutoField', + 'BareField', + 'BigAutoField', + 'BigBitField', + 'BigIntegerField', + 'BinaryUUIDField', + 'BitField', + 'BlobField', + 'BooleanField', + 'Case', + 'Cast', + 'CharField', + 'Check', + 'chunked', + 'Column', + 'CompositeKey', + 'Context', + 'Database', + 'DatabaseError', + 'DatabaseProxy', + 'DataError', + 'DateField', + 'DateTimeField', + 'DecimalField', + 'DeferredForeignKey', + 'DeferredThroughModel', + 'DJANGO_MAP', + 'DoesNotExist', + 'DoubleField', + 'DQ', + 'EXCLUDED', + 'Field', + 'FixedCharField', + 'FloatField', + 'fn', + 'ForeignKeyField', + 'IdentityField', + 'ImproperlyConfigured', + 'Index', + 'IntegerField', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'IPField', + 'JOIN', + 'ManyToManyField', + 'Model', + 'ModelIndex', + 'MySQLDatabase', + 'NotSupportedError', + 'OP', + 'OperationalError', + 'PostgresqlDatabase', + 'PrimaryKeyField', # XXX: Deprecated, change to AutoField. + 'prefetch', + 'PREFETCH_TYPE', + 'ProgrammingError', + 'Proxy', + 'QualifiedNames', + 'SchemaManager', + 'SmallIntegerField', + 'Select', + 'SQL', + 'SqliteDatabase', + 'Table', + 'TextField', + 'TimeField', + 'TimestampField', + 'Tuple', + 'UUIDField', + 'Value', + 'ValuesList', + 'Window', +] + +try: # Python 2.7+ + from logging import NullHandler +except ImportError: + class NullHandler(logging.Handler): + def emit(self, record): + pass + +logger = logging.getLogger('peewee') +logger.addHandler(NullHandler()) + + +if sys.version_info[0] == 2: + text_type = unicode + bytes_type = str + buffer_type = buffer + izip_longest = itertools.izip_longest + callable_ = callable + multi_types = (list, tuple, frozenset, set) + exec('def reraise(tp, value, tb=None): raise tp, value, tb') + def print_(s): + sys.stdout.write(s) + sys.stdout.write('\n') +else: + import builtins + try: + from collections.abc import Callable + except ImportError: + from collections import Callable + from functools import reduce + callable_ = lambda c: isinstance(c, Callable) + text_type = str + bytes_type = bytes + buffer_type = memoryview + basestring = str + long = int + multi_types = (list, tuple, frozenset, set, range) + print_ = getattr(builtins, 'print') + izip_longest = itertools.zip_longest + def reraise(tp, value, tb=None): + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + + +if sqlite3: + sqlite3.register_adapter(decimal.Decimal, str) + sqlite3.register_adapter(datetime.date, str) + sqlite3.register_adapter(datetime.time, str) + __sqlite_version__ = sqlite3.sqlite_version_info +else: + __sqlite_version__ = (0, 0, 0) + + +__date_parts__ = set(('year', 'month', 'day', 'hour', 'minute', 'second')) + +# Sqlite does not support the `date_part` SQL function, so we will define an +# implementation in python. +__sqlite_datetime_formats__ = ( + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d', + '%H:%M:%S', + '%H:%M:%S.%f', + '%H:%M') + +__sqlite_date_trunc__ = { + 'year': '%Y-01-01 00:00:00', + 'month': '%Y-%m-01 00:00:00', + 'day': '%Y-%m-%d 00:00:00', + 'hour': '%Y-%m-%d %H:00:00', + 'minute': '%Y-%m-%d %H:%M:00', + 'second': '%Y-%m-%d %H:%M:%S'} + +__mysql_date_trunc__ = __sqlite_date_trunc__.copy() +__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i:00' +__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S' + +def _sqlite_date_part(lookup_type, datetime_string): + assert lookup_type in __date_parts__ + if not datetime_string: + return + dt = format_date_time(datetime_string, __sqlite_datetime_formats__) + return getattr(dt, lookup_type) + +def _sqlite_date_trunc(lookup_type, datetime_string): + assert lookup_type in __sqlite_date_trunc__ + if not datetime_string: + return + dt = format_date_time(datetime_string, __sqlite_datetime_formats__) + return dt.strftime(__sqlite_date_trunc__[lookup_type]) + + +def __deprecated__(s): + warnings.warn(s, DeprecationWarning) + + +class attrdict(dict): + def __getattr__(self, attr): + try: + return self[attr] + except KeyError: + raise AttributeError(attr) + def __setattr__(self, attr, value): self[attr] = value + def __iadd__(self, rhs): self.update(rhs); return self + def __add__(self, rhs): d = attrdict(self); d.update(rhs); return d + +SENTINEL = object() + +#: Operations for use in SQL expressions. +OP = attrdict( + AND='AND', + OR='OR', + ADD='+', + SUB='-', + MUL='*', + DIV='/', + BIN_AND='&', + BIN_OR='|', + XOR='#', + MOD='%', + EQ='=', + LT='<', + LTE='<=', + GT='>', + GTE='>=', + NE='!=', + IN='IN', + NOT_IN='NOT IN', + IS='IS', + IS_NOT='IS NOT', + LIKE='LIKE', + ILIKE='ILIKE', + BETWEEN='BETWEEN', + REGEXP='REGEXP', + IREGEXP='IREGEXP', + CONCAT='||', + BITWISE_NEGATION='~') + +# To support "django-style" double-underscore filters, create a mapping between +# operation name and operation code, e.g. "__eq" == OP.EQ. +DJANGO_MAP = attrdict({ + 'eq': operator.eq, + 'lt': operator.lt, + 'lte': operator.le, + 'gt': operator.gt, + 'gte': operator.ge, + 'ne': operator.ne, + 'in': operator.lshift, + 'is': lambda l, r: Expression(l, OP.IS, r), + 'like': lambda l, r: Expression(l, OP.LIKE, r), + 'ilike': lambda l, r: Expression(l, OP.ILIKE, r), + 'regexp': lambda l, r: Expression(l, OP.REGEXP, r), +}) + +#: Mapping of field type to the data-type supported by the database. Databases +#: may override or add to this list. +FIELD = attrdict( + AUTO='INTEGER', + BIGAUTO='BIGINT', + BIGINT='BIGINT', + BLOB='BLOB', + BOOL='SMALLINT', + CHAR='CHAR', + DATE='DATE', + DATETIME='DATETIME', + DECIMAL='DECIMAL', + DEFAULT='', + DOUBLE='REAL', + FLOAT='REAL', + INT='INTEGER', + SMALLINT='SMALLINT', + TEXT='TEXT', + TIME='TIME', + UUID='TEXT', + UUIDB='BLOB', + VARCHAR='VARCHAR') + +#: Join helpers (for convenience) -- all join types are supported, this object +#: is just to help avoid introducing errors by using strings everywhere. +JOIN = attrdict( + INNER='INNER JOIN', + LEFT_OUTER='LEFT OUTER JOIN', + RIGHT_OUTER='RIGHT OUTER JOIN', + FULL='FULL JOIN', + FULL_OUTER='FULL OUTER JOIN', + CROSS='CROSS JOIN', + NATURAL='NATURAL JOIN', + LATERAL='LATERAL', + LEFT_LATERAL='LEFT JOIN LATERAL') + +# Row representations. +ROW = attrdict( + TUPLE=1, + DICT=2, + NAMED_TUPLE=3, + CONSTRUCTOR=4, + MODEL=5) + +# Query type to use with prefetch +PREFETCH_TYPE = attrdict( + WHERE=1, + JOIN=2) + +SCOPE_NORMAL = 1 +SCOPE_SOURCE = 2 +SCOPE_VALUES = 4 +SCOPE_CTE = 8 +SCOPE_COLUMN = 16 + +# Rules for parentheses around subqueries in compound select. +CSQ_PARENTHESES_NEVER = 0 +CSQ_PARENTHESES_ALWAYS = 1 +CSQ_PARENTHESES_UNNESTED = 2 + +# Regular expressions used to convert class names to snake-case table names. +# First regex handles acronym followed by word or initial lower-word followed +# by a capitalized word. e.g. APIResponse -> API_Response / fooBar -> foo_Bar. +# Second regex handles the normal case of two title-cased words. +SNAKE_CASE_STEP1 = re.compile('(.)_*([A-Z][a-z]+)') +SNAKE_CASE_STEP2 = re.compile('([a-z0-9])_*([A-Z])') + +# Helper functions that are used in various parts of the codebase. +MODEL_BASE = '_metaclass_helper_' + +def with_metaclass(meta, base=object): + return meta(MODEL_BASE, (base,), {}) + +def merge_dict(source, overrides): + merged = source.copy() + if overrides: + merged.update(overrides) + return merged + +def quote(path, quote_chars): + if len(path) == 1: + return path[0].join(quote_chars) + return '.'.join([part.join(quote_chars) for part in path]) + +is_model = lambda o: isclass(o) and issubclass(o, Model) + +def ensure_tuple(value): + if value is not None: + return value if isinstance(value, (list, tuple)) else (value,) + +def ensure_entity(value): + if value is not None: + return value if isinstance(value, Node) else Entity(value) + +def make_snake_case(s): + first = SNAKE_CASE_STEP1.sub(r'\1_\2', s) + return SNAKE_CASE_STEP2.sub(r'\1_\2', first).lower() + +def chunked(it, n): + marker = object() + for group in (list(g) for g in izip_longest(*[iter(it)] * n, + fillvalue=marker)): + if group[-1] is marker: + del group[group.index(marker):] + yield group + + +class _callable_context_manager(object): + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + with self: + return fn(*args, **kwargs) + return inner + + +class Proxy(object): + """ + Create a proxy or placeholder for another object. + """ + __slots__ = ('obj', '_callbacks') + + def __init__(self): + self._callbacks = [] + self.initialize(None) + + def initialize(self, obj): + self.obj = obj + for callback in self._callbacks: + callback(obj) + + def attach_callback(self, callback): + self._callbacks.append(callback) + return callback + + def passthrough(method): + def inner(self, *args, **kwargs): + if self.obj is None: + raise AttributeError('Cannot use uninitialized Proxy.') + return getattr(self.obj, method)(*args, **kwargs) + return inner + + # Allow proxy to be used as a context-manager. + __enter__ = passthrough('__enter__') + __exit__ = passthrough('__exit__') + + def __getattr__(self, attr): + if self.obj is None: + raise AttributeError('Cannot use uninitialized Proxy.') + return getattr(self.obj, attr) + + def __setattr__(self, attr, value): + if attr not in self.__slots__: + raise AttributeError('Cannot set attribute on proxy.') + return super(Proxy, self).__setattr__(attr, value) + + +class DatabaseProxy(Proxy): + """ + Proxy implementation specifically for proxying `Database` objects. + """ + __slots__ = ('obj', '_callbacks', '_Model') + + def connection_context(self): + return ConnectionContext(self) + def atomic(self, *args, **kwargs): + return _atomic(self, *args, **kwargs) + def manual_commit(self): + return _manual(self) + def transaction(self, *args, **kwargs): + return _transaction(self, *args, **kwargs) + def savepoint(self): + return _savepoint(self) + @property + def Model(self): + if not hasattr(self, '_Model'): + class Meta: database = self + self._Model = type('BaseModel', (Model,), {'Meta': Meta}) + return self._Model + + +class ModelDescriptor(object): pass + + +# SQL Generation. + + +class AliasManager(object): + __slots__ = ('_counter', '_current_index', '_mapping') + + def __init__(self): + # A list of dictionaries containing mappings at various depths. + self._counter = 0 + self._current_index = 0 + self._mapping = [] + self.push() + + @property + def mapping(self): + return self._mapping[self._current_index - 1] + + def add(self, source): + if source not in self.mapping: + self._counter += 1 + self[source] = 't%d' % self._counter + return self.mapping[source] + + def get(self, source, any_depth=False): + if any_depth: + for idx in reversed(range(self._current_index)): + if source in self._mapping[idx]: + return self._mapping[idx][source] + return self.add(source) + + def __getitem__(self, source): + return self.get(source) + + def __setitem__(self, source, alias): + self.mapping[source] = alias + + def push(self): + self._current_index += 1 + if self._current_index > len(self._mapping): + self._mapping.append({}) + + def pop(self): + if self._current_index == 1: + raise ValueError('Cannot pop() from empty alias manager.') + self._current_index -= 1 + + +class State(collections.namedtuple('_State', ('scope', 'parentheses', + 'settings'))): + def __new__(cls, scope=SCOPE_NORMAL, parentheses=False, **kwargs): + return super(State, cls).__new__(cls, scope, parentheses, kwargs) + + def __call__(self, scope=None, parentheses=None, **kwargs): + # Scope and settings are "inherited" (parentheses is not, however). + scope = self.scope if scope is None else scope + + # Try to avoid unnecessary dict copying. + if kwargs and self.settings: + settings = self.settings.copy() # Copy original settings dict. + settings.update(kwargs) # Update copy with overrides. + elif kwargs: + settings = kwargs + else: + settings = self.settings + return State(scope, parentheses, **settings) + + def __getattr__(self, attr_name): + return self.settings.get(attr_name) + + +def __scope_context__(scope): + @contextmanager + def inner(self, **kwargs): + with self(scope=scope, **kwargs): + yield self + return inner + + +class Context(object): + __slots__ = ('stack', '_sql', '_values', 'alias_manager', 'state') + + def __init__(self, **settings): + self.stack = [] + self._sql = [] + self._values = [] + self.alias_manager = AliasManager() + self.state = State(**settings) + + def as_new(self): + return Context(**self.state.settings) + + def column_sort_key(self, item): + return item[0].get_sort_key(self) + + @property + def scope(self): + return self.state.scope + + @property + def parentheses(self): + return self.state.parentheses + + @property + def subquery(self): + return self.state.subquery + + def __call__(self, **overrides): + if overrides and overrides.get('scope') == self.scope: + del overrides['scope'] + + self.stack.append(self.state) + self.state = self.state(**overrides) + return self + + scope_normal = __scope_context__(SCOPE_NORMAL) + scope_source = __scope_context__(SCOPE_SOURCE) + scope_values = __scope_context__(SCOPE_VALUES) + scope_cte = __scope_context__(SCOPE_CTE) + scope_column = __scope_context__(SCOPE_COLUMN) + + def __enter__(self): + if self.parentheses: + self.literal('(') + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.parentheses: + self.literal(')') + self.state = self.stack.pop() + + @contextmanager + def push_alias(self): + self.alias_manager.push() + yield + self.alias_manager.pop() + + def sql(self, obj): + if isinstance(obj, (Node, Context)): + return obj.__sql__(self) + elif is_model(obj): + return obj._meta.table.__sql__(self) + else: + return self.sql(Value(obj)) + + def literal(self, keyword): + self._sql.append(keyword) + return self + + def value(self, value, converter=None, add_param=True): + if converter: + value = converter(value) + elif converter is None and self.state.converter: + # Explicitly check for None so that "False" can be used to signify + # that no conversion should be applied. + value = self.state.converter(value) + + if isinstance(value, Node): + with self(converter=None): + return self.sql(value) + elif is_model(value): + # Under certain circumstances, we could end-up treating a model- + # class itself as a value. This check ensures that we drop the + # table alias into the query instead of trying to parameterize a + # model (for instance, passing a model as a function argument). + with self.scope_column(): + return self.sql(value) + + if self.state.value_literals: + return self.literal(_query_val_transform(value)) + + self._values.append(value) + return self.literal(self.state.param or '?') if add_param else self + + def __sql__(self, ctx): + ctx._sql.extend(self._sql) + ctx._values.extend(self._values) + return ctx + + def parse(self, node): + return self.sql(node).query() + + def query(self): + return ''.join(self._sql), self._values + + +def query_to_string(query): + # NOTE: this function is not exported by default as it might be misused -- + # and this misuse could lead to sql injection vulnerabilities. This + # function is intended for debugging or logging purposes ONLY. + db = getattr(query, '_database', None) + if db is not None: + ctx = db.get_sql_context() + else: + ctx = Context() + + sql, params = ctx.sql(query).query() + if not params: + return sql + + param = ctx.state.param or '?' + if param == '?': + sql = sql.replace('?', '%s') + + return sql % tuple(map(_query_val_transform, params)) + +def _query_val_transform(v): + # Interpolate parameters. + if isinstance(v, (text_type, datetime.datetime, datetime.date, + datetime.time)): + v = "'%s'" % v + elif isinstance(v, bytes_type): + try: + v = v.decode('utf8') + except UnicodeDecodeError: + v = v.decode('raw_unicode_escape') + v = "'%s'" % v + elif isinstance(v, int): + v = '%s' % int(v) # Also handles booleans -> 1 or 0. + elif v is None: + v = 'NULL' + else: + v = str(v) + return v + + +# AST. + + +class Node(object): + _coerce = True + __isabstractmethod__ = False # Avoid issue w/abc and __getattr__, eg fn.X + + def clone(self): + obj = self.__class__.__new__(self.__class__) + obj.__dict__ = self.__dict__.copy() + return obj + + def __sql__(self, ctx): + raise NotImplementedError + + @staticmethod + def copy(method): + def inner(self, *args, **kwargs): + clone = self.clone() + method(clone, *args, **kwargs) + return clone + return inner + + def coerce(self, _coerce=True): + if _coerce != self._coerce: + clone = self.clone() + clone._coerce = _coerce + return clone + return self + + def is_alias(self): + return False + + def unwrap(self): + return self + + +class ColumnFactory(object): + __slots__ = ('node',) + + def __init__(self, node): + self.node = node + + def __getattr__(self, attr): + return Column(self.node, attr) + __getitem__ = __getattr__ + + +class _DynamicColumn(object): + __slots__ = () + + def __get__(self, instance, instance_type=None): + if instance is not None: + return ColumnFactory(instance) # Implements __getattr__(). + return self + + +class _ExplicitColumn(object): + __slots__ = () + + def __get__(self, instance, instance_type=None): + if instance is not None: + raise AttributeError( + '%s specifies columns explicitly, and does not support ' + 'dynamic column lookups.' % instance) + return self + + +class Source(Node): + c = _DynamicColumn() + + def __init__(self, alias=None): + super(Source, self).__init__() + self._alias = alias + + @Node.copy + def alias(self, name): + self._alias = name + + def select(self, *columns): + if not columns: + columns = (SQL('*'),) + return Select((self,), columns) + + @property + def star(self): + return NodeList((QualifiedNames(self), SQL('.*')), glue='') + + def join(self, dest, join_type=JOIN.INNER, on=None): + return Join(self, dest, join_type, on) + + def left_outer_join(self, dest, on=None): + return Join(self, dest, JOIN.LEFT_OUTER, on) + + def cte(self, name, recursive=False, columns=None, materialized=None): + return CTE(name, self, recursive=recursive, columns=columns, + materialized=materialized) + + def get_sort_key(self, ctx): + if self._alias: + return (self._alias,) + return (ctx.alias_manager[self],) + + def apply_alias(self, ctx): + # If we are defining the source, include the "AS alias" declaration. An + # alias is created for the source if one is not already defined. + if ctx.scope == SCOPE_SOURCE: + if self._alias: + ctx.alias_manager[self] = self._alias + ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) + return ctx + + def apply_column(self, ctx): + if self._alias: + ctx.alias_manager[self] = self._alias + return ctx.sql(Entity(ctx.alias_manager[self])) + + +class _HashableSource(object): + def __init__(self, *args, **kwargs): + super(_HashableSource, self).__init__(*args, **kwargs) + self._update_hash() + + @Node.copy + def alias(self, name): + self._alias = name + self._update_hash() + + def _update_hash(self): + self._hash = self._get_hash() + + def _get_hash(self): + return hash((self.__class__, self._path, self._alias)) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + if isinstance(other, _HashableSource): + return self._hash == other._hash + return Expression(self, OP.EQ, other) + + def __ne__(self, other): + if isinstance(other, _HashableSource): + return self._hash != other._hash + return Expression(self, OP.NE, other) + + def _e(op): + def inner(self, rhs): + return Expression(self, op, rhs) + return inner + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + + +def __bind_database__(meth): + @wraps(meth) + def inner(self, *args, **kwargs): + result = meth(self, *args, **kwargs) + if self._database: + return result.bind(self._database) + return result + return inner + + +def __join__(join_type=JOIN.INNER, inverted=False): + def method(self, other): + if inverted: + self, other = other, self + return Join(self, other, join_type=join_type) + return method + + +class BaseTable(Source): + __and__ = __join__(JOIN.INNER) + __add__ = __join__(JOIN.LEFT_OUTER) + __sub__ = __join__(JOIN.RIGHT_OUTER) + __or__ = __join__(JOIN.FULL_OUTER) + __mul__ = __join__(JOIN.CROSS) + __rand__ = __join__(JOIN.INNER, inverted=True) + __radd__ = __join__(JOIN.LEFT_OUTER, inverted=True) + __rsub__ = __join__(JOIN.RIGHT_OUTER, inverted=True) + __ror__ = __join__(JOIN.FULL_OUTER, inverted=True) + __rmul__ = __join__(JOIN.CROSS, inverted=True) + + +class _BoundTableContext(object): + def __init__(self, table, database): + self.table = table + self.database = database + + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + with _BoundTableContext(self.table, self.database): + return fn(*args, **kwargs) + return inner + + def __enter__(self): + self._orig_database = self.table._database + self.table.bind(self.database) + if self.table._model is not None: + self.table._model.bind(self.database) + return self.table + + def __exit__(self, exc_type, exc_val, exc_tb): + self.table.bind(self._orig_database) + if self.table._model is not None: + self.table._model.bind(self._orig_database) + + +class Table(_HashableSource, BaseTable): + def __init__(self, name, columns=None, primary_key=None, schema=None, + alias=None, _model=None, _database=None): + self.__name__ = name + self._columns = columns + self._primary_key = primary_key + self._schema = schema + self._path = (schema, name) if schema else (name,) + self._model = _model + self._database = _database + super(Table, self).__init__(alias=alias) + + # Allow tables to restrict what columns are available. + if columns is not None: + self.c = _ExplicitColumn() + for column in columns: + setattr(self, column, Column(self, column)) + + if primary_key: + col_src = self if self._columns else self.c + self.primary_key = getattr(col_src, primary_key) + else: + self.primary_key = None + + def clone(self): + # Ensure a deep copy of the column instances. + return Table( + self.__name__, + columns=self._columns, + primary_key=self._primary_key, + schema=self._schema, + alias=self._alias, + _model=self._model, + _database=self._database) + + def bind(self, database=None): + self._database = database + return self + + def bind_ctx(self, database=None): + return _BoundTableContext(self, database) + + def _get_hash(self): + return hash((self.__class__, self._path, self._alias, self._model)) + + @__bind_database__ + def select(self, *columns): + if not columns and self._columns: + columns = [Column(self, column) for column in self._columns] + return Select((self,), columns) + + @__bind_database__ + def insert(self, insert=None, columns=None, **kwargs): + if kwargs: + insert = {} if insert is None else insert + src = self if self._columns else self.c + for key, value in kwargs.items(): + insert[getattr(src, key)] = value + return Insert(self, insert=insert, columns=columns) + + @__bind_database__ + def replace(self, insert=None, columns=None, **kwargs): + return (self + .insert(insert=insert, columns=columns) + .on_conflict('REPLACE')) + + @__bind_database__ + def update(self, update=None, **kwargs): + if kwargs: + update = {} if update is None else update + for key, value in kwargs.items(): + src = self if self._columns else self.c + update[getattr(src, key)] = value + return Update(self, update=update) + + @__bind_database__ + def delete(self): + return Delete(self) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + # Return the quoted table name. + return ctx.sql(Entity(*self._path)) + + if self._alias: + ctx.alias_manager[self] = self._alias + + if ctx.scope == SCOPE_SOURCE: + # Define the table and its alias. + return self.apply_alias(ctx.sql(Entity(*self._path))) + else: + # Refer to the table using the alias. + return self.apply_column(ctx) + + +class Join(BaseTable): + def __init__(self, lhs, rhs, join_type=JOIN.INNER, on=None, alias=None): + super(Join, self).__init__(alias=alias) + self.lhs = lhs + self.rhs = rhs + self.join_type = join_type + self._on = on + + def on(self, predicate): + self._on = predicate + return self + + def __sql__(self, ctx): + (ctx + .sql(self.lhs) + .literal(' %s ' % self.join_type) + .sql(self.rhs)) + if self._on is not None: + ctx.literal(' ON ').sql(self._on) + return ctx + + +class ValuesList(_HashableSource, BaseTable): + def __init__(self, values, columns=None, alias=None): + self._values = values + self._columns = columns + super(ValuesList, self).__init__(alias=alias) + + def _get_hash(self): + return hash((self.__class__, id(self._values), self._alias)) + + @Node.copy + def columns(self, *names): + self._columns = names + + def __sql__(self, ctx): + if self._alias: + ctx.alias_manager[self] = self._alias + + if ctx.scope == SCOPE_SOURCE or ctx.scope == SCOPE_NORMAL: + with ctx(parentheses=not ctx.parentheses): + ctx = (ctx + .literal('VALUES ') + .sql(CommaNodeList([ + EnclosedNodeList(row) for row in self._values]))) + + if ctx.scope == SCOPE_SOURCE: + ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) + if self._columns: + entities = [Entity(c) for c in self._columns] + ctx.sql(EnclosedNodeList(entities)) + else: + ctx.sql(Entity(ctx.alias_manager[self])) + + return ctx + + +class CTE(_HashableSource, Source): + def __init__(self, name, query, recursive=False, columns=None, + materialized=None): + self._alias = name + self._query = query + self._recursive = recursive + self._materialized = materialized + if columns is not None: + columns = [Entity(c) if isinstance(c, basestring) else c + for c in columns] + self._columns = columns + query._cte_list = () + super(CTE, self).__init__(alias=name) + + def select_from(self, *columns): + if not columns: + raise ValueError('select_from() must specify one or more columns ' + 'from the CTE to select.') + + query = (Select((self,), columns) + .with_cte(self) + .bind(self._query._database)) + try: + query = query.objects(self._query.model) + except AttributeError: + pass + return query + + def _get_hash(self): + return hash((self.__class__, self._alias, id(self._query))) + + def union_all(self, rhs): + clone = self._query.clone() + return CTE(self._alias, clone + rhs, self._recursive, self._columns) + __add__ = union_all + + def union(self, rhs): + clone = self._query.clone() + return CTE(self._alias, clone | rhs, self._recursive, self._columns) + __or__ = union + + def __sql__(self, ctx): + if ctx.scope != SCOPE_CTE: + return ctx.sql(Entity(self._alias)) + + with ctx.push_alias(): + ctx.alias_manager[self] = self._alias + ctx.sql(Entity(self._alias)) + + if self._columns: + ctx.literal(' ').sql(EnclosedNodeList(self._columns)) + ctx.literal(' AS ') + + if self._materialized: + ctx.literal('MATERIALIZED ') + elif self._materialized is False: + ctx.literal('NOT MATERIALIZED ') + + with ctx.scope_normal(parentheses=True): + ctx.sql(self._query) + return ctx + + +class ColumnBase(Node): + _converter = None + + @Node.copy + def converter(self, converter=None): + self._converter = converter + + def alias(self, alias): + if alias: + return Alias(self, alias) + return self + + def unalias(self): + return self + + def bind_to(self, dest): + return BindTo(self, dest) + + def cast(self, as_type): + return Cast(self, as_type) + + def asc(self, collation=None, nulls=None): + return Asc(self, collation=collation, nulls=nulls) + __pos__ = asc + + def desc(self, collation=None, nulls=None): + return Desc(self, collation=collation, nulls=nulls) + __neg__ = desc + + def __invert__(self): + return Negated(self) + + def _e(op, inv=False): + """ + Lightweight factory which returns a method that builds an Expression + consisting of the left-hand and right-hand operands, using `op`. + """ + def inner(self, rhs): + if inv: + return Expression(rhs, op, self) + return Expression(self, op, rhs) + return inner + __and__ = _e(OP.AND) + __or__ = _e(OP.OR) + + __add__ = _e(OP.ADD) + __sub__ = _e(OP.SUB) + __mul__ = _e(OP.MUL) + __div__ = __truediv__ = _e(OP.DIV) + __xor__ = _e(OP.XOR) + __radd__ = _e(OP.ADD, inv=True) + __rsub__ = _e(OP.SUB, inv=True) + __rmul__ = _e(OP.MUL, inv=True) + __rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True) + __rand__ = _e(OP.AND, inv=True) + __ror__ = _e(OP.OR, inv=True) + __rxor__ = _e(OP.XOR, inv=True) + + def __eq__(self, rhs): + op = OP.IS if rhs is None else OP.EQ + return Expression(self, op, rhs) + def __ne__(self, rhs): + op = OP.IS_NOT if rhs is None else OP.NE + return Expression(self, op, rhs) + + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lshift__ = _e(OP.IN) + __rshift__ = _e(OP.IS) + __mod__ = _e(OP.LIKE) + __pow__ = _e(OP.ILIKE) + + like = _e(OP.LIKE) + ilike = _e(OP.ILIKE) + + bin_and = _e(OP.BIN_AND) + bin_or = _e(OP.BIN_OR) + in_ = _e(OP.IN) + not_in = _e(OP.NOT_IN) + regexp = _e(OP.REGEXP) + iregexp = _e(OP.IREGEXP) + + # Special expressions. + def is_null(self, is_null=True): + op = OP.IS if is_null else OP.IS_NOT + return Expression(self, op, None) + + def _escape_like_expr(self, s, template): + if s.find('_') >= 0 or s.find('%') >= 0 or s.find('\\') >= 0: + s = s.replace('\\', '\\\\').replace('_', '\\_').replace('%', '\\%') + # Pass the expression and escape string as unconverted values, to + # avoid (e.g.) a Json field converter turning the escaped LIKE + # pattern into a Json-quoted string. + return NodeList(( + Value(template % s, converter=False), + SQL('ESCAPE'), + Value('\\', converter=False))) + return template % s + def contains(self, rhs): + if isinstance(rhs, Node): + rhs = Expression('%', OP.CONCAT, + Expression(rhs, OP.CONCAT, '%')) + else: + rhs = self._escape_like_expr(rhs, '%%%s%%') + return Expression(self, OP.ILIKE, rhs) + def startswith(self, rhs): + if isinstance(rhs, Node): + rhs = Expression(rhs, OP.CONCAT, '%') + else: + rhs = self._escape_like_expr(rhs, '%s%%') + return Expression(self, OP.ILIKE, rhs) + def endswith(self, rhs): + if isinstance(rhs, Node): + rhs = Expression('%', OP.CONCAT, rhs) + else: + rhs = self._escape_like_expr(rhs, '%%%s') + return Expression(self, OP.ILIKE, rhs) + def between(self, lo, hi): + return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi))) + def concat(self, rhs): + return StringExpression(self, OP.CONCAT, rhs) + def __getitem__(self, item): + if isinstance(item, slice): + if item.start is None or item.stop is None: + raise ValueError('BETWEEN range must have both a start- and ' + 'end-point.') + return self.between(item.start, item.stop) + return self == item + __iter__ = None # Prevent infinite loop. + + def distinct(self): + return NodeList((SQL('DISTINCT'), self)) + + def collate(self, collation): + return NodeList((self, SQL('COLLATE %s' % collation))) + + def get_sort_key(self, ctx): + return () + + +class Column(ColumnBase): + def __init__(self, source, name): + self.source = source + self.name = name + + def get_sort_key(self, ctx): + if ctx.scope == SCOPE_VALUES: + return (self.name,) + else: + return self.source.get_sort_key(ctx) + (self.name,) + + def __hash__(self): + return hash((self.source, self.name)) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + return ctx.sql(Entity(self.name)) + else: + with ctx.scope_column(): + return ctx.sql(self.source).literal('.').sql(Entity(self.name)) + + +class WrappedNode(ColumnBase): + def __init__(self, node): + self.node = node + self._coerce = getattr(node, '_coerce', True) + self._converter = getattr(node, '_converter', None) + + def is_alias(self): + return self.node.is_alias() + + def unwrap(self): + return self.node.unwrap() + + +class EntityFactory(object): + __slots__ = ('node',) + def __init__(self, node): + self.node = node + def __getattr__(self, attr): + return Entity(self.node, attr) + + +class _DynamicEntity(object): + __slots__ = () + def __get__(self, instance, instance_type=None): + if instance is not None: + return EntityFactory(instance._alias) # Implements __getattr__(). + return self + + +class Alias(WrappedNode): + c = _DynamicEntity() + + def __init__(self, node, alias): + super(Alias, self).__init__(node) + self._alias = alias + + def __hash__(self): + return hash(self._alias) + + @property + def name(self): + return self._alias + @name.setter + def name(self, value): + self._alias = value + + def alias(self, alias=None): + if alias is None: + return self.node + else: + return Alias(self.node, alias) + + def unalias(self): + return self.node + + def is_alias(self): + return True + + def __sql__(self, ctx): + if ctx.scope == SCOPE_SOURCE: + return (ctx + .sql(self.node) + .literal(' AS ') + .sql(Entity(self._alias))) + else: + return ctx.sql(Entity(self._alias)) + + +class BindTo(WrappedNode): + def __init__(self, node, dest): + super(BindTo, self).__init__(node) + self.dest = dest + + def __sql__(self, ctx): + return ctx.sql(self.node) + + +class Negated(WrappedNode): + def __invert__(self): + return self.node + + def __sql__(self, ctx): + return ctx.literal('NOT ').sql(self.node) + + +class BitwiseMixin(object): + def __and__(self, other): + return self.bin_and(other) + + def __or__(self, other): + return self.bin_or(other) + + def __sub__(self, other): + return self.bin_and(other.bin_negated()) + + def __invert__(self): + return BitwiseNegated(self) + + +class BitwiseNegated(BitwiseMixin, WrappedNode): + def __invert__(self): + return self.node + + def __sql__(self, ctx): + if ctx.state.operations: + op_sql = ctx.state.operations.get(self.op, self.op) + else: + op_sql = self.op + return ctx.literal(op_sql).sql(self.node) + + +class Value(ColumnBase): + def __init__(self, value, converter=None, unpack=True): + self.value = value + self.converter = converter + self.multi = unpack and isinstance(self.value, multi_types) + if self.multi: + self.values = [] + for item in self.value: + if isinstance(item, Node): + self.values.append(item) + else: + self.values.append(Value(item, self.converter)) + + def __sql__(self, ctx): + if self.multi: + # For multi-part values (e.g. lists of IDs). + return ctx.sql(EnclosedNodeList(self.values)) + + return ctx.value(self.value, self.converter) + + +class ValueLiterals(WrappedNode): + def __sql__(self, ctx): + with ctx(value_literals=True): + return ctx.sql(self.node) + + +def AsIs(value): + return Value(value, unpack=False) + + +class Cast(WrappedNode): + def __init__(self, node, cast): + super(Cast, self).__init__(node) + self._cast = cast + self._coerce = False + + def __sql__(self, ctx): + return (ctx + .literal('CAST(') + .sql(self.node) + .literal(' AS %s)' % self._cast)) + + +class Ordering(WrappedNode): + def __init__(self, node, direction, collation=None, nulls=None): + super(Ordering, self).__init__(node) + self.direction = direction + self.collation = collation + self.nulls = nulls + if nulls and nulls.lower() not in ('first', 'last'): + raise ValueError('Ordering nulls= parameter must be "first" or ' + '"last", got: %s' % nulls) + + def collate(self, collation=None): + return Ordering(self.node, self.direction, collation) + + def _null_ordering_case(self, nulls): + if nulls.lower() == 'last': + ifnull, notnull = 1, 0 + elif nulls.lower() == 'first': + ifnull, notnull = 0, 1 + else: + raise ValueError('unsupported value for nulls= ordering.') + return Case(None, ((self.node.is_null(), ifnull),), notnull) + + def __sql__(self, ctx): + if self.nulls and not ctx.state.nulls_ordering: + ctx.sql(self._null_ordering_case(self.nulls)).literal(', ') + + ctx.sql(self.node).literal(' %s' % self.direction) + if self.collation: + ctx.literal(' COLLATE %s' % self.collation) + if self.nulls and ctx.state.nulls_ordering: + ctx.literal(' NULLS %s' % self.nulls) + return ctx + + +def Asc(node, collation=None, nulls=None): + return Ordering(node, 'ASC', collation, nulls) + + +def Desc(node, collation=None, nulls=None): + return Ordering(node, 'DESC', collation, nulls) + + +class Expression(ColumnBase): + def __init__(self, lhs, op, rhs, flat=False): + self.lhs = lhs + self.op = op + self.rhs = rhs + self.flat = flat + + def __sql__(self, ctx): + overrides = {'parentheses': not self.flat, 'in_expr': True} + + # First attempt to unwrap the node on the left-hand-side, so that we + # can get at the underlying Field if one is present. + node = raw_node = self.lhs + if isinstance(raw_node, WrappedNode): + node = raw_node.unwrap() + + # Set up the appropriate converter if we have a field on the left side. + if isinstance(node, Field) and raw_node._coerce: + overrides['converter'] = node.db_value + overrides['is_fk_expr'] = isinstance(node, ForeignKeyField) + else: + overrides['converter'] = None + + if ctx.state.operations: + op_sql = ctx.state.operations.get(self.op, self.op) + else: + op_sql = self.op + + with ctx(**overrides): + # Postgresql reports an error for IN/NOT IN (), so convert to + # the equivalent boolean expression. + op_in = self.op == OP.IN or self.op == OP.NOT_IN + if op_in and ctx.as_new().parse(self.rhs)[0] == '()': + return ctx.literal('0 = 1' if self.op == OP.IN else '1 = 1') + + return (ctx + .sql(self.lhs) + .literal(' %s ' % op_sql) + .sql(self.rhs)) + + +class StringExpression(Expression): + def __add__(self, rhs): + return self.concat(rhs) + def __radd__(self, lhs): + return StringExpression(lhs, OP.CONCAT, self) + + +class Entity(ColumnBase): + def __init__(self, *path): + self._path = [part.replace('"', '""') for part in path if part] + + def __getattr__(self, attr): + return Entity(*self._path + [attr]) + + def get_sort_key(self, ctx): + return tuple(self._path) + + def __hash__(self): + return hash((self.__class__.__name__, tuple(self._path))) + + def __sql__(self, ctx): + return ctx.literal(quote(self._path, ctx.state.quote or '""')) + + +class SQL(ColumnBase): + def __init__(self, sql, params=None): + self.sql = sql + self.params = params + + def __sql__(self, ctx): + ctx.literal(self.sql) + if self.params: + for param in self.params: + ctx.value(param, False, add_param=False) + return ctx + + +def Check(constraint, name=None): + check = SQL('CHECK (%s)' % constraint) + if not name: + return check + return NodeList((SQL('CONSTRAINT'), Entity(name), check)) + + +class Function(ColumnBase): + no_coerce_functions = set(('sum', 'count', 'avg', 'cast', 'array_agg')) + + def __init__(self, name, arguments, coerce=True, python_value=None): + self.name = name + self.arguments = arguments + self._filter = None + self._order_by = None + self._python_value = python_value + if name and name.lower() in self.no_coerce_functions: + self._coerce = False + else: + self._coerce = coerce + + def __getattr__(self, attr): + def decorator(*args, **kwargs): + return Function(attr, args, **kwargs) + return decorator + + @Node.copy + def filter(self, where=None): + self._filter = where + + @Node.copy + def order_by(self, *ordering): + self._order_by = ordering + + @Node.copy + def python_value(self, func=None): + self._python_value = func + + def over(self, partition_by=None, order_by=None, start=None, end=None, + frame_type=None, window=None, exclude=None): + if isinstance(partition_by, Window) and window is None: + window = partition_by + + if window is not None: + node = WindowAlias(window) + else: + node = Window(partition_by=partition_by, order_by=order_by, + start=start, end=end, frame_type=frame_type, + exclude=exclude, _inline=True) + return NodeList((self, SQL('OVER'), node)) + + def __sql__(self, ctx): + ctx.literal(self.name) + if not len(self.arguments): + ctx.literal('()') + else: + args = self.arguments + + # If this is an ordered aggregate, then we will modify the last + # argument to append the ORDER BY ... clause. We do this to avoid + # double-wrapping any expression args in parentheses, as NodeList + # has a special check (hack) in place to work around this. + if self._order_by: + args = list(args) + args[-1] = NodeList((args[-1], SQL('ORDER BY'), + CommaNodeList(self._order_by))) + + with ctx(in_function=True, function_arg_count=len(self.arguments)): + ctx.sql(EnclosedNodeList([ + (arg if isinstance(arg, Node) else Value(arg, False)) + for arg in args])) + + if self._filter: + ctx.literal(' FILTER (WHERE ').sql(self._filter).literal(')') + return ctx + + +fn = Function(None, None) + + +class Window(Node): + # Frame start/end and frame exclusion. + CURRENT_ROW = SQL('CURRENT ROW') + GROUP = SQL('GROUP') + TIES = SQL('TIES') + NO_OTHERS = SQL('NO OTHERS') + + # Frame types. + GROUPS = 'GROUPS' + RANGE = 'RANGE' + ROWS = 'ROWS' + + def __init__(self, partition_by=None, order_by=None, start=None, end=None, + frame_type=None, extends=None, exclude=None, alias=None, + _inline=False): + super(Window, self).__init__() + if start is not None and not isinstance(start, SQL): + start = SQL(start) + if end is not None and not isinstance(end, SQL): + end = SQL(end) + + self.partition_by = ensure_tuple(partition_by) + self.order_by = ensure_tuple(order_by) + self.start = start + self.end = end + if self.start is None and self.end is not None: + raise ValueError('Cannot specify WINDOW end without start.') + self._alias = alias or 'w' + self._inline = _inline + self.frame_type = frame_type + self._extends = extends + self._exclude = exclude + + def alias(self, alias=None): + self._alias = alias or 'w' + return self + + @Node.copy + def as_range(self): + self.frame_type = Window.RANGE + + @Node.copy + def as_rows(self): + self.frame_type = Window.ROWS + + @Node.copy + def as_groups(self): + self.frame_type = Window.GROUPS + + @Node.copy + def extends(self, window=None): + self._extends = window + + @Node.copy + def exclude(self, frame_exclusion=None): + if isinstance(frame_exclusion, basestring): + frame_exclusion = SQL(frame_exclusion) + self._exclude = frame_exclusion + + @staticmethod + def following(value=None): + if value is None: + return SQL('UNBOUNDED FOLLOWING') + return SQL('%d FOLLOWING' % value) + + @staticmethod + def preceding(value=None): + if value is None: + return SQL('UNBOUNDED PRECEDING') + return SQL('%d PRECEDING' % value) + + def __sql__(self, ctx): + if ctx.scope != SCOPE_SOURCE and not self._inline: + ctx.literal(self._alias) + ctx.literal(' AS ') + + with ctx(parentheses=True): + parts = [] + if self._extends is not None: + ext = self._extends + if isinstance(ext, Window): + ext = SQL(ext._alias) + elif isinstance(ext, basestring): + ext = SQL(ext) + parts.append(ext) + if self.partition_by: + parts.extend(( + SQL('PARTITION BY'), + CommaNodeList(self.partition_by))) + if self.order_by: + parts.extend(( + SQL('ORDER BY'), + CommaNodeList(self.order_by))) + if self.start is not None and self.end is not None: + frame = self.frame_type or 'ROWS' + parts.extend(( + SQL('%s BETWEEN' % frame), + self.start, + SQL('AND'), + self.end)) + elif self.start is not None: + parts.extend((SQL(self.frame_type or 'ROWS'), self.start)) + elif self.frame_type is not None: + parts.append(SQL('%s UNBOUNDED PRECEDING' % self.frame_type)) + if self._exclude is not None: + parts.extend((SQL('EXCLUDE'), self._exclude)) + ctx.sql(NodeList(parts)) + return ctx + + +class WindowAlias(Node): + def __init__(self, window): + self.window = window + + def alias(self, window_alias): + self.window._alias = window_alias + return self + + def __sql__(self, ctx): + return ctx.literal(self.window._alias or 'w') + + +class ForUpdate(Node): + def __init__(self, expr, of=None, nowait=None): + expr = 'FOR UPDATE' if expr is True else expr + if expr.lower().endswith('nowait'): + expr = expr[:-7] # Strip off the "nowait" bit. + nowait = True + + self._expr = expr + if of is not None and not isinstance(of, (list, set, tuple)): + of = (of,) + self._of = of + self._nowait = nowait + + def __sql__(self, ctx): + ctx.literal(self._expr) + if self._of is not None: + ctx.literal(' OF ').sql(CommaNodeList(self._of)) + if self._nowait: + ctx.literal(' NOWAIT') + return ctx + + +def Case(predicate, expression_tuples, default=None): + clauses = [SQL('CASE')] + if predicate is not None: + clauses.append(predicate) + for expr, value in expression_tuples: + clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value)) + if default is not None: + clauses.extend((SQL('ELSE'), default)) + clauses.append(SQL('END')) + return NodeList(clauses) + + +class NodeList(ColumnBase): + def __init__(self, nodes, glue=' ', parens=False): + self.nodes = nodes + self.glue = glue + self.parens = parens + if parens and len(self.nodes) == 1 and \ + isinstance(self.nodes[0], Expression) and \ + not self.nodes[0].flat: + # Hack to avoid double-parentheses. + self.nodes = (self.nodes[0].clone(),) + self.nodes[0].flat = True + + def __sql__(self, ctx): + n_nodes = len(self.nodes) + if n_nodes == 0: + return ctx.literal('()') if self.parens else ctx + with ctx(parentheses=self.parens): + for i in range(n_nodes - 1): + ctx.sql(self.nodes[i]) + ctx.literal(self.glue) + ctx.sql(self.nodes[n_nodes - 1]) + return ctx + + +def CommaNodeList(nodes): + return NodeList(nodes, ', ') + + +def EnclosedNodeList(nodes): + return NodeList(nodes, ', ', True) + + +class _Namespace(Node): + __slots__ = ('_name',) + def __init__(self, name): + self._name = name + def __getattr__(self, attr): + return NamespaceAttribute(self, attr) + __getitem__ = __getattr__ + +class NamespaceAttribute(ColumnBase): + def __init__(self, namespace, attribute): + self._namespace = namespace + self._attribute = attribute + + def __sql__(self, ctx): + return (ctx + .literal(self._namespace._name + '.') + .sql(Entity(self._attribute))) + +EXCLUDED = _Namespace('EXCLUDED') + + +class DQ(ColumnBase): + def __init__(self, **query): + super(DQ, self).__init__() + self.query = query + self._negated = False + + @Node.copy + def __invert__(self): + self._negated = not self._negated + + def clone(self): + node = DQ(**self.query) + node._negated = self._negated + return node + +#: Represent a row tuple. +Tuple = lambda *a: EnclosedNodeList(a) + + +class QualifiedNames(WrappedNode): + def __sql__(self, ctx): + with ctx.scope_column(): + return ctx.sql(self.node) + + +def qualify_names(node): + # Search a node heirarchy to ensure that any column-like objects are + # referenced using fully-qualified names. + if isinstance(node, Expression): + return node.__class__(qualify_names(node.lhs), node.op, + qualify_names(node.rhs), node.flat) + elif isinstance(node, ColumnBase): + return QualifiedNames(node) + return node + + +class OnConflict(Node): + def __init__(self, action=None, update=None, preserve=None, where=None, + conflict_target=None, conflict_where=None, + conflict_constraint=None): + self._action = action + self._update = update + self._preserve = ensure_tuple(preserve) + self._where = where + if conflict_target is not None and conflict_constraint is not None: + raise ValueError('only one of "conflict_target" and ' + '"conflict_constraint" may be specified.') + self._conflict_target = ensure_tuple(conflict_target) + self._conflict_where = conflict_where + self._conflict_constraint = conflict_constraint + + def get_conflict_statement(self, ctx, query): + return ctx.state.conflict_statement(self, query) + + def get_conflict_update(self, ctx, query): + return ctx.state.conflict_update(self, query) + + @Node.copy + def preserve(self, *columns): + self._preserve = columns + + @Node.copy + def update(self, _data=None, **kwargs): + if _data and kwargs and not isinstance(_data, dict): + raise ValueError('Cannot mix data with keyword arguments in the ' + 'OnConflict update method.') + _data = _data or {} + if kwargs: + _data.update(kwargs) + self._update = _data + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def conflict_target(self, *constraints): + self._conflict_constraint = None + self._conflict_target = constraints + + @Node.copy + def conflict_where(self, *expressions): + if self._conflict_where is not None: + expressions = (self._conflict_where,) + expressions + self._conflict_where = reduce(operator.and_, expressions) + + @Node.copy + def conflict_constraint(self, constraint): + self._conflict_constraint = constraint + self._conflict_target = None + + +def database_required(method): + @wraps(method) + def inner(self, database=None, *args, **kwargs): + database = self._database if database is None else database + if not database: + raise InterfaceError('Query must be bound to a database in order ' + 'to call "%s".' % method.__name__) + return method(self, database, *args, **kwargs) + return inner + +# BASE QUERY INTERFACE. + +class BaseQuery(Node): + default_row_type = ROW.DICT + + def __init__(self, _database=None, **kwargs): + self._database = _database + self._cursor_wrapper = None + self._row_type = None + self._constructor = None + super(BaseQuery, self).__init__(**kwargs) + + def bind(self, database=None): + self._database = database + return self + + def clone(self): + query = super(BaseQuery, self).clone() + query._cursor_wrapper = None + return query + + @Node.copy + def dicts(self, as_dict=True): + self._row_type = ROW.DICT if as_dict else None + return self + + @Node.copy + def tuples(self, as_tuple=True): + self._row_type = ROW.TUPLE if as_tuple else None + return self + + @Node.copy + def namedtuples(self, as_namedtuple=True): + self._row_type = ROW.NAMED_TUPLE if as_namedtuple else None + return self + + @Node.copy + def objects(self, constructor=None): + self._row_type = ROW.CONSTRUCTOR if constructor else None + self._constructor = constructor + return self + + def _get_cursor_wrapper(self, cursor): + row_type = self._row_type or self.default_row_type + + if row_type == ROW.DICT: + return DictCursorWrapper(cursor) + elif row_type == ROW.TUPLE: + return CursorWrapper(cursor) + elif row_type == ROW.NAMED_TUPLE: + return NamedTupleCursorWrapper(cursor) + elif row_type == ROW.CONSTRUCTOR: + return ObjectCursorWrapper(cursor, self._constructor) + else: + raise ValueError('Unrecognized row type: "%s".' % row_type) + + def __sql__(self, ctx): + raise NotImplementedError + + def sql(self): + if self._database: + context = self._database.get_sql_context() + else: + context = Context() + return context.parse(self) + + @database_required + def execute(self, database): + return self._execute(database) + + def _execute(self, database): + raise NotImplementedError + + def iterator(self, database=None): + return iter(self.execute(database).iterator()) + + def _ensure_execution(self): + if not self._cursor_wrapper: + if not self._database: + raise ValueError('Query has not been executed.') + self.execute() + + def __iter__(self): + self._ensure_execution() + return iter(self._cursor_wrapper) + + def __getitem__(self, value): + self._ensure_execution() + if isinstance(value, slice): + index = value.stop + else: + index = value + if index is not None: + index = index + 1 if index >= 0 else 0 + self._cursor_wrapper.fill_cache(index) + return self._cursor_wrapper.row_cache[value] + + def __len__(self): + self._ensure_execution() + return len(self._cursor_wrapper) + + def __str__(self): + return query_to_string(self) + + +class RawQuery(BaseQuery): + def __init__(self, sql=None, params=None, **kwargs): + super(RawQuery, self).__init__(**kwargs) + self._sql = sql + self._params = params + + def __sql__(self, ctx): + ctx.literal(self._sql) + if self._params: + for param in self._params: + ctx.value(param, add_param=False) + return ctx + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + +class Query(BaseQuery): + def __init__(self, where=None, order_by=None, limit=None, offset=None, + **kwargs): + super(Query, self).__init__(**kwargs) + self._where = where + self._order_by = order_by + self._limit = limit + self._offset = offset + + self._cte_list = None + + @Node.copy + def with_cte(self, *cte_list): + self._cte_list = cte_list + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def orwhere(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.or_, expressions) + + @Node.copy + def order_by(self, *values): + self._order_by = values + + @Node.copy + def order_by_extend(self, *values): + self._order_by = ((self._order_by or ()) + values) or None + + @Node.copy + def limit(self, value=None): + self._limit = value + + @Node.copy + def offset(self, value=None): + self._offset = value + + @Node.copy + def paginate(self, page, paginate_by=20): + if page > 0: + page -= 1 + self._limit = paginate_by + self._offset = page * paginate_by + + def _apply_ordering(self, ctx): + if self._order_by: + (ctx + .literal(' ORDER BY ') + .sql(CommaNodeList(self._order_by))) + if self._limit is not None or (self._offset is not None and + ctx.state.limit_max): + limit = ctx.state.limit_max if self._limit is None else self._limit + ctx.literal(' LIMIT ').sql(limit) + if self._offset is not None: + ctx.literal(' OFFSET ').sql(self._offset) + return ctx + + def __sql__(self, ctx): + if self._cte_list: + # The CTE scope is only used at the very beginning of the query, + # when we are describing the various CTEs we will be using. + recursive = any(cte._recursive for cte in self._cte_list) + + # Explicitly disable the "subquery" flag here, so as to avoid + # unnecessary parentheses around subsequent selects. + with ctx.scope_cte(subquery=False): + (ctx + .literal('WITH RECURSIVE ' if recursive else 'WITH ') + .sql(CommaNodeList(self._cte_list)) + .literal(' ')) + return ctx + + +def __compound_select__(operation, inverted=False): + @__bind_database__ + def method(self, other): + if inverted: + self, other = other, self + return CompoundSelectQuery(self, operation, other) + return method + + +class SelectQuery(Query): + union_all = __add__ = __compound_select__('UNION ALL') + union = __or__ = __compound_select__('UNION') + intersect = __and__ = __compound_select__('INTERSECT') + except_ = __sub__ = __compound_select__('EXCEPT') + __radd__ = __compound_select__('UNION ALL', inverted=True) + __ror__ = __compound_select__('UNION', inverted=True) + __rand__ = __compound_select__('INTERSECT', inverted=True) + __rsub__ = __compound_select__('EXCEPT', inverted=True) + + def select_from(self, *columns): + if not columns: + raise ValueError('select_from() must specify one or more columns.') + + query = (Select((self,), columns) + .bind(self._database)) + if getattr(self, 'model', None) is not None: + # Bind to the sub-select's model type, if defined. + query = query.objects(self.model) + return query + + +class SelectBase(_HashableSource, Source, SelectQuery): + def _get_hash(self): + return hash((self.__class__, self._alias or id(self))) + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + @database_required + def peek(self, database, n=1): + rows = self.execute(database)[:n] + if rows: + return rows[0] if n == 1 else rows + + @database_required + def first(self, database, n=1): + if self._limit != n: + self._limit = n + self._cursor_wrapper = None + return self.peek(database, n=n) + + @database_required + def scalar(self, database, as_tuple=False, as_dict=False): + if as_dict: + return self.dicts().peek(database) + row = self.tuples().peek(database) + return row[0] if row and not as_tuple else row + + @database_required + def scalars(self, database): + for row in self.tuples().execute(database): + yield row[0] + + @database_required + def count(self, database, clear_limit=False): + clone = self.order_by().alias('_wrapped') + if clear_limit: + clone._limit = clone._offset = None + try: + if clone._having is None and clone._group_by is None and \ + clone._windows is None and clone._distinct is None and \ + clone._simple_distinct is not True: + clone = clone.select(SQL('1')) + except AttributeError: + pass + return Select([clone], [fn.COUNT(SQL('1'))]).scalar(database) + + @database_required + def exists(self, database): + clone = self.columns(SQL('1')) + clone._limit = 1 + clone._offset = None + return bool(clone.scalar()) + + @database_required + def get(self, database): + self._cursor_wrapper = None + try: + return self.execute(database)[0] + except IndexError: + pass + + +# QUERY IMPLEMENTATIONS. + + +class CompoundSelectQuery(SelectBase): + def __init__(self, lhs, op, rhs): + super(CompoundSelectQuery, self).__init__() + self.lhs = lhs + self.op = op + self.rhs = rhs + + @property + def _returning(self): + return self.lhs._returning + + @database_required + def exists(self, database): + query = Select((self.limit(1),), (SQL('1'),)).bind(database) + return bool(query.scalar()) + + def _get_query_key(self): + return (self.lhs.get_query_key(), self.rhs.get_query_key()) + + def _wrap_parens(self, ctx, subq): + csq_setting = ctx.state.compound_select_parentheses + + if not csq_setting or csq_setting == CSQ_PARENTHESES_NEVER: + return False + elif csq_setting == CSQ_PARENTHESES_ALWAYS: + return True + elif csq_setting == CSQ_PARENTHESES_UNNESTED: + if ctx.state.in_expr or ctx.state.in_function: + # If this compound select query is being used inside an + # expression, e.g., an IN or EXISTS(). + return False + + # If the query on the left or right is itself a compound select + # query, then we do not apply parentheses. However, if it is a + # regular SELECT query, we will apply parentheses. + return not isinstance(subq, CompoundSelectQuery) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_COLUMN: + return self.apply_column(ctx) + + # Call parent method to handle any CTEs. + super(CompoundSelectQuery, self).__sql__(ctx) + + outer_parens = ctx.subquery or (ctx.scope == SCOPE_SOURCE) + with ctx(parentheses=outer_parens): + # Should the left-hand query be wrapped in parentheses? + lhs_parens = self._wrap_parens(ctx, self.lhs) + with ctx.scope_normal(parentheses=lhs_parens, subquery=False): + ctx.sql(self.lhs) + ctx.literal(' %s ' % self.op) + with ctx.push_alias(): + # Should the right-hand query be wrapped in parentheses? + rhs_parens = self._wrap_parens(ctx, self.rhs) + with ctx.scope_normal(parentheses=rhs_parens, subquery=False): + ctx.sql(self.rhs) + + # Apply ORDER BY, LIMIT, OFFSET. We use the "values" scope so that + # entity names are not fully-qualified. This is a bit of a hack, as + # we're relying on the logic in Column.__sql__() to not fully + # qualify column names. + with ctx.scope_values(): + self._apply_ordering(ctx) + + return self.apply_alias(ctx) + + +class Select(SelectBase): + def __init__(self, from_list=None, columns=None, group_by=None, + having=None, distinct=None, windows=None, for_update=None, + for_update_of=None, nowait=None, lateral=None, **kwargs): + super(Select, self).__init__(**kwargs) + self._from_list = (list(from_list) if isinstance(from_list, tuple) + else from_list) or [] + self._returning = columns + self._group_by = group_by + self._having = having + self._windows = None + self._for_update = for_update # XXX: consider reorganizing. + self._for_update_of = for_update_of + self._for_update_nowait = nowait + self._lateral = lateral + + self._distinct = self._simple_distinct = None + if distinct: + if isinstance(distinct, bool): + self._simple_distinct = distinct + else: + self._distinct = distinct + + self._cursor_wrapper = None + + def clone(self): + clone = super(Select, self).clone() + if clone._from_list: + clone._from_list = list(clone._from_list) + return clone + + @Node.copy + def columns(self, *columns, **kwargs): + self._returning = columns + select = columns + + @Node.copy + def select_extend(self, *columns): + self._returning = tuple(self._returning) + columns + + @property + def selected_columns(self): + return self._returning + @selected_columns.setter + def selected_columns(self, value): + self._returning = value + + @Node.copy + def from_(self, *sources): + self._from_list = list(sources) + + @Node.copy + def join(self, dest, join_type=JOIN.INNER, on=None): + if not self._from_list: + raise ValueError('No sources to join on.') + item = self._from_list.pop() + self._from_list.append(Join(item, dest, join_type, on)) + + def left_outer_join(self, dest, on=None): + return self.join(dest, JOIN.LEFT_OUTER, on) + + @Node.copy + def group_by(self, *columns): + grouping = [] + for column in columns: + if isinstance(column, Table): + if not column._columns: + raise ValueError('Cannot pass a table to group_by() that ' + 'does not have columns explicitly ' + 'declared.') + grouping.extend([getattr(column, col_name) + for col_name in column._columns]) + else: + grouping.append(column) + self._group_by = grouping + + def group_by_extend(self, *values): + """@Node.copy used from group_by() call""" + group_by = tuple(self._group_by or ()) + values + return self.group_by(*group_by) + + @Node.copy + def having(self, *expressions): + if self._having is not None: + expressions = (self._having,) + expressions + self._having = reduce(operator.and_, expressions) + + @Node.copy + def distinct(self, *columns): + if len(columns) == 1 and (columns[0] is True or columns[0] is False): + self._simple_distinct = columns[0] + else: + self._simple_distinct = False + self._distinct = columns + + @Node.copy + def window(self, *windows): + self._windows = windows if windows else None + + @Node.copy + def for_update(self, for_update=True, of=None, nowait=None): + if not for_update and (of is not None or nowait): + for_update = True + self._for_update = for_update + self._for_update_of = of + self._for_update_nowait = nowait + + @Node.copy + def lateral(self, lateral=True): + self._lateral = lateral + + def _get_query_key(self): + return self._alias + + def __sql_selection__(self, ctx, is_subquery=False): + return ctx.sql(CommaNodeList(self._returning)) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_COLUMN: + return self.apply_column(ctx) + + if self._lateral and ctx.scope == SCOPE_SOURCE: + ctx.literal('LATERAL ') + + is_subquery = ctx.subquery + state = { + 'converter': None, + 'in_function': False, + 'parentheses': is_subquery or (ctx.scope == SCOPE_SOURCE), + 'subquery': True, + } + if ctx.state.in_function and ctx.state.function_arg_count == 1: + state['parentheses'] = False + + with ctx.scope_normal(**state): + # Defer calling parent SQL until here. This ensures that any CTEs + # for this query will be properly nested if this query is a + # sub-select or is used in an expression. See GH#1809 for example. + super(Select, self).__sql__(ctx) + + ctx.literal('SELECT ') + if self._simple_distinct or self._distinct is not None: + ctx.literal('DISTINCT ') + if self._distinct: + (ctx + .literal('ON ') + .sql(EnclosedNodeList(self._distinct)) + .literal(' ')) + + with ctx.scope_source(): + ctx = self.__sql_selection__(ctx, is_subquery) + + if self._from_list: + with ctx.scope_source(parentheses=False): + ctx.literal(' FROM ').sql(CommaNodeList(self._from_list)) + + if self._where is not None: + ctx.literal(' WHERE ').sql(self._where) + + if self._group_by: + ctx.literal(' GROUP BY ').sql(CommaNodeList(self._group_by)) + + if self._having is not None: + ctx.literal(' HAVING ').sql(self._having) + + if self._windows is not None: + ctx.literal(' WINDOW ') + ctx.sql(CommaNodeList(self._windows)) + + # Apply ORDER BY, LIMIT, OFFSET. + self._apply_ordering(ctx) + + if self._for_update: + if not ctx.state.for_update: + raise ValueError('FOR UPDATE specified but not supported ' + 'by database.') + ctx.literal(' ') + ctx.sql(ForUpdate(self._for_update, self._for_update_of, + self._for_update_nowait)) + + # If the subquery is inside a function -or- we are evaluating a + # subquery on either side of an expression w/o an explicit alias, do + # not generate an alias + AS clause. + if ctx.state.in_function or (ctx.state.in_expr and + self._alias is None): + return ctx + + return self.apply_alias(ctx) + + +class _WriteQuery(Query): + def __init__(self, table, returning=None, **kwargs): + self.table = table + self._returning = returning + self._return_cursor = True if returning else False + super(_WriteQuery, self).__init__(**kwargs) + + def cte(self, name, recursive=False, columns=None, materialized=None): + return CTE(name, self, recursive=recursive, columns=columns, + materialized=materialized) + + @Node.copy + def returning(self, *returning): + self._returning = returning + self._return_cursor = True if returning else False + + def apply_returning(self, ctx): + if self._returning: + with ctx.scope_source(): + ctx.literal(' RETURNING ').sql(CommaNodeList(self._returning)) + return ctx + + def _execute(self, database): + if self._returning: + cursor = self.execute_returning(database) + else: + cursor = database.execute(self) + return self.handle_result(database, cursor) + + def execute_returning(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + def handle_result(self, database, cursor): + if self._return_cursor: + return cursor + return database.rows_affected(cursor) + + def _set_table_alias(self, ctx): + ctx.alias_manager[self.table] = self.table.__name__ + + def __sql__(self, ctx): + super(_WriteQuery, self).__sql__(ctx) + # We explicitly set the table alias to the table's name, which ensures + # that if a sub-select references a column on the outer table, we won't + # assign it a new alias (e.g. t2) but will refer to it as table.column. + self._set_table_alias(ctx) + return ctx + + +class Update(_WriteQuery): + def __init__(self, table, update=None, **kwargs): + super(Update, self).__init__(table, **kwargs) + self._update = update + self._from = None + + @Node.copy + def from_(self, *sources): + self._from = sources + + def __sql__(self, ctx): + super(Update, self).__sql__(ctx) + + with ctx.scope_values(subquery=True): + ctx.literal('UPDATE ') + + expressions = [] + for k, v in sorted(self._update.items(), key=ctx.column_sort_key): + if not isinstance(v, Node): + if isinstance(k, Field): + v = k.to_value(v) + else: + v = Value(v, unpack=False) + elif isinstance(v, Model) and isinstance(k, ForeignKeyField): + # NB: we want to ensure that when passed a model instance + # in the context of a foreign-key, we apply the fk-specific + # adaptation of the model. + v = k.to_value(v) + + if not isinstance(v, Value): + v = qualify_names(v) + + expressions.append(NodeList((k, SQL('='), v))) + + (ctx + .sql(self.table) + .literal(' SET ') + .sql(CommaNodeList(expressions))) + + if self._from: + with ctx.scope_source(parentheses=False): + ctx.literal(' FROM ').sql(CommaNodeList(self._from)) + + if self._where: + with ctx.scope_normal(): + ctx.literal(' WHERE ').sql(self._where) + self._apply_ordering(ctx) + return self.apply_returning(ctx) + + +class Insert(_WriteQuery): + SIMPLE = 0 + QUERY = 1 + MULTI = 2 + class DefaultValuesException(Exception): pass + + def __init__(self, table, insert=None, columns=None, on_conflict=None, + **kwargs): + super(Insert, self).__init__(table, **kwargs) + self._insert = insert + self._columns = columns + self._on_conflict = on_conflict + self._query_type = None + self._as_rowcount = False + + def where(self, *expressions): + raise NotImplementedError('INSERT queries cannot have a WHERE clause.') + + @Node.copy + def as_rowcount(self, _as_rowcount=True): + self._as_rowcount = _as_rowcount + + @Node.copy + def on_conflict_ignore(self, ignore=True): + self._on_conflict = OnConflict('IGNORE') if ignore else None + + @Node.copy + def on_conflict_replace(self, replace=True): + self._on_conflict = OnConflict('REPLACE') if replace else None + + @Node.copy + def on_conflict(self, *args, **kwargs): + self._on_conflict = (OnConflict(*args, **kwargs) if (args or kwargs) + else None) + + def _simple_insert(self, ctx): + if not self._insert: + raise self.DefaultValuesException('Error: no data to insert.') + return self._generate_insert((self._insert,), ctx) + + def get_default_data(self): + return {} + + def get_default_columns(self): + if self.table._columns: + return [getattr(self.table, col) for col in self.table._columns + if col != self.table._primary_key] + + def _generate_insert(self, insert, ctx): + rows_iter = iter(insert) + columns = self._columns + + # Load and organize column defaults (if provided). + defaults = self.get_default_data() + + # First figure out what columns are being inserted (if they weren't + # specified explicitly). Resulting columns are normalized and ordered. + if not columns: + try: + row = next(rows_iter) + except StopIteration: + raise self.DefaultValuesException('Error: no rows to insert.') + + if not isinstance(row, Mapping): + columns = self.get_default_columns() + if columns is None: + raise ValueError('Bulk insert must specify columns.') + else: + # Infer column names from the dict of data being inserted. + accum = [] + for column in row: + if isinstance(column, basestring): + column = getattr(self.table, column) + accum.append(column) + + # Add any columns present in the default data that are not + # accounted for by the dictionary of row data. + column_set = set(accum) + for col in (set(defaults) - column_set): + accum.append(col) + + columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx)) + rows_iter = itertools.chain(iter((row,)), rows_iter) + else: + clean_columns = [] + seen = set() + for column in columns: + if isinstance(column, basestring): + column_obj = getattr(self.table, column) + else: + column_obj = column + clean_columns.append(column_obj) + seen.add(column_obj) + + columns = clean_columns + for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)): + if col not in seen: + columns.append(col) + + fk_fields = set() + nullable_columns = set() + value_lookups = {} + for column in columns: + lookups = [column, column.name] + if isinstance(column, Field): + if column.name != column.column_name: + lookups.append(column.column_name) + if column.null: + nullable_columns.add(column) + if isinstance(column, ForeignKeyField): + fk_fields.add(column) + value_lookups[column] = lookups + + ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ') + columns_converters = [ + (column, column.db_value if isinstance(column, Field) else None) + for column in columns] + + all_values = [] + for row in rows_iter: + values = [] + is_dict = isinstance(row, Mapping) + for i, (column, converter) in enumerate(columns_converters): + try: + if is_dict: + # The logic is a bit convoluted, but in order to be + # flexible in what we accept (dict keyed by + # column/field, field name, or underlying column name), + # we try accessing the row data dict using each + # possible key. If no match is found, throw an error. + for lookup in value_lookups[column]: + try: + val = row[lookup] + except KeyError: pass + else: break + else: + raise KeyError + else: + val = row[i] + except (KeyError, IndexError): + if column in defaults: + val = defaults[column] + if callable_(val): + val = val() + elif column in nullable_columns: + val = None + else: + raise ValueError('Missing value for %s.' % column.name) + + if not isinstance(val, Node) or (isinstance(val, Model) and + column in fk_fields): + val = Value(val, converter=converter, unpack=False) + values.append(val) + + all_values.append(EnclosedNodeList(values)) + + if not all_values: + raise self.DefaultValuesException('Error: no data to insert.') + + with ctx.scope_values(subquery=True): + return ctx.sql(CommaNodeList(all_values)) + + def _query_insert(self, ctx): + return (ctx + .sql(EnclosedNodeList(self._columns)) + .literal(' ') + .sql(self._insert)) + + def _default_values(self, ctx): + if not self._database: + return ctx.literal('DEFAULT VALUES') + return self._database.default_values_insert(ctx) + + def __sql__(self, ctx): + super(Insert, self).__sql__(ctx) + with ctx.scope_values(): + stmt = None + if self._on_conflict is not None: + stmt = self._on_conflict.get_conflict_statement(ctx, self) + + (ctx + .sql(stmt or SQL('INSERT')) + .literal(' INTO ') + .sql(self.table) + .literal(' ')) + + if isinstance(self._insert, Mapping) and not self._columns: + try: + self._simple_insert(ctx) + except self.DefaultValuesException: + self._default_values(ctx) + self._query_type = Insert.SIMPLE + elif isinstance(self._insert, (SelectQuery, SQL)): + self._query_insert(ctx) + self._query_type = Insert.QUERY + else: + self._generate_insert(self._insert, ctx) + self._query_type = Insert.MULTI + + if self._on_conflict is not None: + update = self._on_conflict.get_conflict_update(ctx, self) + if update is not None: + ctx.literal(' ').sql(update) + + return self.apply_returning(ctx) + + def _execute(self, database): + if self._returning is None and database.returning_clause \ + and self.table._primary_key: + self._returning = (self.table._primary_key,) + try: + return super(Insert, self)._execute(database) + except self.DefaultValuesException: + pass + + def handle_result(self, database, cursor): + if self._return_cursor: + return cursor + if self._as_rowcount: + return database.rows_affected(cursor) + return database.last_insert_id(cursor, self._query_type) + + +class Delete(_WriteQuery): + def __sql__(self, ctx): + super(Delete, self).__sql__(ctx) + + with ctx.scope_values(subquery=True): + ctx.literal('DELETE FROM ').sql(self.table) + if self._where is not None: + with ctx.scope_normal(): + ctx.literal(' WHERE ').sql(self._where) + + self._apply_ordering(ctx) + return self.apply_returning(ctx) + + +class Index(Node): + def __init__(self, name, table, expressions, unique=False, safe=False, + where=None, using=None): + self._name = name + self._table = Entity(table) if not isinstance(table, Table) else table + self._expressions = expressions + self._where = where + self._unique = unique + self._safe = safe + self._using = using + + @Node.copy + def safe(self, _safe=True): + self._safe = _safe + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def using(self, _using=None): + self._using = _using + + def __sql__(self, ctx): + statement = 'CREATE UNIQUE INDEX ' if self._unique else 'CREATE INDEX ' + with ctx.scope_values(subquery=True): + ctx.literal(statement) + if self._safe: + ctx.literal('IF NOT EXISTS ') + + # Sqlite uses CREATE INDEX . ON , whereas most + # others use: CREATE INDEX ON .
. + if ctx.state.index_schema_prefix and \ + isinstance(self._table, Table) and self._table._schema: + index_name = Entity(self._table._schema, self._name) + table_name = Entity(self._table.__name__) + else: + index_name = Entity(self._name) + table_name = self._table + + ctx.sql(index_name) + if self._using is not None and \ + ctx.state.index_using_precedes_table: + ctx.literal(' USING %s' % self._using) # MySQL style. + + (ctx + .literal(' ON ') + .sql(table_name) + .literal(' ')) + + if self._using is not None and not \ + ctx.state.index_using_precedes_table: + ctx.literal('USING %s ' % self._using) # Postgres/default. + + ctx.sql(EnclosedNodeList([ + SQL(expr) if isinstance(expr, basestring) else expr + for expr in self._expressions])) + if self._where is not None: + ctx.literal(' WHERE ').sql(self._where) + + return ctx + + +class ModelIndex(Index): + def __init__(self, model, fields, unique=False, safe=True, where=None, + using=None, name=None): + self._model = model + if name is None: + name = self._generate_name_from_fields(model, fields) + if using is None: + for field in fields: + if isinstance(field, Field) and hasattr(field, 'index_type'): + using = field.index_type + super(ModelIndex, self).__init__( + name=name, + table=model._meta.table, + expressions=fields, + unique=unique, + safe=safe, + where=where, + using=using) + + def _generate_name_from_fields(self, model, fields): + accum = [] + for field in fields: + if isinstance(field, basestring): + accum.append(field.split()[0]) + else: + if isinstance(field, Node) and not isinstance(field, Field): + field = field.unwrap() + if isinstance(field, Field): + accum.append(field.column_name) + + if not accum: + raise ValueError('Unable to generate a name for the index, please ' + 'explicitly specify a name.') + + clean_field_names = re.sub(r'[^\w]+', '', '_'.join(accum)) + meta = model._meta + prefix = meta.name if meta.legacy_table_names else meta.table_name + return _truncate_constraint_name('_'.join((prefix, clean_field_names))) + + +def _truncate_constraint_name(constraint, maxlen=64): + if len(constraint) > maxlen: + name_hash = hashlib.md5(constraint.encode('utf-8')).hexdigest() + constraint = '%s_%s' % (constraint[:(maxlen - 8)], name_hash[:7]) + return constraint + + +# DB-API 2.0 EXCEPTIONS. + + +class PeeweeException(Exception): + def __init__(self, *args): + if args and isinstance(args[0], Exception): + self.orig, args = args[0], args[1:] + super(PeeweeException, self).__init__(*args) +class ImproperlyConfigured(PeeweeException): pass +class DatabaseError(PeeweeException): pass +class DataError(DatabaseError): pass +class IntegrityError(DatabaseError): pass +class InterfaceError(PeeweeException): pass +class InternalError(DatabaseError): pass +class NotSupportedError(DatabaseError): pass +class OperationalError(DatabaseError): pass +class ProgrammingError(DatabaseError): pass + + +class ExceptionWrapper(object): + __slots__ = ('exceptions',) + def __init__(self, exceptions): + self.exceptions = exceptions + def __enter__(self): pass + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + return + # psycopg2.8 shits out a million cute error types. Try to catch em all. + if pg_errors is not None and exc_type.__name__ not in self.exceptions \ + and issubclass(exc_type, pg_errors.Error): + exc_type = exc_type.__bases__[0] + if exc_type.__name__ in self.exceptions: + new_type = self.exceptions[exc_type.__name__] + exc_args = exc_value.args + reraise(new_type, new_type(exc_value, *exc_args), traceback) + + +EXCEPTIONS = { + 'ConstraintError': IntegrityError, + 'DatabaseError': DatabaseError, + 'DataError': DataError, + 'IntegrityError': IntegrityError, + 'InterfaceError': InterfaceError, + 'InternalError': InternalError, + 'NotSupportedError': NotSupportedError, + 'OperationalError': OperationalError, + 'ProgrammingError': ProgrammingError, + 'TransactionRollbackError': OperationalError} + +__exception_wrapper__ = ExceptionWrapper(EXCEPTIONS) + + +# DATABASE INTERFACE AND CONNECTION MANAGEMENT. + + +IndexMetadata = collections.namedtuple( + 'IndexMetadata', + ('name', 'sql', 'columns', 'unique', 'table')) +ColumnMetadata = collections.namedtuple( + 'ColumnMetadata', + ('name', 'data_type', 'null', 'primary_key', 'table', 'default')) +ForeignKeyMetadata = collections.namedtuple( + 'ForeignKeyMetadata', + ('column', 'dest_table', 'dest_column', 'table')) +ViewMetadata = collections.namedtuple('ViewMetadata', ('name', 'sql')) + + +class _ConnectionState(object): + def __init__(self, **kwargs): + super(_ConnectionState, self).__init__(**kwargs) + self.reset() + + def reset(self): + self.closed = True + self.conn = None + self.ctx = [] + self.transactions = [] + + def set_connection(self, conn): + self.conn = conn + self.closed = False + self.ctx = [] + self.transactions = [] + + +class _ConnectionLocal(_ConnectionState, threading.local): pass +class _NoopLock(object): + __slots__ = () + def __enter__(self): return self + def __exit__(self, exc_type, exc_val, exc_tb): pass + + +class ConnectionContext(object): + __slots__ = ('db',) + def __init__(self, db): self.db = db + def __enter__(self): + if self.db.is_closed(): + self.db.connect() + def __exit__(self, exc_type, exc_val, exc_tb): self.db.close() + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + with ConnectionContext(self.db): + return fn(*args, **kwargs) + return inner + + +class Database(_callable_context_manager): + context_class = Context + field_types = {} + operations = {} + param = '?' + quote = '""' + server_version = None + + # Feature toggles. + compound_select_parentheses = CSQ_PARENTHESES_NEVER + for_update = False + index_schema_prefix = False + index_using_precedes_table = False + limit_max = None + nulls_ordering = False + returning_clause = False + safe_create_index = True + safe_drop_index = True + sequences = False + truncate_table = True + + def __init__(self, database, thread_safe=True, autorollback=False, + field_types=None, operations=None, autocommit=None, + autoconnect=True, **kwargs): + self._field_types = merge_dict(FIELD, self.field_types) + self._operations = merge_dict(OP, self.operations) + if field_types: + self._field_types.update(field_types) + if operations: + self._operations.update(operations) + + self.autoconnect = autoconnect + self.thread_safe = thread_safe + if thread_safe: + self._state = _ConnectionLocal() + self._lock = threading.RLock() + else: + self._state = _ConnectionState() + self._lock = _NoopLock() + + if autorollback: + __deprecated__('Peewee no longer uses the "autorollback" option, ' + 'as we always run in autocommit-mode now. This ' + 'changes psycopg2\'s semantics so that the conn ' + 'is not left in a transaction-aborted state.') + + if autocommit is not None: + __deprecated__('Peewee no longer uses the "autocommit" option, as ' + 'the semantics now require it to always be True. ' + 'Because some database-drivers also use the ' + '"autocommit" parameter, you are receiving a ' + 'warning so you may update your code and remove ' + 'the parameter, as in the future, specifying ' + 'autocommit could impact the behavior of the ' + 'database driver you are using.') + + self.connect_params = {} + self.init(database, **kwargs) + + def init(self, database, **kwargs): + if not self.is_closed(): + self.close() + self.database = database + self.connect_params.update(kwargs) + self.deferred = not bool(database) + + def __enter__(self): + if self.is_closed(): + self.connect() + ctx = self.atomic() + self._state.ctx.append(ctx) + ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + ctx = self._state.ctx.pop() + try: + ctx.__exit__(exc_type, exc_val, exc_tb) + finally: + if not self._state.ctx: + self.close() + + def connection_context(self): + return ConnectionContext(self) + + def _connect(self): + raise NotImplementedError + + def connect(self, reuse_if_open=False): + with self._lock: + if self.deferred: + raise InterfaceError('Error, database must be initialized ' + 'before opening a connection.') + if not self._state.closed: + if reuse_if_open: + return False + raise OperationalError('Connection already opened.') + + self._state.reset() + with __exception_wrapper__: + self._state.set_connection(self._connect()) + if self.server_version is None: + self._set_server_version(self._state.conn) + self._initialize_connection(self._state.conn) + return True + + def _initialize_connection(self, conn): + pass + + def _set_server_version(self, conn): + self.server_version = 0 + + def close(self): + with self._lock: + if self.deferred: + raise InterfaceError('Error, database must be initialized ' + 'before opening a connection.') + if self.in_transaction(): + raise OperationalError('Attempting to close database while ' + 'transaction is open.') + is_open = not self._state.closed + try: + if is_open: + with __exception_wrapper__: + self._close(self._state.conn) + finally: + self._state.reset() + return is_open + + def _close(self, conn): + conn.close() + + def is_closed(self): + return self._state.closed + + def is_connection_usable(self): + return not self._state.closed + + def connection(self): + if self.is_closed(): + self.connect() + return self._state.conn + + def cursor(self, commit=None, named_cursor=None): + if commit is not None: + __deprecated__('"commit" has been deprecated and is a no-op.') + if self.is_closed(): + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') + return self._state.conn.cursor() + + def execute_sql(self, sql, params=None, commit=None): + if commit is not None: + __deprecated__('"commit" has been deprecated and is a no-op.') + logger.debug((sql, params)) + with __exception_wrapper__: + cursor = self.cursor() + cursor.execute(sql, params or ()) + return cursor + + def execute(self, query, commit=None, **context_options): + if commit is not None: + __deprecated__('"commit" has been deprecated and is a no-op.') + ctx = self.get_sql_context(**context_options) + sql, params = ctx.sql(query).query() + return self.execute_sql(sql, params) + + def get_context_options(self): + return { + 'field_types': self._field_types, + 'operations': self._operations, + 'param': self.param, + 'quote': self.quote, + 'compound_select_parentheses': self.compound_select_parentheses, + 'conflict_statement': self.conflict_statement, + 'conflict_update': self.conflict_update, + 'for_update': self.for_update, + 'index_schema_prefix': self.index_schema_prefix, + 'index_using_precedes_table': self.index_using_precedes_table, + 'limit_max': self.limit_max, + 'nulls_ordering': self.nulls_ordering, + } + + def get_sql_context(self, **context_options): + context = self.get_context_options() + if context_options: + context.update(context_options) + return self.context_class(**context) + + def conflict_statement(self, on_conflict, query): + raise NotImplementedError + + def conflict_update(self, on_conflict, query): + raise NotImplementedError + + def _build_on_conflict_update(self, on_conflict, query): + if on_conflict._conflict_target: + stmt = SQL('ON CONFLICT') + target = EnclosedNodeList([ + Entity(col) if isinstance(col, basestring) else col + for col in on_conflict._conflict_target]) + if on_conflict._conflict_where is not None: + target = NodeList([target, SQL('WHERE'), + on_conflict._conflict_where]) + else: + stmt = SQL('ON CONFLICT ON CONSTRAINT') + target = on_conflict._conflict_constraint + if isinstance(target, basestring): + target = Entity(target) + + updates = [] + if on_conflict._preserve: + for column in on_conflict._preserve: + excluded = NodeList((SQL('EXCLUDED'), ensure_entity(column)), + glue='.') + expression = NodeList((ensure_entity(column), SQL('='), + excluded)) + updates.append(expression) + + if on_conflict._update: + for k, v in on_conflict._update.items(): + if not isinstance(v, Node): + # Attempt to resolve string field-names to their respective + # field object, to apply data-type conversions. + if isinstance(k, basestring): + k = getattr(query.table, k) + if isinstance(k, Field): + v = k.to_value(v) + else: + v = Value(v, unpack=False) + else: + v = QualifiedNames(v) + updates.append(NodeList((ensure_entity(k), SQL('='), v))) + + parts = [stmt, target, SQL('DO UPDATE SET'), CommaNodeList(updates)] + if on_conflict._where: + parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where))) + + return NodeList(parts) + + def last_insert_id(self, cursor, query_type=None): + return cursor.lastrowid + + def rows_affected(self, cursor): + return cursor.rowcount + + def default_values_insert(self, ctx): + return ctx.literal('DEFAULT VALUES') + + def session_start(self): + with self._lock: + return self.transaction().__enter__() + + def session_commit(self): + with self._lock: + try: + txn = self.pop_transaction() + except IndexError: + return False + txn.commit(begin=self.in_transaction()) + return True + + def session_rollback(self): + with self._lock: + try: + txn = self.pop_transaction() + except IndexError: + return False + txn.rollback(begin=self.in_transaction()) + return True + + def in_transaction(self): + return bool(self._state.transactions) + + def push_transaction(self, transaction): + self._state.transactions.append(transaction) + + def pop_transaction(self): + return self._state.transactions.pop() + + def transaction_depth(self): + return len(self._state.transactions) + + def top_transaction(self): + if self._state.transactions: + return self._state.transactions[-1] + + def atomic(self, *args, **kwargs): + return _atomic(self, *args, **kwargs) + + def manual_commit(self): + return _manual(self) + + def transaction(self, *args, **kwargs): + return _transaction(self, *args, **kwargs) + + def savepoint(self): + return _savepoint(self) + + def begin(self): + if self.is_closed(): + self.connect() + with __exception_wrapper__: + self.cursor().execute('BEGIN') + + def rollback(self): + with __exception_wrapper__: + self.cursor().execute('ROLLBACK') + + def commit(self): + with __exception_wrapper__: + self.cursor().execute('COMMIT') + + def batch_commit(self, it, n): + for group in chunked(it, n): + with self.atomic(): + for obj in group: + yield obj + + def table_exists(self, table_name, schema=None): + if is_model(table_name): + model = table_name + table_name = model._meta.table_name + schema = model._meta.schema + return table_name in self.get_tables(schema=schema) + + def get_tables(self, schema=None): + raise NotImplementedError + + def get_indexes(self, table, schema=None): + raise NotImplementedError + + def get_columns(self, table, schema=None): + raise NotImplementedError + + def get_primary_keys(self, table, schema=None): + raise NotImplementedError + + def get_foreign_keys(self, table, schema=None): + raise NotImplementedError + + def sequence_exists(self, seq): + raise NotImplementedError + + def create_tables(self, models, **options): + for model in sort_models(models): + model.create_table(**options) + + def drop_tables(self, models, **kwargs): + for model in reversed(sort_models(models)): + model.drop_table(**kwargs) + + def extract_date(self, date_part, date_field): + raise NotImplementedError + + def truncate_date(self, date_part, date_field): + raise NotImplementedError + + def to_timestamp(self, date_field): + raise NotImplementedError + + def from_timestamp(self, date_field): + raise NotImplementedError + + def random(self): + return fn.random() + + def bind(self, models, bind_refs=True, bind_backrefs=True): + for model in models: + model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs) + + def bind_ctx(self, models, bind_refs=True, bind_backrefs=True): + return _BoundModelsContext(models, self, bind_refs, bind_backrefs) + + def get_noop_select(self, ctx): + return ctx.sql(Select().columns(SQL('0')).where(SQL('0'))) + + @property + def Model(self): + if not hasattr(self, '_Model'): + class Meta: database = self + self._Model = type('BaseModel', (Model,), {'Meta': Meta}) + return self._Model + + +def __pragma__(name): + def __get__(self): + return self.pragma(name) + def __set__(self, value): + return self.pragma(name, value) + return property(__get__, __set__) + + +class SqliteDatabase(Database): + field_types = { + 'BIGAUTO': FIELD.AUTO, + 'BIGINT': FIELD.INT, + 'BOOL': FIELD.INT, + 'DOUBLE': FIELD.FLOAT, + 'SMALLINT': FIELD.INT, + 'UUID': FIELD.TEXT} + operations = { + 'LIKE': 'GLOB', + 'ILIKE': 'LIKE'} + index_schema_prefix = True + limit_max = -1 + server_version = __sqlite_version__ + truncate_table = False + + def __init__(self, database, *args, **kwargs): + self._pragmas = kwargs.pop('pragmas', ()) + super(SqliteDatabase, self).__init__(database, *args, **kwargs) + self._aggregates = {} + self._collations = {} + self._functions = {} + self._window_functions = {} + self._table_functions = [] + self._extensions = set() + self._attached = {} + self.register_function(_sqlite_date_part, 'date_part', 2) + self.register_function(_sqlite_date_trunc, 'date_trunc', 2) + self.nulls_ordering = self.server_version >= (3, 30, 0) + + def init(self, database, pragmas=None, timeout=5, returning_clause=None, + **kwargs): + if pragmas is not None: + self._pragmas = pragmas + if isinstance(self._pragmas, dict): + self._pragmas = list(self._pragmas.items()) + if returning_clause is not None: + if __sqlite_version__ < (3, 35, 0): + warnings.warn('RETURNING clause requires Sqlite 3.35 or newer') + self.returning_clause = returning_clause + self._timeout = timeout + super(SqliteDatabase, self).init(database, **kwargs) + + def _set_server_version(self, conn): + pass + + def _connect(self): + if sqlite3 is None: + raise ImproperlyConfigured('SQLite driver not installed!') + conn = sqlite3.connect(self.database, timeout=self._timeout, + isolation_level=None, **self.connect_params) + try: + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def _add_conn_hooks(self, conn): + if self._attached: + self._attach_databases(conn) + if self._pragmas: + self._set_pragmas(conn) + self._load_aggregates(conn) + self._load_collations(conn) + self._load_functions(conn) + if self.server_version >= (3, 25, 0): + self._load_window_functions(conn) + if self._table_functions: + for table_function in self._table_functions: + table_function.register(conn) + if self._extensions: + self._load_extensions(conn) + + def _set_pragmas(self, conn): + cursor = conn.cursor() + for pragma, value in self._pragmas: + cursor.execute('PRAGMA %s = %s;' % (pragma, value)) + cursor.close() + + def _attach_databases(self, conn): + cursor = conn.cursor() + for name, db in self._attached.items(): + cursor.execute('ATTACH DATABASE "%s" AS "%s"' % (db, name)) + cursor.close() + + def pragma(self, key, value=SENTINEL, permanent=False, schema=None): + if schema is not None: + key = '"%s".%s' % (schema, key) + sql = 'PRAGMA %s' % key + if value is not SENTINEL: + sql += ' = %s' % (value or 0) + if permanent: + pragmas = dict(self._pragmas or ()) + pragmas[key] = value + self._pragmas = list(pragmas.items()) + elif permanent: + raise ValueError('Cannot specify a permanent pragma without value') + row = self.execute_sql(sql).fetchone() + if row: + return row[0] + + cache_size = __pragma__('cache_size') + foreign_keys = __pragma__('foreign_keys') + journal_mode = __pragma__('journal_mode') + journal_size_limit = __pragma__('journal_size_limit') + mmap_size = __pragma__('mmap_size') + page_size = __pragma__('page_size') + read_uncommitted = __pragma__('read_uncommitted') + synchronous = __pragma__('synchronous') + wal_autocheckpoint = __pragma__('wal_autocheckpoint') + application_id = __pragma__('application_id') + user_version = __pragma__('user_version') + data_version = __pragma__('data_version') + + @property + def timeout(self): + return self._timeout + + @timeout.setter + def timeout(self, seconds): + if self._timeout == seconds: + return + + self._timeout = seconds + if not self.is_closed(): + # PySQLite multiplies user timeout by 1000, but the unit of the + # timeout PRAGMA is actually milliseconds. + self.execute_sql('PRAGMA busy_timeout=%d;' % (seconds * 1000)) + + def _load_aggregates(self, conn): + for name, (klass, num_params) in self._aggregates.items(): + conn.create_aggregate(name, num_params, klass) + + def _load_collations(self, conn): + for name, fn in self._collations.items(): + conn.create_collation(name, fn) + + def _load_functions(self, conn): + for name, (fn, n_params, deterministic) in self._functions.items(): + kwargs = {'deterministic': deterministic} if deterministic else {} + conn.create_function(name, n_params, fn, **kwargs) + + def _load_window_functions(self, conn): + for name, (klass, num_params) in self._window_functions.items(): + conn.create_window_function(name, num_params, klass) + + def register_aggregate(self, klass, name=None, num_params=-1): + self._aggregates[name or klass.__name__.lower()] = (klass, num_params) + if not self.is_closed(): + self._load_aggregates(self.connection()) + + def aggregate(self, name=None, num_params=-1): + def decorator(klass): + self.register_aggregate(klass, name, num_params) + return klass + return decorator + + def register_collation(self, fn, name=None): + name = name or fn.__name__ + def _collation(*args): + expressions = args + (SQL('collate %s' % name),) + return NodeList(expressions) + fn.collation = _collation + self._collations[name] = fn + if not self.is_closed(): + self._load_collations(self.connection()) + + def collation(self, name=None): + def decorator(fn): + self.register_collation(fn, name) + return fn + return decorator + + def register_function(self, fn, name=None, num_params=-1, + deterministic=None): + self._functions[name or fn.__name__] = (fn, num_params, deterministic) + if not self.is_closed(): + self._load_functions(self.connection()) + + def func(self, name=None, num_params=-1, deterministic=None): + def decorator(fn): + self.register_function(fn, name, num_params, deterministic) + return fn + return decorator + + def register_window_function(self, klass, name=None, num_params=-1): + name = name or klass.__name__.lower() + self._window_functions[name] = (klass, num_params) + if not self.is_closed(): + self._load_window_functions(self.connection()) + + def window_function(self, name=None, num_params=-1): + def decorator(klass): + self.register_window_function(klass, name, num_params) + return klass + return decorator + + def register_table_function(self, klass, name=None): + if name is not None: + klass.name = name + self._table_functions.append(klass) + if not self.is_closed(): + klass.register(self.connection()) + + def table_function(self, name=None): + def decorator(klass): + self.register_table_function(klass, name) + return klass + return decorator + + def unregister_aggregate(self, name): + del(self._aggregates[name]) + + def unregister_collation(self, name): + del(self._collations[name]) + + def unregister_function(self, name): + del(self._functions[name]) + + def unregister_window_function(self, name): + del(self._window_functions[name]) + + def unregister_table_function(self, name): + for idx, klass in enumerate(self._table_functions): + if klass.name == name: + break + else: + return False + self._table_functions.pop(idx) + return True + + def _load_extensions(self, conn): + conn.enable_load_extension(True) + for extension in self._extensions: + conn.load_extension(extension) + + def load_extension(self, extension): + self._extensions.add(extension) + if not self.is_closed(): + conn = self.connection() + conn.enable_load_extension(True) + conn.load_extension(extension) + + def unload_extension(self, extension): + self._extensions.remove(extension) + + def attach(self, filename, name): + if name in self._attached: + if self._attached[name] == filename: + return False + raise OperationalError('schema "%s" already attached.' % name) + + self._attached[name] = filename + if not self.is_closed(): + self.execute_sql('ATTACH DATABASE "%s" AS "%s"' % (filename, name)) + return True + + def detach(self, name): + if name not in self._attached: + return False + + del self._attached[name] + if not self.is_closed(): + self.execute_sql('DETACH DATABASE "%s"' % name) + return True + + def last_insert_id(self, cursor, query_type=None): + if not self.returning_clause: + return cursor.lastrowid + elif query_type == Insert.SIMPLE: + try: + return cursor[0][0] + except (IndexError, KeyError, TypeError): + pass + return cursor + + def rows_affected(self, cursor): + try: + return cursor.rowcount + except AttributeError: + return cursor.cursor.rowcount # This was a RETURNING query. + + def begin(self, lock_type=None): + statement = 'BEGIN %s' % lock_type if lock_type else 'BEGIN' + self.execute_sql(statement) + + def commit(self): + with __exception_wrapper__: + return self._state.conn.commit() + + def rollback(self): + with __exception_wrapper__: + return self._state.conn.rollback() + + def get_tables(self, schema=None): + schema = schema or 'main' + cursor = self.execute_sql('SELECT name FROM "%s".sqlite_master WHERE ' + 'type=? ORDER BY name' % schema, ('table',)) + return [row for row, in cursor.fetchall()] + + def get_views(self, schema=None): + sql = ('SELECT name, sql FROM "%s".sqlite_master WHERE type=? ' + 'ORDER BY name') % (schema or 'main') + return [ViewMetadata(*row) for row in self.execute_sql(sql, ('view',))] + + def get_indexes(self, table, schema=None): + schema = schema or 'main' + query = ('SELECT name, sql FROM "%s".sqlite_master ' + 'WHERE tbl_name = ? AND type = ? ORDER BY name') % schema + cursor = self.execute_sql(query, (table, 'index')) + index_to_sql = dict(cursor.fetchall()) + + # Determine which indexes have a unique constraint. + unique_indexes = set() + cursor = self.execute_sql('PRAGMA "%s".index_list("%s")' % + (schema, table)) + for row in cursor.fetchall(): + name = row[1] + is_unique = int(row[2]) == 1 + if is_unique: + unique_indexes.add(name) + + # Retrieve the indexed columns. + index_columns = {} + for index_name in sorted(index_to_sql): + cursor = self.execute_sql('PRAGMA "%s".index_info("%s")' % + (schema, index_name)) + index_columns[index_name] = [row[2] for row in cursor.fetchall()] + + return [ + IndexMetadata( + name, + index_to_sql[name], + index_columns[name], + name in unique_indexes, + table) + for name in sorted(index_to_sql)] + + def get_columns(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % + (schema or 'main', table)) + return [ColumnMetadata(r[1], r[2], not r[3], bool(r[5]), table, r[4]) + for r in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % + (schema or 'main', table)) + return [row[1] for row in filter(lambda r: r[-1], cursor.fetchall())] + + def get_foreign_keys(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".foreign_key_list("%s")' % + (schema or 'main', table)) + return [ForeignKeyMetadata(row[3], row[2], row[4], table) + for row in cursor.fetchall()] + + def get_binary_type(self): + return sqlite3.Binary + + def conflict_statement(self, on_conflict, query): + action = on_conflict._action.lower() if on_conflict._action else '' + if action and action not in ('nothing', 'update'): + return SQL('INSERT OR %s' % on_conflict._action.upper()) + + def conflict_update(self, oc, query): + # Sqlite prior to 3.24.0 does not support Postgres-style upsert. + if self.server_version < (3, 24, 0) and \ + any((oc._preserve, oc._update, oc._where, oc._conflict_target, + oc._conflict_constraint)): + raise ValueError('SQLite does not support specifying which values ' + 'to preserve or update.') + + action = oc._action.lower() if oc._action else '' + if action and action not in ('nothing', 'update', ''): + return + + if action == 'nothing': + return SQL('ON CONFLICT DO NOTHING') + elif not oc._update and not oc._preserve: + raise ValueError('If you are not performing any updates (or ' + 'preserving any INSERTed values), then the ' + 'conflict resolution action should be set to ' + '"NOTHING".') + elif oc._conflict_constraint: + raise ValueError('SQLite does not support specifying named ' + 'constraints for conflict resolution.') + elif not oc._conflict_target: + raise ValueError('SQLite requires that a conflict target be ' + 'specified when doing an upsert.') + + return self._build_on_conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.date_part(date_part, date_field, python_value=int) + + def truncate_date(self, date_part, date_field): + return fn.date_trunc(date_part, date_field, + python_value=simple_date_time) + + def to_timestamp(self, date_field): + return fn.strftime('%s', date_field).cast('integer') + + def from_timestamp(self, date_field): + return fn.datetime(date_field, 'unixepoch') + + +class PostgresqlDatabase(Database): + field_types = { + 'AUTO': 'SERIAL', + 'BIGAUTO': 'BIGSERIAL', + 'BLOB': 'BYTEA', + 'BOOL': 'BOOLEAN', + 'DATETIME': 'TIMESTAMP', + 'DECIMAL': 'NUMERIC', + 'DOUBLE': 'DOUBLE PRECISION', + 'UUID': 'UUID', + 'UUIDB': 'BYTEA'} + operations = {'REGEXP': '~', 'IREGEXP': '~*'} + param = '%s' + + compound_select_parentheses = CSQ_PARENTHESES_ALWAYS + for_update = True + nulls_ordering = True + returning_clause = True + safe_create_index = False + sequences = True + + def init(self, database, register_unicode=True, encoding=None, + isolation_level=None, **kwargs): + self._register_unicode = register_unicode + self._encoding = encoding + self._isolation_level = isolation_level + super(PostgresqlDatabase, self).init(database, **kwargs) + + def _connect(self): + if psycopg2 is None: + raise ImproperlyConfigured('Postgres driver not installed!') + + # Handle connection-strings nicely, since psycopg2 will accept them, + # and they may be easier when lots of parameters are specified. + params = self.connect_params.copy() + if self.database.startswith('postgresql://'): + params.setdefault('dsn', self.database) + else: + params.setdefault('dbname', self.database) + + conn = psycopg2.connect(**params) + if self._register_unicode: + pg_extensions.register_type(pg_extensions.UNICODE, conn) + pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) + if self._encoding: + conn.set_client_encoding(self._encoding) + if self._isolation_level: + conn.set_isolation_level(self._isolation_level) + conn.autocommit = True + return conn + + def _set_server_version(self, conn): + self.server_version = conn.server_version + if self.server_version >= 90600: + self.safe_create_index = True + + def is_connection_usable(self): + if self._state.closed: + return False + + # Returns True if we are idle, running a command, or in an active + # connection. If the connection is in an error state or the connection + # is otherwise unusable, return False. + txn_status = self._state.conn.get_transaction_status() + return txn_status < pg_extensions.TRANSACTION_STATUS_INERROR + + def last_insert_id(self, cursor, query_type=None): + try: + return cursor if query_type != Insert.SIMPLE else cursor[0][0] + except (IndexError, KeyError, TypeError): + pass + + def rows_affected(self, cursor): + try: + return cursor.rowcount + except AttributeError: + return cursor.cursor.rowcount + + def begin(self, isolation_level=None): + if self.is_closed(): + self.connect() + if isolation_level: + stmt = 'BEGIN TRANSACTION ISOLATION LEVEL %s' % isolation_level + else: + stmt = 'BEGIN' + with __exception_wrapper__: + self.cursor().execute(stmt) + + def get_tables(self, schema=None): + query = ('SELECT tablename FROM pg_catalog.pg_tables ' + 'WHERE schemaname = %s ORDER BY tablename') + cursor = self.execute_sql(query, (schema or 'public',)) + return [table for table, in cursor.fetchall()] + + def get_views(self, schema=None): + query = ('SELECT viewname, definition FROM pg_catalog.pg_views ' + 'WHERE schemaname = %s ORDER BY viewname') + cursor = self.execute_sql(query, (schema or 'public',)) + return [ViewMetadata(view_name, sql.strip(' \t;')) + for (view_name, sql) in cursor.fetchall()] + + def get_indexes(self, table, schema=None): + query = """ + SELECT + i.relname, idxs.indexdef, idx.indisunique, + array_to_string(ARRAY( + SELECT pg_get_indexdef(idx.indexrelid, k + 1, TRUE) + FROM generate_subscripts(idx.indkey, 1) AS k + ORDER BY k), ',') + FROM pg_catalog.pg_class AS t + INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid + INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid + INNER JOIN pg_catalog.pg_indexes AS idxs ON + (idxs.tablename = t.relname AND idxs.indexname = i.relname) + WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s + ORDER BY idx.indisunique DESC, i.relname;""" + cursor = self.execute_sql(query, (table, 'r', schema or 'public')) + return [IndexMetadata(name, sql.rstrip(' ;'), columns.split(','), + is_unique, table) + for name, sql, is_unique, columns in cursor.fetchall()] + + def get_columns(self, table, schema=None): + query = """ + SELECT column_name, is_nullable, data_type, column_default + FROM information_schema.columns + WHERE table_name = %s AND table_schema = %s + ORDER BY ordinal_position""" + cursor = self.execute_sql(query, (table, schema or 'public')) + pks = set(self.get_primary_keys(table, schema)) + return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) + for name, null, dt, df in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + query = """ + SELECT kc.column_name + FROM information_schema.table_constraints AS tc + INNER JOIN information_schema.key_column_usage AS kc ON ( + tc.table_name = kc.table_name AND + tc.table_schema = kc.table_schema AND + tc.constraint_name = kc.constraint_name) + WHERE + tc.constraint_type = %s AND + tc.table_name = %s AND + tc.table_schema = %s""" + ctype = 'PRIMARY KEY' + cursor = self.execute_sql(query, (ctype, table, schema or 'public')) + return [pk for pk, in cursor.fetchall()] + + def get_foreign_keys(self, table, schema=None): + sql = """ + SELECT DISTINCT + kcu.column_name, ccu.table_name, ccu.column_name + FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON (tc.constraint_name = kcu.constraint_name AND + tc.constraint_schema = kcu.constraint_schema AND + tc.table_name = kcu.table_name AND + tc.table_schema = kcu.table_schema) + JOIN information_schema.constraint_column_usage AS ccu + ON (ccu.constraint_name = tc.constraint_name AND + ccu.constraint_schema = tc.constraint_schema) + WHERE + tc.constraint_type = 'FOREIGN KEY' AND + tc.table_name = %s AND + tc.table_schema = %s""" + cursor = self.execute_sql(sql, (table, schema or 'public')) + return [ForeignKeyMetadata(row[0], row[1], row[2], table) + for row in cursor.fetchall()] + + def sequence_exists(self, sequence): + res = self.execute_sql(""" + SELECT COUNT(*) FROM pg_class, pg_namespace + WHERE relkind='S' + AND pg_class.relnamespace = pg_namespace.oid + AND relname=%s""", (sequence,)) + return bool(res.fetchone()[0]) + + def get_binary_type(self): + return psycopg2.Binary + + def conflict_statement(self, on_conflict, query): + return + + def conflict_update(self, oc, query): + action = oc._action.lower() if oc._action else '' + if action in ('ignore', 'nothing'): + parts = [SQL('ON CONFLICT')] + if oc._conflict_target: + parts.append(EnclosedNodeList([ + Entity(col) if isinstance(col, basestring) else col + for col in oc._conflict_target])) + parts.append(SQL('DO NOTHING')) + return NodeList(parts) + elif action and action != 'update': + raise ValueError('The only supported actions for conflict ' + 'resolution with Postgresql are "ignore" or ' + '"update".') + elif not oc._update and not oc._preserve: + raise ValueError('If you are not performing any updates (or ' + 'preserving any INSERTed values), then the ' + 'conflict resolution action should be set to ' + '"IGNORE".') + elif not (oc._conflict_target or oc._conflict_constraint): + raise ValueError('Postgres requires that a conflict target be ' + 'specified when doing an upsert.') + + return self._build_on_conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.EXTRACT(NodeList((date_part, SQL('FROM'), date_field))) + + def truncate_date(self, date_part, date_field): + return fn.DATE_TRUNC(date_part, date_field) + + def to_timestamp(self, date_field): + return self.extract_date('EPOCH', date_field) + + def from_timestamp(self, date_field): + # Ironically, here, Postgres means "to the Postgresql timestamp type". + return fn.to_timestamp(date_field) + + def get_noop_select(self, ctx): + return ctx.sql(Select().columns(SQL('0')).where(SQL('false'))) + + def set_time_zone(self, timezone): + self.execute_sql('set time zone "%s";' % timezone) + + +class MySQLDatabase(Database): + field_types = { + 'AUTO': 'INTEGER AUTO_INCREMENT', + 'BIGAUTO': 'BIGINT AUTO_INCREMENT', + 'BOOL': 'BOOL', + 'DECIMAL': 'NUMERIC', + 'DOUBLE': 'DOUBLE PRECISION', + 'FLOAT': 'FLOAT', + 'UUID': 'VARCHAR(40)', + 'UUIDB': 'VARBINARY(16)'} + operations = { + 'LIKE': 'LIKE BINARY', + 'ILIKE': 'LIKE', + 'REGEXP': 'REGEXP BINARY', + 'IREGEXP': 'REGEXP', + 'XOR': 'XOR'} + param = '%s' + quote = '``' + + compound_select_parentheses = CSQ_PARENTHESES_UNNESTED + for_update = True + index_using_precedes_table = True + limit_max = 2 ** 64 - 1 + safe_create_index = False + safe_drop_index = False + sql_mode = 'PIPES_AS_CONCAT' + + def init(self, database, **kwargs): + params = { + 'charset': 'utf8', + 'sql_mode': self.sql_mode, + 'use_unicode': True} + params.update(kwargs) + if 'password' in params and mysql_passwd: + params['passwd'] = params.pop('password') + super(MySQLDatabase, self).init(database, **params) + + def _connect(self): + if mysql is None: + raise ImproperlyConfigured('MySQL driver not installed!') + conn = mysql.connect(db=self.database, autocommit=True, + **self.connect_params) + return conn + + def _set_server_version(self, conn): + try: + version_raw = conn.server_version + except AttributeError: + version_raw = conn.get_server_info() + self.server_version = self._extract_server_version(version_raw) + + def _extract_server_version(self, version): + version = version.lower() + if 'maria' in version: + match_obj = re.search(r'(1\d\.\d+\.\d+)', version) + else: + match_obj = re.search(r'(\d\.\d+\.\d+)', version) + if match_obj is not None: + return tuple(int(num) for num in match_obj.groups()[0].split('.')) + + warnings.warn('Unable to determine MySQL version: "%s"' % version) + return (0, 0, 0) # Unable to determine version! + + def is_connection_usable(self): + if self._state.closed: + return False + + conn = self._state.conn + if hasattr(conn, 'ping'): + try: + conn.ping(False) + except Exception: + return False + return True + + def default_values_insert(self, ctx): + return ctx.literal('() VALUES ()') + + def begin(self, isolation_level=None): + if self.is_closed(): + self.connect() + with __exception_wrapper__: + curs = self.cursor() + if isolation_level: + curs.execute('SET TRANSACTION ISOLATION LEVEL %s' % + isolation_level) + curs.execute('BEGIN') + + def get_tables(self, schema=None): + query = ('SELECT table_name FROM information_schema.tables ' + 'WHERE table_schema = DATABASE() AND table_type != %s ' + 'ORDER BY table_name') + return [table for table, in self.execute_sql(query, ('VIEW',))] + + def get_views(self, schema=None): + query = ('SELECT table_name, view_definition ' + 'FROM information_schema.views ' + 'WHERE table_schema = DATABASE() ORDER BY table_name') + cursor = self.execute_sql(query) + return [ViewMetadata(*row) for row in cursor.fetchall()] + + def get_indexes(self, table, schema=None): + cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) + unique = set() + indexes = {} + for row in cursor.fetchall(): + if not row[1]: + unique.add(row[2]) + indexes.setdefault(row[2], []) + indexes[row[2]].append(row[4]) + return [IndexMetadata(name, None, indexes[name], name in unique, table) + for name in indexes] + + def get_columns(self, table, schema=None): + sql = """ + SELECT column_name, is_nullable, data_type, column_default + FROM information_schema.columns + WHERE table_name = %s AND table_schema = DATABASE() + ORDER BY ordinal_position""" + cursor = self.execute_sql(sql, (table,)) + pks = set(self.get_primary_keys(table)) + return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) + for name, null, dt, df in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) + return [row[4] for row in + filter(lambda row: row[2] == 'PRIMARY', cursor.fetchall())] + + def get_foreign_keys(self, table, schema=None): + query = """ + SELECT column_name, referenced_table_name, referenced_column_name + FROM information_schema.key_column_usage + WHERE table_name = %s + AND table_schema = DATABASE() + AND referenced_table_name IS NOT NULL + AND referenced_column_name IS NOT NULL""" + cursor = self.execute_sql(query, (table,)) + return [ + ForeignKeyMetadata(column, dest_table, dest_column, table) + for column, dest_table, dest_column in cursor.fetchall()] + + def get_binary_type(self): + return mysql.Binary + + def conflict_statement(self, on_conflict, query): + if not on_conflict._action: return + + action = on_conflict._action.lower() + if action == 'replace': + return SQL('REPLACE') + elif action == 'ignore': + return SQL('INSERT IGNORE') + elif action != 'update': + raise ValueError('Un-supported action for conflict resolution. ' + 'MySQL supports REPLACE, IGNORE and UPDATE.') + + def conflict_update(self, on_conflict, query): + if on_conflict._where or on_conflict._conflict_target or \ + on_conflict._conflict_constraint: + raise ValueError('MySQL does not support the specification of ' + 'where clauses or conflict targets for conflict ' + 'resolution.') + + updates = [] + if on_conflict._preserve: + # Here we need to determine which function to use, which varies + # depending on the MySQL server version. MySQL and MariaDB prior to + # 10.3.3 use "VALUES", while MariaDB 10.3.3+ use "VALUE". + version = self.server_version or (0,) + if version[0] == 10 and version >= (10, 3, 3): + VALUE_FN = fn.VALUE + else: + VALUE_FN = fn.VALUES + + for column in on_conflict._preserve: + entity = ensure_entity(column) + expression = NodeList(( + ensure_entity(column), + SQL('='), + VALUE_FN(entity))) + updates.append(expression) + + if on_conflict._update: + for k, v in on_conflict._update.items(): + if not isinstance(v, Node): + # Attempt to resolve string field-names to their respective + # field object, to apply data-type conversions. + if isinstance(k, basestring): + k = getattr(query.table, k) + if isinstance(k, Field): + v = k.to_value(v) + else: + v = Value(v, unpack=False) + updates.append(NodeList((ensure_entity(k), SQL('='), v))) + + if updates: + return NodeList((SQL('ON DUPLICATE KEY UPDATE'), + CommaNodeList(updates))) + + def extract_date(self, date_part, date_field): + return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field))) + + def truncate_date(self, date_part, date_field): + return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part], + python_value=simple_date_time) + + def to_timestamp(self, date_field): + return fn.UNIX_TIMESTAMP(date_field) + + def from_timestamp(self, date_field): + return fn.FROM_UNIXTIME(date_field) + + def random(self): + return fn.rand() + + def get_noop_select(self, ctx): + return ctx.literal('DO 0') + + +# TRANSACTION CONTROL. + + +class _manual(object): + def __init__(self, db): + self.db = db + + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + with _manual(self.db): + return fn(*args, **kwargs) + return inner + + def __enter__(self): + top = self.db.top_transaction() + if top is not None and not isinstance(top, _manual): + raise ValueError('Cannot enter manual commit block while a ' + 'transaction is active.') + self.db.push_transaction(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.db.pop_transaction() is not self: + raise ValueError('Transaction stack corrupted while exiting ' + 'manual commit block.') + + +class _atomic(object): + def __init__(self, db, *args, **kwargs): + self.db = db + self._transaction_args = (args, kwargs) + + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + a, k = self._transaction_args + with _atomic(self.db, *a, **k): + return fn(*args, **kwargs) + return inner + + def __enter__(self): + if self.db.transaction_depth() == 0: + args, kwargs = self._transaction_args + self._helper = self.db.transaction(*args, **kwargs) + elif isinstance(self.db.top_transaction(), _manual): + raise ValueError('Cannot enter atomic commit block while in ' + 'manual commit mode.') + else: + self._helper = self.db.savepoint() + return self._helper.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._helper.__exit__(exc_type, exc_val, exc_tb) + + +class _transaction(object): + def __init__(self, db, *args, **kwargs): + self.db = db + self._begin_args = (args, kwargs) + + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + a, k = self._begin_args + with _transaction(self.db, *a, **k): + return fn(*args, **kwargs) + return inner + + def _begin(self): + args, kwargs = self._begin_args + self.db.begin(*args, **kwargs) + + def commit(self, begin=True): + self.db.commit() + if begin: + self._begin() + + def rollback(self, begin=True): + self.db.rollback() + if begin: + self._begin() + + def __enter__(self): + if self.db.transaction_depth() == 0: + self._begin() + self.db.push_transaction(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + depth = self.db.transaction_depth() + try: + if exc_type and depth == 1: + self.rollback(False) + elif depth == 1: + try: + self.commit(False) + except: + self.rollback(False) + raise + finally: + self.db.pop_transaction() + + +class _savepoint(object): + def __init__(self, db, sid=None): + self.db = db + self.sid = sid or 's' + uuid.uuid4().hex + self.quoted_sid = self.sid.join(self.db.quote) + + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + with _savepoint(self.db): + return fn(*args, **kwargs) + return inner + + def _begin(self): + self.db.execute_sql('SAVEPOINT %s;' % self.quoted_sid) + + def commit(self, begin=True): + self.db.execute_sql('RELEASE SAVEPOINT %s;' % self.quoted_sid) + if begin: self._begin() + + def rollback(self): + self.db.execute_sql('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) + + def __enter__(self): + self._begin() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type: + self.rollback() + else: + try: + self.commit(begin=False) + except: + self.rollback() + raise + + +# CURSOR REPRESENTATIONS. + + +class CursorWrapper(object): + def __init__(self, cursor): + self.cursor = cursor + self.count = 0 + self.index = 0 + self.initialized = False + self.populated = False + self.row_cache = [] + + def __iter__(self): + if self.populated: + return iter(self.row_cache) + return ResultIterator(self) + + def __getitem__(self, item): + if isinstance(item, slice): + stop = item.stop + if stop is None or stop < 0: + self.fill_cache() + else: + self.fill_cache(stop) + return self.row_cache[item] + elif isinstance(item, int): + self.fill_cache(item if item > 0 else 0) + return self.row_cache[item] + else: + raise ValueError('CursorWrapper only supports integer and slice ' + 'indexes.') + + def __len__(self): + self.fill_cache() + return self.count + + def initialize(self): + pass + + def iterate(self, cache=True): + row = self.cursor.fetchone() + if row is None: + self.populated = True + self.cursor.close() + raise StopIteration + elif not self.initialized: + self.initialize() # Lazy initialization. + self.initialized = True + self.count += 1 + result = self.process_row(row) + if cache: + self.row_cache.append(result) + return result + + def process_row(self, row): + return row + + def iterator(self): + """Efficient one-pass iteration over the result set.""" + while True: + try: + yield self.iterate(False) + except StopIteration: + return + + def fill_cache(self, n=0): + n = n or float('Inf') + if n < 0: + raise ValueError('Negative values are not supported.') + + iterator = ResultIterator(self) + iterator.index = self.count + while not self.populated and (n > self.count): + try: + iterator.next() + except StopIteration: + break + + +class DictCursorWrapper(CursorWrapper): + def _initialize_columns(self): + description = self.cursor.description + self.columns = [t[0][t[0].rfind('.') + 1:].strip('()"`') + for t in description] + self.ncols = len(description) + + initialize = _initialize_columns + + def _row_to_dict(self, row): + result = {} + for i in range(self.ncols): + result.setdefault(self.columns[i], row[i]) # Do not overwrite. + return result + + process_row = _row_to_dict + + +class NamedTupleCursorWrapper(CursorWrapper): + def initialize(self): + description = self.cursor.description + self.tuple_class = collections.namedtuple('Row', [ + t[0][t[0].rfind('.') + 1:].strip('()"`') for t in description]) + + def process_row(self, row): + return self.tuple_class(*row) + + +class ObjectCursorWrapper(DictCursorWrapper): + def __init__(self, cursor, constructor): + super(ObjectCursorWrapper, self).__init__(cursor) + self.constructor = constructor + + def process_row(self, row): + row_dict = self._row_to_dict(row) + return self.constructor(**row_dict) + + +class ResultIterator(object): + def __init__(self, cursor_wrapper): + self.cursor_wrapper = cursor_wrapper + self.index = 0 + + def __iter__(self): + return self + + def next(self): + if self.index < self.cursor_wrapper.count: + obj = self.cursor_wrapper.row_cache[self.index] + elif not self.cursor_wrapper.populated: + self.cursor_wrapper.iterate() + obj = self.cursor_wrapper.row_cache[self.index] + else: + raise StopIteration + self.index += 1 + return obj + + __next__ = next + +# FIELDS + +class FieldAccessor(object): + def __init__(self, model, field, name): + self.model = model + self.field = field + self.name = name + + def __get__(self, instance, instance_type=None): + if instance is not None: + return instance.__data__.get(self.name) + return self.field + + def __set__(self, instance, value): + instance.__data__[self.name] = value + instance._dirty.add(self.name) + + +class ForeignKeyAccessor(FieldAccessor): + def __init__(self, model, field, name): + super(ForeignKeyAccessor, self).__init__(model, field, name) + self.rel_model = field.rel_model + + def get_rel_instance(self, instance): + value = instance.__data__.get(self.name) + if value is not None or self.name in instance.__rel__: + if self.name not in instance.__rel__ and self.field.lazy_load: + obj = self.rel_model.get(self.field.rel_field == value) + instance.__rel__[self.name] = obj + return instance.__rel__.get(self.name, value) + elif not self.field.null and self.field.lazy_load: + raise self.rel_model.DoesNotExist + return value + + def __get__(self, instance, instance_type=None): + if instance is not None: + return self.get_rel_instance(instance) + return self.field + + def __set__(self, instance, obj): + if isinstance(obj, self.rel_model): + instance.__data__[self.name] = getattr(obj, self.field.rel_field.name) + instance.__rel__[self.name] = obj + else: + fk_value = instance.__data__.get(self.name) + instance.__data__[self.name] = obj + if (obj != fk_value or obj is None) and \ + self.name in instance.__rel__: + del instance.__rel__[self.name] + instance._dirty.add(self.name) + + +class BackrefAccessor(object): + def __init__(self, field): + self.field = field + self.model = field.rel_model + self.rel_model = field.model + + def __get__(self, instance, instance_type=None): + if instance is not None: + dest = self.field.rel_field.name + return (self.rel_model + .select() + .where(self.field == getattr(instance, dest))) + return self + + +class ObjectIdAccessor(object): + """Gives direct access to the underlying id""" + def __init__(self, field): + self.field = field + + def __get__(self, instance, instance_type=None): + if instance is not None: + value = instance.__data__.get(self.field.name) + # Pull the object-id from the related object if it is not set. + if value is None and self.field.name in instance.__rel__: + rel_obj = instance.__rel__[self.field.name] + value = getattr(rel_obj, self.field.rel_field.name) + return value + return self.field + + def __set__(self, instance, value): + setattr(instance, self.field.name, value) + + +class Field(ColumnBase): + _field_counter = 0 + _order = 0 + accessor_class = FieldAccessor + auto_increment = False + default_index_type = None + field_type = 'DEFAULT' + unpack = True + + def __init__(self, null=False, index=False, unique=False, column_name=None, + default=None, primary_key=False, constraints=None, + sequence=None, collation=None, unindexed=False, choices=None, + help_text=None, verbose_name=None, index_type=None, + db_column=None, _hidden=False): + if db_column is not None: + __deprecated__('"db_column" has been deprecated in favor of ' + '"column_name" for Field objects.') + column_name = db_column + + self.null = null + self.index = index + self.unique = unique + self.column_name = column_name + self.default = default + self.primary_key = primary_key + self.constraints = constraints # List of column constraints. + self.sequence = sequence # Name of sequence, e.g. foo_id_seq. + self.collation = collation + self.unindexed = unindexed + self.choices = choices + self.help_text = help_text + self.verbose_name = verbose_name + self.index_type = index_type or self.default_index_type + self._hidden = _hidden + + # Used internally for recovering the order in which Fields were defined + # on the Model class. + Field._field_counter += 1 + self._order = Field._field_counter + self._sort_key = (self.primary_key and 1 or 2), self._order + + def __hash__(self): + return hash(self.name + '.' + self.model.__name__) + + def __repr__(self): + if hasattr(self, 'model') and getattr(self, 'name', None): + return '<%s: %s.%s>' % (type(self).__name__, + self.model.__name__, + self.name) + return '<%s: (unbound)>' % type(self).__name__ + + def bind(self, model, name, set_attribute=True): + self.model = model + self.name = self.safe_name = name + self.column_name = self.column_name or name + if set_attribute: + setattr(model, name, self.accessor_class(model, self, name)) + + @property + def column(self): + return Column(self.model._meta.table, self.column_name) + + def adapt(self, value): + return value + + def db_value(self, value): + return value if value is None else self.adapt(value) + + def python_value(self, value): + return value if value is None else self.adapt(value) + + def to_value(self, value): + return Value(value, self.db_value, unpack=False) + + def get_sort_key(self, ctx): + return self._sort_key + + def __sql__(self, ctx): + return ctx.sql(self.column) + + def get_modifiers(self): + pass + + def ddl_datatype(self, ctx): + if ctx and ctx.state.field_types: + column_type = ctx.state.field_types.get(self.field_type, + self.field_type) + else: + column_type = self.field_type + + modifiers = self.get_modifiers() + if column_type and modifiers: + modifier_literal = ', '.join([str(m) for m in modifiers]) + return SQL('%s(%s)' % (column_type, modifier_literal)) + else: + return SQL(column_type) + + def ddl(self, ctx): + accum = [Entity(self.column_name)] + data_type = self.ddl_datatype(ctx) + if data_type: + accum.append(data_type) + if self.unindexed: + accum.append(SQL('UNINDEXED')) + if not self.null: + accum.append(SQL('NOT NULL')) + if self.primary_key: + accum.append(SQL('PRIMARY KEY')) + if self.sequence: + accum.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence)) + if self.constraints: + accum.extend(self.constraints) + if self.collation: + accum.append(SQL('COLLATE %s' % self.collation)) + return NodeList(accum) + + +class AnyField(Field): + field_type = 'ANY' + + +class IntegerField(Field): + field_type = 'INT' + + def adapt(self, value): + try: + return int(value) + except ValueError: + return value + + +class BigIntegerField(IntegerField): + field_type = 'BIGINT' + + +class SmallIntegerField(IntegerField): + field_type = 'SMALLINT' + + +class AutoField(IntegerField): + auto_increment = True + field_type = 'AUTO' + + def __init__(self, *args, **kwargs): + if kwargs.get('primary_key') is False: + raise ValueError('%s must always be a primary key.' % type(self)) + kwargs['primary_key'] = True + super(AutoField, self).__init__(*args, **kwargs) + + +class BigAutoField(AutoField): + field_type = 'BIGAUTO' + + +class IdentityField(AutoField): + field_type = 'INT GENERATED BY DEFAULT AS IDENTITY' + + def __init__(self, generate_always=False, **kwargs): + if generate_always: + self.field_type = 'INT GENERATED ALWAYS AS IDENTITY' + super(IdentityField, self).__init__(**kwargs) + + +class PrimaryKeyField(AutoField): + def __init__(self, *args, **kwargs): + __deprecated__('"PrimaryKeyField" has been renamed to "AutoField". ' + 'Please update your code accordingly as this will be ' + 'completely removed in a subsequent release.') + super(PrimaryKeyField, self).__init__(*args, **kwargs) + + +class FloatField(Field): + field_type = 'FLOAT' + + def adapt(self, value): + try: + return float(value) + except ValueError: + return value + + +class DoubleField(FloatField): + field_type = 'DOUBLE' + + +class DecimalField(Field): + field_type = 'DECIMAL' + + def __init__(self, max_digits=10, decimal_places=5, auto_round=False, + rounding=None, *args, **kwargs): + self.max_digits = max_digits + self.decimal_places = decimal_places + self.auto_round = auto_round + self.rounding = rounding or decimal.DefaultContext.rounding + self._exp = decimal.Decimal(10) ** (-self.decimal_places) + super(DecimalField, self).__init__(*args, **kwargs) + + def get_modifiers(self): + return [self.max_digits, self.decimal_places] + + def db_value(self, value): + D = decimal.Decimal + if not value: + return value if value is None else D(0) + if self.auto_round: + decimal_value = D(text_type(value)) + return decimal_value.quantize(self._exp, rounding=self.rounding) + return value + + def python_value(self, value): + if value is not None: + if isinstance(value, decimal.Decimal): + return value + return decimal.Decimal(text_type(value)) + + +class _StringField(Field): + def adapt(self, value): + if isinstance(value, text_type): + return value + elif isinstance(value, bytes_type): + return value.decode('utf-8') + return text_type(value) + + def __add__(self, other): return StringExpression(self, OP.CONCAT, other) + def __radd__(self, other): return StringExpression(other, OP.CONCAT, self) + + +class CharField(_StringField): + field_type = 'VARCHAR' + + def __init__(self, max_length=255, *args, **kwargs): + self.max_length = max_length + super(CharField, self).__init__(*args, **kwargs) + + def get_modifiers(self): + return self.max_length and [self.max_length] or None + + +class FixedCharField(CharField): + field_type = 'CHAR' + + def python_value(self, value): + value = super(FixedCharField, self).python_value(value) + if value: + value = value.strip() + return value + + +class TextField(_StringField): + field_type = 'TEXT' + + +class BlobField(Field): + field_type = 'BLOB' + + def _db_hook(self, database): + if database is None: + self._constructor = bytearray + else: + self._constructor = database.get_binary_type() + + def bind(self, model, name, set_attribute=True): + self._constructor = bytearray + if model._meta.database: + if isinstance(model._meta.database, Proxy): + model._meta.database.attach_callback(self._db_hook) + else: + self._db_hook(model._meta.database) + + # Attach a hook to the model metadata; in the event the database is + # changed or set at run-time, we will be sure to apply our callback and + # use the proper data-type for our database driver. + model._meta._db_hooks.append(self._db_hook) + return super(BlobField, self).bind(model, name, set_attribute) + + def db_value(self, value): + if isinstance(value, text_type): + value = value.encode('raw_unicode_escape') + if isinstance(value, bytes_type): + return self._constructor(value) + return value + + +class BitField(BitwiseMixin, BigIntegerField): + def __init__(self, *args, **kwargs): + kwargs.setdefault('default', 0) + super(BitField, self).__init__(*args, **kwargs) + self.__current_flag = 1 + + def flag(self, value=None): + if value is None: + value = self.__current_flag + self.__current_flag <<= 1 + else: + self.__current_flag = value << 1 + + class FlagDescriptor(ColumnBase): + def __init__(self, field, value): + self._field = field + self._value = value + super(FlagDescriptor, self).__init__() + def clear(self): + return self._field.bin_and(~self._value) + def set(self): + return self._field.bin_or(self._value) + def __get__(self, instance, instance_type=None): + if instance is None: + return self + value = getattr(instance, self._field.name) or 0 + return (value & self._value) != 0 + def __set__(self, instance, is_set): + if is_set not in (True, False): + raise ValueError('Value must be either True or False') + value = getattr(instance, self._field.name) or 0 + if is_set: + value |= self._value + else: + value &= ~self._value + setattr(instance, self._field.name, value) + def __sql__(self, ctx): + return ctx.sql(self._field.bin_and(self._value) != 0) + return FlagDescriptor(self, value) + + +class BigBitFieldData(object): + def __init__(self, instance, name): + self.instance = instance + self.name = name + value = self.instance.__data__.get(self.name) + if not value: + value = bytearray() + elif not isinstance(value, bytearray): + value = bytearray(value) + self._buffer = self.instance.__data__[self.name] = value + + def _ensure_length(self, idx): + byte_num, byte_offset = divmod(idx, 8) + cur_size = len(self._buffer) + if cur_size <= byte_num: + self._buffer.extend(b'\x00' * ((byte_num + 1) - cur_size)) + return byte_num, byte_offset + + def set_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] |= (1 << byte_offset) + + def clear_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] &= ~(1 << byte_offset) + + def toggle_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] ^= (1 << byte_offset) + return bool(self._buffer[byte_num] & (1 << byte_offset)) + + def is_set(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + return bool(self._buffer[byte_num] & (1 << byte_offset)) + + def __repr__(self): + return repr(self._buffer) + if sys.version_info[0] < 3: + def __str__(self): + return bytes_type(self._buffer) + else: + def __bytes__(self): + return bytes_type(self._buffer) + + +class BigBitFieldAccessor(FieldAccessor): + def __get__(self, instance, instance_type=None): + if instance is None: + return self.field + return BigBitFieldData(instance, self.name) + def __set__(self, instance, value): + if isinstance(value, memoryview): + value = value.tobytes() + elif isinstance(value, buffer_type): + value = bytes(value) + elif isinstance(value, bytearray): + value = bytes_type(value) + elif isinstance(value, BigBitFieldData): + value = bytes_type(value._buffer) + elif isinstance(value, text_type): + value = value.encode('utf-8') + elif not isinstance(value, bytes_type): + raise ValueError('Value must be either a bytes, memoryview or ' + 'BigBitFieldData instance.') + super(BigBitFieldAccessor, self).__set__(instance, value) + + +class BigBitField(BlobField): + accessor_class = BigBitFieldAccessor + + def __init__(self, *args, **kwargs): + kwargs.setdefault('default', bytes_type) + super(BigBitField, self).__init__(*args, **kwargs) + + def db_value(self, value): + return bytes_type(value) if value is not None else value + + +class UUIDField(Field): + field_type = 'UUID' + + def db_value(self, value): + if isinstance(value, basestring) and len(value) == 32: + # Hex string. No transformation is necessary. + return value + elif isinstance(value, bytes) and len(value) == 16: + # Allow raw binary representation. + value = uuid.UUID(bytes=value) + if isinstance(value, uuid.UUID): + return value.hex + try: + return uuid.UUID(value).hex + except: + return value + + def python_value(self, value): + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(value) if value is not None else None + + +class BinaryUUIDField(BlobField): + field_type = 'UUIDB' + + def db_value(self, value): + if isinstance(value, bytes) and len(value) == 16: + # Raw binary value. No transformation is necessary. + return self._constructor(value) + elif isinstance(value, basestring) and len(value) == 32: + # Allow hex string representation. + value = uuid.UUID(hex=value) + if isinstance(value, uuid.UUID): + return self._constructor(value.bytes) + elif value is not None: + raise ValueError('value for binary UUID field must be UUID(), ' + 'a hexadecimal string, or a bytes object.') + + def python_value(self, value): + if isinstance(value, uuid.UUID): + return value + elif isinstance(value, memoryview): + value = value.tobytes() + elif value and not isinstance(value, bytes): + value = bytes(value) + return uuid.UUID(bytes=value) if value is not None else None + + +def _date_part(date_part): + def dec(self): + return self.model._meta.database.extract_date(date_part, self) + return dec + +def format_date_time(value, formats, post_process=None): + post_process = post_process or (lambda x: x) + for fmt in formats: + try: + return post_process(datetime.datetime.strptime(value, fmt)) + except ValueError: + pass + return value + +def simple_date_time(value): + try: + return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') + except (TypeError, ValueError): + return value + + +class _BaseFormattedField(Field): + formats = None + + def __init__(self, formats=None, *args, **kwargs): + if formats is not None: + self.formats = formats + super(_BaseFormattedField, self).__init__(*args, **kwargs) + + +class DateTimeField(_BaseFormattedField): + field_type = 'DATETIME' + formats = [ + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d', + ] + + def adapt(self, value): + if value and isinstance(value, basestring): + return format_date_time(value, self.formats) + return value + + def to_timestamp(self): + return self.model._meta.database.to_timestamp(self) + + def truncate(self, part): + return self.model._meta.database.truncate_date(part, self) + + year = property(_date_part('year')) + month = property(_date_part('month')) + day = property(_date_part('day')) + hour = property(_date_part('hour')) + minute = property(_date_part('minute')) + second = property(_date_part('second')) + + +class DateField(_BaseFormattedField): + field_type = 'DATE' + formats = [ + '%Y-%m-%d', + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + ] + + def adapt(self, value): + if value and isinstance(value, basestring): + pp = lambda x: x.date() + return format_date_time(value, self.formats, pp) + elif value and isinstance(value, datetime.datetime): + return value.date() + return value + + def to_timestamp(self): + return self.model._meta.database.to_timestamp(self) + + def truncate(self, part): + return self.model._meta.database.truncate_date(part, self) + + year = property(_date_part('year')) + month = property(_date_part('month')) + day = property(_date_part('day')) + + +class TimeField(_BaseFormattedField): + field_type = 'TIME' + formats = [ + '%H:%M:%S.%f', + '%H:%M:%S', + '%H:%M', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d %H:%M:%S', + ] + + def adapt(self, value): + if value: + if isinstance(value, basestring): + pp = lambda x: x.time() + return format_date_time(value, self.formats, pp) + elif isinstance(value, datetime.datetime): + return value.time() + if value is not None and isinstance(value, datetime.timedelta): + return (datetime.datetime.min + value).time() + return value + + hour = property(_date_part('hour')) + minute = property(_date_part('minute')) + second = property(_date_part('second')) + + +def _timestamp_date_part(date_part): + def dec(self): + db = self.model._meta.database + expr = ((self / Value(self.resolution, converter=False)) + if self.resolution > 1 else self) + return db.extract_date(date_part, db.from_timestamp(expr)) + return dec + + +class TimestampField(BigIntegerField): + # Support second -> microsecond resolution. + valid_resolutions = [10**i for i in range(7)] + + def __init__(self, *args, **kwargs): + self.resolution = kwargs.pop('resolution', None) + + if not self.resolution: + self.resolution = 1 + elif self.resolution in range(2, 7): + self.resolution = 10 ** self.resolution + elif self.resolution not in self.valid_resolutions: + raise ValueError('TimestampField resolution must be one of: %s' % + ', '.join(str(i) for i in self.valid_resolutions)) + self.ticks_to_microsecond = 1000000 // self.resolution + + self.utc = kwargs.pop('utc', False) or False + dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now + kwargs.setdefault('default', dflt) + super(TimestampField, self).__init__(*args, **kwargs) + + def local_to_utc(self, dt): + # Convert naive local datetime into naive UTC, e.g.: + # 2019-03-01T12:00:00 (local=US/Central) -> 2019-03-01T18:00:00. + # 2019-05-01T12:00:00 (local=US/Central) -> 2019-05-01T17:00:00. + # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. + return datetime.datetime(*time.gmtime(time.mktime(dt.timetuple()))[:6]) + + def utc_to_local(self, dt): + # Convert a naive UTC datetime into local time, e.g.: + # 2019-03-01T18:00:00 (local=US/Central) -> 2019-03-01T12:00:00. + # 2019-05-01T17:00:00 (local=US/Central) -> 2019-05-01T12:00:00. + # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. + ts = calendar.timegm(dt.utctimetuple()) + return datetime.datetime.fromtimestamp(ts) + + def get_timestamp(self, value): + if self.utc: + # If utc-mode is on, then we assume all naive datetimes are in UTC. + return calendar.timegm(value.utctimetuple()) + else: + return time.mktime(value.timetuple()) + + def db_value(self, value): + if value is None: + return + + if isinstance(value, datetime.datetime): + pass + elif isinstance(value, datetime.date): + value = datetime.datetime(value.year, value.month, value.day) + else: + return int(round(value * self.resolution)) + + timestamp = self.get_timestamp(value) + if self.resolution > 1: + timestamp += (value.microsecond * .000001) + timestamp *= self.resolution + return int(round(timestamp)) + + def python_value(self, value): + if value is not None and isinstance(value, (int, float, long)): + if self.resolution > 1: + value, ticks = divmod(value, self.resolution) + microseconds = int(ticks * self.ticks_to_microsecond) + else: + microseconds = 0 + + if self.utc: + value = datetime.datetime.utcfromtimestamp(value) + else: + value = datetime.datetime.fromtimestamp(value) + + if microseconds: + value = value.replace(microsecond=microseconds) + + return value + + def from_timestamp(self): + expr = ((self / Value(self.resolution, converter=False)) + if self.resolution > 1 else self) + return self.model._meta.database.from_timestamp(expr) + + year = property(_timestamp_date_part('year')) + month = property(_timestamp_date_part('month')) + day = property(_timestamp_date_part('day')) + hour = property(_timestamp_date_part('hour')) + minute = property(_timestamp_date_part('minute')) + second = property(_timestamp_date_part('second')) + + +class IPField(BigIntegerField): + def db_value(self, val): + if val is not None: + return struct.unpack('!I', socket.inet_aton(val))[0] + + def python_value(self, val): + if val is not None: + return socket.inet_ntoa(struct.pack('!I', val)) + + +class BooleanField(Field): + field_type = 'BOOL' + adapt = bool + + +class BareField(Field): + def __init__(self, adapt=None, *args, **kwargs): + super(BareField, self).__init__(*args, **kwargs) + if adapt is not None: + self.adapt = adapt + + def ddl_datatype(self, ctx): + return + + +class ForeignKeyField(Field): + accessor_class = ForeignKeyAccessor + backref_accessor_class = BackrefAccessor + + def __init__(self, model, field=None, backref=None, on_delete=None, + on_update=None, deferrable=None, _deferred=None, + rel_model=None, to_field=None, object_id_name=None, + lazy_load=True, constraint_name=None, related_name=None, + *args, **kwargs): + kwargs.setdefault('index', True) + + super(ForeignKeyField, self).__init__(*args, **kwargs) + + if rel_model is not None: + __deprecated__('"rel_model" has been deprecated in favor of ' + '"model" for ForeignKeyField objects.') + model = rel_model + if to_field is not None: + __deprecated__('"to_field" has been deprecated in favor of ' + '"field" for ForeignKeyField objects.') + field = to_field + if related_name is not None: + __deprecated__('"related_name" has been deprecated in favor of ' + '"backref" for Field objects.') + backref = related_name + + self._is_self_reference = model == 'self' + self.rel_model = model + self.rel_field = field + self.declared_backref = backref + self.backref = None + self.on_delete = on_delete + self.on_update = on_update + self.deferrable = deferrable + self.deferred = _deferred + self.object_id_name = object_id_name + self.lazy_load = lazy_load + self.constraint_name = constraint_name + + @property + def field_type(self): + if not isinstance(self.rel_field, AutoField): + return self.rel_field.field_type + elif isinstance(self.rel_field, BigAutoField): + return BigIntegerField.field_type + return IntegerField.field_type + + def get_modifiers(self): + if not isinstance(self.rel_field, AutoField): + return self.rel_field.get_modifiers() + return super(ForeignKeyField, self).get_modifiers() + + def adapt(self, value): + return self.rel_field.adapt(value) + + def db_value(self, value): + if isinstance(value, self.rel_model): + value = getattr(value, self.rel_field.name) + return self.rel_field.db_value(value) + + def python_value(self, value): + if isinstance(value, self.rel_model): + return value + return self.rel_field.python_value(value) + + def bind(self, model, name, set_attribute=True): + if not self.column_name: + self.column_name = name if name.endswith('_id') else name + '_id' + if not self.object_id_name: + self.object_id_name = self.column_name + if self.object_id_name == name: + self.object_id_name += '_id' + elif self.object_id_name == name: + raise ValueError('ForeignKeyField "%s"."%s" specifies an ' + 'object_id_name that conflicts with its field ' + 'name.' % (model._meta.name, name)) + if self._is_self_reference: + self.rel_model = model + if isinstance(self.rel_field, basestring): + self.rel_field = getattr(self.rel_model, self.rel_field) + elif self.rel_field is None: + self.rel_field = self.rel_model._meta.primary_key + + # Bind field before assigning backref, so field is bound when + # calling declared_backref() (if callable). + super(ForeignKeyField, self).bind(model, name, set_attribute) + self.safe_name = self.object_id_name + + if callable_(self.declared_backref): + self.backref = self.declared_backref(self) + else: + self.backref, self.declared_backref = self.declared_backref, None + if not self.backref: + self.backref = '%s_set' % model._meta.name + + if set_attribute: + setattr(model, self.object_id_name, ObjectIdAccessor(self)) + if self.backref not in '!+': + setattr(self.rel_model, self.backref, + self.backref_accessor_class(self)) + + def foreign_key_constraint(self): + parts = [] + if self.constraint_name: + parts.extend((SQL('CONSTRAINT'), Entity(self.constraint_name))) + parts.extend([ + SQL('FOREIGN KEY'), + EnclosedNodeList((self,)), + SQL('REFERENCES'), + self.rel_model, + EnclosedNodeList((self.rel_field,))]) + if self.on_delete: + parts.append(SQL('ON DELETE %s' % self.on_delete)) + if self.on_update: + parts.append(SQL('ON UPDATE %s' % self.on_update)) + if self.deferrable: + parts.append(SQL('DEFERRABLE %s' % self.deferrable)) + return NodeList(parts) + + def __getattr__(self, attr): + if attr.startswith('__'): + # Prevent recursion error when deep-copying. + raise AttributeError('Cannot look-up non-existant "__" methods.') + if attr in self.rel_model._meta.fields: + return self.rel_model._meta.fields[attr] + raise AttributeError('Foreign-key has no attribute %s, nor is it a ' + 'valid field on the related model.' % attr) + + +class DeferredForeignKey(Field): + _unresolved = set() + + def __init__(self, rel_model_name, **kwargs): + self.field_kwargs = kwargs + self.rel_model_name = rel_model_name.lower() + DeferredForeignKey._unresolved.add(self) + super(DeferredForeignKey, self).__init__( + column_name=kwargs.get('column_name'), + null=kwargs.get('null'), + primary_key=kwargs.get('primary_key')) + + __hash__ = object.__hash__ + + def __deepcopy__(self, memo=None): + return DeferredForeignKey(self.rel_model_name, **self.field_kwargs) + + def set_model(self, rel_model): + field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs) + if field.primary_key: + # NOTE: this calls add_field() under-the-hood. + self.model._meta.set_primary_key(self.name, field) + else: + self.model._meta.add_field(self.name, field) + + @staticmethod + def resolve(model_cls): + unresolved = sorted(DeferredForeignKey._unresolved, + key=operator.attrgetter('_order')) + for dr in unresolved: + if dr.rel_model_name == model_cls.__name__.lower(): + dr.set_model(model_cls) + DeferredForeignKey._unresolved.discard(dr) + + +class DeferredThroughModel(object): + def __init__(self): + self._refs = [] + + def set_field(self, model, field, name): + self._refs.append((model, field, name)) + + def set_model(self, through_model): + for src_model, m2mfield, name in self._refs: + m2mfield.through_model = through_model + src_model._meta.add_field(name, m2mfield) + + +class MetaField(Field): + column_name = default = model = name = None + primary_key = False + + +class ManyToManyFieldAccessor(FieldAccessor): + def __init__(self, model, field, name): + super(ManyToManyFieldAccessor, self).__init__(model, field, name) + self.model = field.model + self.rel_model = field.rel_model + self.through_model = field.through_model + src_fks = self.through_model._meta.model_refs[self.model] + dest_fks = self.through_model._meta.model_refs[self.rel_model] + if not src_fks: + raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % + (self.model, self.through_model)) + elif not dest_fks: + raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % + (self.rel_model, self.through_model)) + self.src_fk = src_fks[0] + self.dest_fk = dest_fks[0] + + def __get__(self, instance, instance_type=None, force_query=False): + if instance is not None: + if not force_query and self.src_fk.backref != '+': + backref = getattr(instance, self.src_fk.backref) + if isinstance(backref, list): + return [getattr(obj, self.dest_fk.name) for obj in backref] + + src_id = getattr(instance, self.src_fk.rel_field.name) + if src_id is None and self.field._prevent_unsaved: + raise ValueError('Cannot get many-to-many "%s" for unsaved ' + 'instance "%s".' % (self.field, instance)) + return (ManyToManyQuery(instance, self, self.rel_model) + .join(self.through_model) + .join(self.model) + .where(self.src_fk == src_id)) + + return self.field + + def __set__(self, instance, value): + src_id = getattr(instance, self.src_fk.rel_field.name) + if src_id is None and self.field._prevent_unsaved: + raise ValueError('Cannot set many-to-many "%s" for unsaved ' + 'instance "%s".' % (self.field, instance)) + query = self.__get__(instance, force_query=True) + query.add(value, clear_existing=True) + + +class ManyToManyField(MetaField): + accessor_class = ManyToManyFieldAccessor + + def __init__(self, model, backref=None, through_model=None, on_delete=None, + on_update=None, prevent_unsaved=True, _is_backref=False): + if through_model is not None: + if not (isinstance(through_model, DeferredThroughModel) or + is_model(through_model)): + raise TypeError('Unexpected value for through_model. Expected ' + 'Model or DeferredThroughModel.') + if not _is_backref and (on_delete is not None or on_update is not None): + raise ValueError('Cannot specify on_delete or on_update when ' + 'through_model is specified.') + self.rel_model = model + self.backref = backref + self._through_model = through_model + self._on_delete = on_delete + self._on_update = on_update + self._prevent_unsaved = prevent_unsaved + self._is_backref = _is_backref + + def _get_descriptor(self): + return ManyToManyFieldAccessor(self) + + def bind(self, model, name, set_attribute=True): + if isinstance(self._through_model, DeferredThroughModel): + self._through_model.set_field(model, self, name) + return + + super(ManyToManyField, self).bind(model, name, set_attribute) + + if not self._is_backref: + many_to_many_field = ManyToManyField( + self.model, + backref=name, + through_model=self.through_model, + on_delete=self._on_delete, + on_update=self._on_update, + _is_backref=True) + self.backref = self.backref or model._meta.name + 's' + self.rel_model._meta.add_field(self.backref, many_to_many_field) + + def get_models(self): + return [model for _, model in sorted(( + (self._is_backref, self.model), + (not self._is_backref, self.rel_model)))] + + @property + def through_model(self): + if self._through_model is None: + self._through_model = self._create_through_model() + return self._through_model + + @through_model.setter + def through_model(self, value): + self._through_model = value + + def _create_through_model(self): + lhs, rhs = self.get_models() + tables = [model._meta.table_name for model in (lhs, rhs)] + + class Meta: + database = self.model._meta.database + schema = self.model._meta.schema + table_name = '%s_%s_through' % tuple(tables) + indexes = ( + ((lhs._meta.name, rhs._meta.name), + True),) + + params = {'on_delete': self._on_delete, 'on_update': self._on_update} + attrs = { + lhs._meta.name: ForeignKeyField(lhs, **params), + rhs._meta.name: ForeignKeyField(rhs, **params), + 'Meta': Meta} + + klass_name = '%s%sThrough' % (lhs.__name__, rhs.__name__) + return type(klass_name, (Model,), attrs) + + def get_through_model(self): + # XXX: Deprecated. Just use the "through_model" property. + return self.through_model + + +class VirtualField(MetaField): + field_class = None + + def __init__(self, field_class=None, *args, **kwargs): + Field = field_class if field_class is not None else self.field_class + self.field_instance = Field() if Field is not None else None + super(VirtualField, self).__init__(*args, **kwargs) + + def db_value(self, value): + if self.field_instance is not None: + return self.field_instance.db_value(value) + return value + + def python_value(self, value): + if self.field_instance is not None: + return self.field_instance.python_value(value) + return value + + def bind(self, model, name, set_attribute=True): + self.model = model + self.column_name = self.name = self.safe_name = name + setattr(model, name, self.accessor_class(model, self, name)) + + +class CompositeKey(MetaField): + sequence = None + + def __init__(self, *field_names): + self.field_names = field_names + self._safe_field_names = None + + @property + def safe_field_names(self): + if self._safe_field_names is None: + if self.model is None: + return self.field_names + + self._safe_field_names = [self.model._meta.fields[f].safe_name + for f in self.field_names] + return self._safe_field_names + + def __get__(self, instance, instance_type=None): + if instance is not None: + return tuple([getattr(instance, f) for f in self.safe_field_names]) + return self + + def __set__(self, instance, value): + if not isinstance(value, (list, tuple)): + raise TypeError('A list or tuple must be used to set the value of ' + 'a composite primary key.') + if len(value) != len(self.field_names): + raise ValueError('The length of the value must equal the number ' + 'of columns of the composite primary key.') + for idx, field_value in enumerate(value): + setattr(instance, self.field_names[idx], field_value) + + def __eq__(self, other): + expressions = [(self.model._meta.fields[field] == value) + for field, value in zip(self.field_names, other)] + return reduce(operator.and_, expressions) + + def __ne__(self, other): + return ~(self == other) + + def __hash__(self): + return hash((self.model.__name__, self.field_names)) + + def __sql__(self, ctx): + # If the composite PK is being selected, do not use parens. Elsewhere, + # such as in an expression, we want to use parentheses and treat it as + # a row value. + parens = ctx.scope != SCOPE_SOURCE + return ctx.sql(NodeList([self.model._meta.fields[field] + for field in self.field_names], ', ', parens)) + + def bind(self, model, name, set_attribute=True): + self.model = model + self.column_name = self.name = self.safe_name = name + setattr(model, self.name, self) + + +class _SortedFieldList(object): + __slots__ = ('_keys', '_items') + + def __init__(self): + self._keys = [] + self._items = [] + + def __getitem__(self, i): + return self._items[i] + + def __iter__(self): + return iter(self._items) + + def __contains__(self, item): + k = item._sort_key + i = bisect_left(self._keys, k) + j = bisect_right(self._keys, k) + return item in self._items[i:j] + + def index(self, field): + return self._keys.index(field._sort_key) + + def insert(self, item): + k = item._sort_key + i = bisect_left(self._keys, k) + self._keys.insert(i, k) + self._items.insert(i, item) + + def remove(self, item): + idx = self.index(item) + del self._items[idx] + del self._keys[idx] + + +# MODELS + + +class SchemaManager(object): + def __init__(self, model, database=None, **context_options): + self.model = model + self._database = database + context_options.setdefault('scope', SCOPE_VALUES) + self.context_options = context_options + + @property + def database(self): + db = self._database or self.model._meta.database + if db is None: + raise ImproperlyConfigured('database attribute does not appear to ' + 'be set on the model: %s' % self.model) + return db + + @database.setter + def database(self, value): + self._database = value + + def _create_context(self): + return self.database.get_sql_context(**self.context_options) + + def _create_table(self, safe=True, **options): + is_temp = options.pop('temporary', False) + ctx = self._create_context() + ctx.literal('CREATE TEMPORARY TABLE ' if is_temp else 'CREATE TABLE ') + if safe: + ctx.literal('IF NOT EXISTS ') + ctx.sql(self.model).literal(' ') + + columns = [] + constraints = [] + meta = self.model._meta + if meta.composite_key: + pk_columns = [meta.fields[field_name].column + for field_name in meta.primary_key.field_names] + constraints.append(NodeList((SQL('PRIMARY KEY'), + EnclosedNodeList(pk_columns)))) + + for field in meta.sorted_fields: + columns.append(field.ddl(ctx)) + if isinstance(field, ForeignKeyField) and not field.deferred: + constraints.append(field.foreign_key_constraint()) + + if meta.constraints: + constraints.extend(meta.constraints) + + constraints.extend(self._create_table_option_sql(options)) + ctx.sql(EnclosedNodeList(columns + constraints)) + + if meta.table_settings is not None: + table_settings = ensure_tuple(meta.table_settings) + for setting in table_settings: + if not isinstance(setting, basestring): + raise ValueError('table_settings must be strings') + ctx.literal(' ').literal(setting) + + extra_opts = [] + if meta.strict_tables: extra_opts.append('STRICT') + if meta.without_rowid: extra_opts.append('WITHOUT ROWID') + if extra_opts: + ctx.literal(' %s' % ', '.join(extra_opts)) + return ctx + + def _create_table_option_sql(self, options): + accum = [] + options = merge_dict(self.model._meta.options or {}, options) + if not options: + return accum + + for key, value in sorted(options.items()): + if not isinstance(value, Node): + if is_model(value): + value = value._meta.table + else: + value = SQL(str(value)) + accum.append(NodeList((SQL(key), value), glue='=')) + return accum + + def create_table(self, safe=True, **options): + self.database.execute(self._create_table(safe=safe, **options)) + + def _create_table_as(self, table_name, query, safe=True, **meta): + ctx = (self._create_context() + .literal('CREATE TEMPORARY TABLE ' + if meta.get('temporary') else 'CREATE TABLE ')) + if safe: + ctx.literal('IF NOT EXISTS ') + return (ctx + .sql(Entity(*ensure_tuple(table_name))) + .literal(' AS ') + .sql(query)) + + def create_table_as(self, table_name, query, safe=True, **meta): + ctx = self._create_table_as(table_name, query, safe=safe, **meta) + self.database.execute(ctx) + + def _drop_table(self, safe=True, **options): + ctx = (self._create_context() + .literal('DROP TABLE IF EXISTS ' if safe else 'DROP TABLE ') + .sql(self.model)) + if options.get('cascade'): + ctx = ctx.literal(' CASCADE') + elif options.get('restrict'): + ctx = ctx.literal(' RESTRICT') + return ctx + + def drop_table(self, safe=True, **options): + self.database.execute(self._drop_table(safe=safe, **options)) + + def _truncate_table(self, restart_identity=False, cascade=False): + db = self.database + if not db.truncate_table: + return (self._create_context() + .literal('DELETE FROM ').sql(self.model)) + + ctx = self._create_context().literal('TRUNCATE TABLE ').sql(self.model) + if restart_identity: + ctx = ctx.literal(' RESTART IDENTITY') + if cascade: + ctx = ctx.literal(' CASCADE') + return ctx + + def truncate_table(self, restart_identity=False, cascade=False): + self.database.execute(self._truncate_table(restart_identity, cascade)) + + def _create_indexes(self, safe=True): + return [self._create_index(index, safe) + for index in self.model._meta.fields_to_index()] + + def _create_index(self, index, safe=True): + if isinstance(index, Index): + if not self.database.safe_create_index: + index = index.safe(False) + elif index._safe != safe: + index = index.safe(safe) + if isinstance(self._database, SqliteDatabase): + # Ensure we do not use value placeholders with Sqlite, as they + # are not supported. + index = ValueLiterals(index) + return self._create_context().sql(index) + + def create_indexes(self, safe=True): + for query in self._create_indexes(safe=safe): + self.database.execute(query) + + def _drop_indexes(self, safe=True): + return [self._drop_index(index, safe) + for index in self.model._meta.fields_to_index() + if isinstance(index, Index)] + + def _drop_index(self, index, safe): + statement = 'DROP INDEX ' + if safe and self.database.safe_drop_index: + statement += 'IF EXISTS ' + if isinstance(index._table, Table) and index._table._schema: + index_name = Entity(index._table._schema, index._name) + else: + index_name = Entity(index._name) + return (self + ._create_context() + .literal(statement) + .sql(index_name)) + + def drop_indexes(self, safe=True): + for query in self._drop_indexes(safe=safe): + self.database.execute(query) + + def _check_sequences(self, field): + if not field.sequence or not self.database.sequences: + raise ValueError('Sequences are either not supported, or are not ' + 'defined for "%s".' % field.name) + + def _sequence_for_field(self, field): + if field.model._meta.schema: + return Entity(field.model._meta.schema, field.sequence) + else: + return Entity(field.sequence) + + def _create_sequence(self, field): + self._check_sequences(field) + if not self.database.sequence_exists(field.sequence): + return (self + ._create_context() + .literal('CREATE SEQUENCE ') + .sql(self._sequence_for_field(field))) + + def create_sequence(self, field): + seq_ctx = self._create_sequence(field) + if seq_ctx is not None: + self.database.execute(seq_ctx) + + def _drop_sequence(self, field): + self._check_sequences(field) + if self.database.sequence_exists(field.sequence): + return (self + ._create_context() + .literal('DROP SEQUENCE ') + .sql(self._sequence_for_field(field))) + + def drop_sequence(self, field): + seq_ctx = self._drop_sequence(field) + if seq_ctx is not None: + self.database.execute(seq_ctx) + + def _create_foreign_key(self, field): + name = 'fk_%s_%s_refs_%s' % (field.model._meta.table_name, + field.column_name, + field.rel_model._meta.table_name) + return (self + ._create_context() + .literal('ALTER TABLE ') + .sql(field.model) + .literal(' ADD CONSTRAINT ') + .sql(Entity(_truncate_constraint_name(name))) + .literal(' ') + .sql(field.foreign_key_constraint())) + + def create_foreign_key(self, field): + self.database.execute(self._create_foreign_key(field)) + + def create_sequences(self): + if self.database.sequences: + for field in self.model._meta.sorted_fields: + if field.sequence: + self.create_sequence(field) + + def create_all(self, safe=True, **table_options): + self.create_sequences() + self.create_table(safe, **table_options) + self.create_indexes(safe=safe) + + def drop_sequences(self): + if self.database.sequences: + for field in self.model._meta.sorted_fields: + if field.sequence: + self.drop_sequence(field) + + def drop_all(self, safe=True, drop_sequences=True, **options): + self.drop_table(safe, **options) + if drop_sequences: + self.drop_sequences() + + +class Metadata(object): + def __init__(self, model, database=None, table_name=None, indexes=None, + primary_key=None, constraints=None, schema=None, + only_save_dirty=False, depends_on=None, options=None, + db_table=None, table_function=None, table_settings=None, + without_rowid=False, temporary=False, strict_tables=None, + legacy_table_names=True, **kwargs): + if db_table is not None: + __deprecated__('"db_table" has been deprecated in favor of ' + '"table_name" for Models.') + table_name = db_table + self.model = model + self.database = database + + self.fields = {} + self.columns = {} + self.combined = {} + + self._sorted_field_list = _SortedFieldList() + self.sorted_fields = [] + self.sorted_field_names = [] + + self.defaults = {} + self._default_by_name = {} + self._default_dict = {} + self._default_callables = {} + self._default_callable_list = [] + + self.name = model.__name__.lower() + self.table_function = table_function + self.legacy_table_names = legacy_table_names + if not table_name: + table_name = (self.table_function(model) + if self.table_function + else self.make_table_name()) + self.table_name = table_name + self._table = None + + self.indexes = list(indexes) if indexes else [] + self.constraints = constraints + self._schema = schema + self.primary_key = primary_key + self.composite_key = self.auto_increment = None + self.only_save_dirty = only_save_dirty + self.depends_on = depends_on + self.table_settings = table_settings + self.without_rowid = without_rowid + self.strict_tables = strict_tables + self.temporary = temporary + + self.refs = {} + self.backrefs = {} + self.model_refs = collections.defaultdict(list) + self.model_backrefs = collections.defaultdict(list) + self.manytomany = {} + + self.options = options or {} + for key, value in kwargs.items(): + setattr(self, key, value) + self._additional_keys = set(kwargs.keys()) + + # Allow objects to register hooks that are called if the model is bound + # to a different database. For example, BlobField uses a different + # Python data-type depending on the db driver / python version. When + # the database changes, we need to update any BlobField so they can use + # the appropriate data-type. + self._db_hooks = [] + + def make_table_name(self): + if self.legacy_table_names: + return re.sub(r'[^\w]+', '_', self.name) + return make_snake_case(self.model.__name__) + + def model_graph(self, refs=True, backrefs=True, depth_first=True): + if not refs and not backrefs: + raise ValueError('One of `refs` or `backrefs` must be True.') + + accum = [(None, self.model, None)] + seen = set() + queue = collections.deque((self,)) + method = queue.pop if depth_first else queue.popleft + + while queue: + curr = method() + if curr in seen: continue + seen.add(curr) + + if refs: + for fk, model in curr.refs.items(): + accum.append((fk, model, False)) + queue.append(model._meta) + if backrefs: + for fk, model in curr.backrefs.items(): + accum.append((fk, model, True)) + queue.append(model._meta) + + return accum + + def add_ref(self, field): + rel = field.rel_model + self.refs[field] = rel + self.model_refs[rel].append(field) + rel._meta.backrefs[field] = self.model + rel._meta.model_backrefs[self.model].append(field) + + def remove_ref(self, field): + rel = field.rel_model + del self.refs[field] + self.model_refs[rel].remove(field) + del rel._meta.backrefs[field] + rel._meta.model_backrefs[self.model].remove(field) + + def add_manytomany(self, field): + self.manytomany[field.name] = field + + def remove_manytomany(self, field): + del self.manytomany[field.name] + + @property + def table(self): + if self._table is None: + self._table = Table( + self.table_name, + [field.column_name for field in self.sorted_fields], + schema=self.schema, + _model=self.model, + _database=self.database) + return self._table + + @table.setter + def table(self, value): + raise AttributeError('Cannot set the "table".') + + @table.deleter + def table(self): + self._table = None + + @property + def schema(self): + return self._schema + + @schema.setter + def schema(self, value): + self._schema = value + del self.table + + @property + def entity(self): + if self._schema: + return Entity(self._schema, self.table_name) + else: + return Entity(self.table_name) + + def _update_sorted_fields(self): + self.sorted_fields = list(self._sorted_field_list) + self.sorted_field_names = [f.name for f in self.sorted_fields] + + def get_rel_for_model(self, model): + if isinstance(model, ModelAlias): + model = model.model + forwardrefs = self.model_refs.get(model, []) + backrefs = self.model_backrefs.get(model, []) + return (forwardrefs, backrefs) + + def add_field(self, field_name, field, set_attribute=True): + if field_name in self.fields: + self.remove_field(field_name) + elif field_name in self.manytomany: + self.remove_manytomany(self.manytomany[field_name]) + + if not isinstance(field, MetaField): + del self.table + field.bind(self.model, field_name, set_attribute) + self.fields[field.name] = field + self.columns[field.column_name] = field + self.combined[field.name] = field + self.combined[field.column_name] = field + + self._sorted_field_list.insert(field) + self._update_sorted_fields() + + if field.default is not None: + # This optimization helps speed up model instance construction. + self.defaults[field] = field.default + if callable_(field.default): + self._default_callables[field] = field.default + self._default_callable_list.append((field.name, + field.default)) + else: + self._default_dict[field] = field.default + self._default_by_name[field.name] = field.default + else: + field.bind(self.model, field_name, set_attribute) + + if isinstance(field, ForeignKeyField): + self.add_ref(field) + elif isinstance(field, ManyToManyField) and field.name: + self.add_manytomany(field) + + def remove_field(self, field_name): + if field_name not in self.fields: + return + + del self.table + original = self.fields.pop(field_name) + del self.columns[original.column_name] + del self.combined[field_name] + try: + del self.combined[original.column_name] + except KeyError: + pass + self._sorted_field_list.remove(original) + self._update_sorted_fields() + + if original.default is not None: + del self.defaults[original] + if self._default_callables.pop(original, None): + for i, (name, _) in enumerate(self._default_callable_list): + if name == field_name: + self._default_callable_list.pop(i) + break + else: + self._default_dict.pop(original, None) + self._default_by_name.pop(original.name, None) + + if isinstance(original, ForeignKeyField): + self.remove_ref(original) + + def set_primary_key(self, name, field): + self.composite_key = isinstance(field, CompositeKey) + self.add_field(name, field) + self.primary_key = field + self.auto_increment = ( + field.auto_increment or + bool(field.sequence)) + + def get_primary_keys(self): + if self.composite_key: + return tuple([self.fields[field_name] + for field_name in self.primary_key.field_names]) + else: + return (self.primary_key,) if self.primary_key is not False else () + + def get_default_dict(self): + dd = self._default_by_name.copy() + for field_name, default in self._default_callable_list: + dd[field_name] = default() + return dd + + def fields_to_index(self): + indexes = [] + for f in self.sorted_fields: + if f.primary_key: + continue + if f.index or f.unique: + indexes.append(ModelIndex(self.model, (f,), unique=f.unique, + using=f.index_type)) + + for index_obj in self.indexes: + if isinstance(index_obj, Node): + indexes.append(index_obj) + elif isinstance(index_obj, (list, tuple)): + index_parts, unique = index_obj + fields = [] + for part in index_parts: + if isinstance(part, basestring): + fields.append(self.combined[part]) + elif isinstance(part, Node): + fields.append(part) + else: + raise ValueError('Expected either a field name or a ' + 'subclass of Node. Got: %s' % part) + indexes.append(ModelIndex(self.model, fields, unique=unique)) + + return indexes + + def set_database(self, database): + self.database = database + self.model._schema._database = database + del self.table + + # Apply any hooks that have been registered. If we have an + # uninitialized proxy object, we will treat that as `None`. + if isinstance(database, Proxy) and database.obj is None: + database = None + + for hook in self._db_hooks: + hook(database) + + def set_table_name(self, table_name): + self.table_name = table_name + del self.table + + +class SubclassAwareMetadata(Metadata): + models = [] + + def __init__(self, model, *args, **kwargs): + super(SubclassAwareMetadata, self).__init__(model, *args, **kwargs) + self.models.append(model) + + def map_models(self, fn): + for model in self.models: + fn(model) + + +class DoesNotExist(Exception): pass + + +class ModelBase(type): + inheritable = set(['constraints', 'database', 'indexes', 'primary_key', + 'options', 'schema', 'table_function', 'temporary', + 'only_save_dirty', 'legacy_table_names', + 'table_settings', 'strict_tables']) + + def __new__(cls, name, bases, attrs, **kwargs): + if name == MODEL_BASE or bases[0].__name__ == MODEL_BASE: + return super(ModelBase, cls).__new__(cls, name, bases, attrs, + **kwargs) + + meta_options = {} + meta = attrs.pop('Meta', None) + if meta: + for k, v in meta.__dict__.items(): + if not k.startswith('_'): + meta_options[k] = v + + pk = getattr(meta, 'primary_key', None) + pk_name = parent_pk = None + + # Inherit any field descriptors by deep copying the underlying field + # into the attrs of the new model, additionally see if the bases define + # inheritable model options and swipe them. + for b in bases: + if not hasattr(b, '_meta'): + continue + + base_meta = b._meta + if parent_pk is None: + parent_pk = deepcopy(base_meta.primary_key) + all_inheritable = cls.inheritable | base_meta._additional_keys + for k in base_meta.__dict__: + if k in all_inheritable and k not in meta_options: + meta_options[k] = base_meta.__dict__[k] + meta_options.setdefault('database', base_meta.database) + meta_options.setdefault('schema', base_meta.schema) + + for (k, v) in b.__dict__.items(): + if k in attrs: continue + + if isinstance(v, FieldAccessor) and not v.field.primary_key: + attrs[k] = deepcopy(v.field) + + sopts = meta_options.pop('schema_options', None) or {} + Meta = meta_options.get('model_metadata_class', Metadata) + Schema = meta_options.get('schema_manager_class', SchemaManager) + + # Construct the new class. + cls = super(ModelBase, cls).__new__(cls, name, bases, attrs, **kwargs) + cls.__data__ = cls.__rel__ = None + + cls._meta = Meta(cls, **meta_options) + cls._schema = Schema(cls, **sopts) + + fields = [] + for key, value in cls.__dict__.items(): + if isinstance(value, Field): + if value.primary_key and pk: + raise ValueError('over-determined primary key %s.' % name) + elif value.primary_key: + pk, pk_name = value, key + else: + fields.append((key, value)) + + if pk is None: + if parent_pk is not False: + pk, pk_name = ((parent_pk, parent_pk.name) + if parent_pk is not None else + (AutoField(), 'id')) + else: + pk = False + elif isinstance(pk, CompositeKey): + pk_name = '__composite_key__' + cls._meta.composite_key = True + + if pk is not False: + cls._meta.set_primary_key(pk_name, pk) + + for name, field in fields: + cls._meta.add_field(name, field) + + # Create a repr and error class before finalizing. + if hasattr(cls, '__str__') and '__repr__' not in attrs: + setattr(cls, '__repr__', lambda self: '<%s: %s>' % ( + cls.__name__, self.__str__())) + + exc_name = '%sDoesNotExist' % cls.__name__ + exc_attrs = {'__module__': cls.__module__} + exception_class = type(exc_name, (DoesNotExist,), exc_attrs) + cls.DoesNotExist = exception_class + + # Call validation hook, allowing additional model validation. + cls.validate_model() + DeferredForeignKey.resolve(cls) + return cls + + def __repr__(self): + return '' % self.__name__ + + def __iter__(self): + return iter(self.select()) + + def __getitem__(self, key): + return self.get_by_id(key) + + def __setitem__(self, key, value): + self.set_by_id(key, value) + + def __delitem__(self, key): + self.delete_by_id(key) + + def __contains__(self, key): + try: + self.get_by_id(key) + except self.DoesNotExist: + return False + else: + return True + + def __len__(self): + return self.select().count() + def __bool__(self): return True + __nonzero__ = __bool__ # Python 2. + + def __sql__(self, ctx): + return ctx.sql(self._meta.table) + + +class _BoundModelsContext(object): + def __init__(self, models, database, bind_refs, bind_backrefs): + self.models = models + self.database = database + self.bind_refs = bind_refs + self.bind_backrefs = bind_backrefs + + def __enter__(self): + self._orig_database = [] + for model in self.models: + self._orig_database.append(model._meta.database) + model.bind(self.database, self.bind_refs, self.bind_backrefs, + _exclude=set(self.models)) + return self.models + + def __exit__(self, exc_type, exc_val, exc_tb): + for model, db in zip(self.models, self._orig_database): + model.bind(db, self.bind_refs, self.bind_backrefs, + _exclude=set(self.models)) + + +class Model(with_metaclass(ModelBase, Node)): + def __init__(self, *args, **kwargs): + if kwargs.pop('__no_default__', None): + self.__data__ = {} + else: + self.__data__ = self._meta.get_default_dict() + self._dirty = set(self.__data__) + self.__rel__ = {} + + for k in kwargs: + setattr(self, k, kwargs[k]) + + def __str__(self): + return str(self._pk) if self._meta.primary_key is not False else 'n/a' + + @classmethod + def validate_model(cls): + pass + + @classmethod + def alias(cls, alias=None): + return ModelAlias(cls, alias) + + @classmethod + def select(cls, *fields): + is_default = not fields + if not fields: + fields = cls._meta.sorted_fields + return ModelSelect(cls, fields, is_default=is_default) + + @classmethod + def _normalize_data(cls, data, kwargs): + normalized = {} + if data: + if not isinstance(data, dict): + if kwargs: + raise ValueError('Data cannot be mixed with keyword ' + 'arguments: %s' % data) + return data + for key in data: + try: + field = (key if isinstance(key, Field) + else cls._meta.combined[key]) + except KeyError: + if not isinstance(key, Node): + raise ValueError('Unrecognized field name: "%s" in %s.' + % (key, data)) + field = key + normalized[field] = data[key] + if kwargs: + for key in kwargs: + try: + normalized[cls._meta.combined[key]] = kwargs[key] + except KeyError: + normalized[getattr(cls, key)] = kwargs[key] + return normalized + + @classmethod + def update(cls, __data=None, **update): + return ModelUpdate(cls, cls._normalize_data(__data, update)) + + @classmethod + def insert(cls, __data=None, **insert): + return ModelInsert(cls, cls._normalize_data(__data, insert)) + + @classmethod + def insert_many(cls, rows, fields=None): + return ModelInsert(cls, insert=rows, columns=fields) + + @classmethod + def insert_from(cls, query, fields): + columns = [getattr(cls, field) if isinstance(field, basestring) + else field for field in fields] + return ModelInsert(cls, insert=query, columns=columns) + + @classmethod + def replace(cls, __data=None, **insert): + return cls.insert(__data, **insert).on_conflict('REPLACE') + + @classmethod + def replace_many(cls, rows, fields=None): + return (cls + .insert_many(rows=rows, fields=fields) + .on_conflict('REPLACE')) + + @classmethod + def raw(cls, sql, *params): + return ModelRaw(cls, sql, params) + + @classmethod + def delete(cls): + return ModelDelete(cls) + + @classmethod + def create(cls, **query): + inst = cls(**query) + inst.save(force_insert=True) + return inst + + @classmethod + def bulk_create(cls, model_list, batch_size=None): + if batch_size is not None: + batches = chunked(model_list, batch_size) + else: + batches = [model_list] + + field_names = list(cls._meta.sorted_field_names) + if cls._meta.auto_increment: + pk_name = cls._meta.primary_key.name + field_names.remove(pk_name) + + if cls._meta.database.returning_clause and \ + cls._meta.primary_key is not False: + pk_fields = cls._meta.get_primary_keys() + else: + pk_fields = None + + fields = [cls._meta.fields[field_name] for field_name in field_names] + attrs = [] + for field in fields: + if isinstance(field, ForeignKeyField): + attrs.append(field.object_id_name) + else: + attrs.append(field.name) + + for batch in batches: + accum = ([getattr(model, f) for f in attrs] + for model in batch) + res = cls.insert_many(accum, fields=fields).execute() + if pk_fields and res is not None: + for row, model in zip(res, batch): + for (pk_field, obj_id) in zip(pk_fields, row): + setattr(model, pk_field.name, obj_id) + + @classmethod + def bulk_update(cls, model_list, fields, batch_size=None): + if isinstance(cls._meta.primary_key, CompositeKey): + raise ValueError('bulk_update() is not supported for models with ' + 'a composite primary key.') + + # First normalize list of fields so all are field instances. + fields = [cls._meta.fields[f] if isinstance(f, basestring) else f + for f in fields] + # Now collect list of attribute names to use for values. + attrs = [field.object_id_name if isinstance(field, ForeignKeyField) + else field.name for field in fields] + + if batch_size is not None: + batches = chunked(model_list, batch_size) + else: + batches = [model_list] + + n = 0 + pk = cls._meta.primary_key + + for batch in batches: + id_list = [model._pk for model in batch] + update = {} + for field, attr in zip(fields, attrs): + accum = [] + for model in batch: + value = getattr(model, attr) + if not isinstance(value, Node): + value = field.to_value(value) + accum.append((pk.to_value(model._pk), value)) + case = Case(pk, accum) + update[field] = case + + n += (cls.update(update) + .where(cls._meta.primary_key.in_(id_list)) + .execute()) + return n + + @classmethod + def noop(cls): + return NoopModelSelect(cls, ()) + + @classmethod + def get(cls, *query, **filters): + sq = cls.select() + if query: + # Handle simple lookup using just the primary key. + if len(query) == 1 and isinstance(query[0], int): + sq = sq.where(cls._meta.primary_key == query[0]) + else: + sq = sq.where(*query) + if filters: + sq = sq.filter(**filters) + return sq.get() + + @classmethod + def get_or_none(cls, *query, **filters): + try: + return cls.get(*query, **filters) + except DoesNotExist: + pass + + @classmethod + def get_by_id(cls, pk): + return cls.get(cls._meta.primary_key == pk) + + @classmethod + def set_by_id(cls, key, value): + if key is None: + return cls.insert(value).execute() + else: + return (cls.update(value) + .where(cls._meta.primary_key == key).execute()) + + @classmethod + def delete_by_id(cls, pk): + return cls.delete().where(cls._meta.primary_key == pk).execute() + + @classmethod + def get_or_create(cls, **kwargs): + defaults = kwargs.pop('defaults', {}) + query = cls.select() + for field, value in kwargs.items(): + query = query.where(getattr(cls, field) == value) + + try: + return query.get(), False + except cls.DoesNotExist: + try: + if defaults: + kwargs.update(defaults) + with cls._meta.database.atomic(): + return cls.create(**kwargs), True + except IntegrityError as exc: + try: + return query.get(), False + except cls.DoesNotExist: + raise exc + + @classmethod + def filter(cls, *dq_nodes, **filters): + return cls.select().filter(*dq_nodes, **filters) + + def get_id(self): + # Using getattr(self, pk-name) could accidentally trigger a query if + # the primary-key is a foreign-key. So we use the safe_name attribute, + # which defaults to the field-name, but will be the object_id_name for + # foreign-key fields. + if self._meta.primary_key is not False: + return getattr(self, self._meta.primary_key.safe_name) + + _pk = property(get_id) + + @_pk.setter + def _pk(self, value): + setattr(self, self._meta.primary_key.name, value) + + def _pk_expr(self): + return self._meta.primary_key == self._pk + + def _prune_fields(self, field_dict, only): + new_data = {} + for field in only: + if isinstance(field, basestring): + field = self._meta.combined[field] + if field.name in field_dict: + new_data[field.name] = field_dict[field.name] + return new_data + + def _populate_unsaved_relations(self, field_dict): + for foreign_key_field in self._meta.refs: + foreign_key = foreign_key_field.name + conditions = ( + foreign_key in field_dict and + field_dict[foreign_key] is None and + self.__rel__.get(foreign_key) is not None) + if conditions: + setattr(self, foreign_key, getattr(self, foreign_key)) + field_dict[foreign_key] = self.__data__[foreign_key] + + def save(self, force_insert=False, only=None): + field_dict = self.__data__.copy() + if self._meta.primary_key is not False: + pk_field = self._meta.primary_key + pk_value = self._pk + else: + pk_field = pk_value = None + if only is not None: + field_dict = self._prune_fields(field_dict, only) + elif self._meta.only_save_dirty and not force_insert: + field_dict = self._prune_fields(field_dict, self.dirty_fields) + if not field_dict: + self._dirty.clear() + return False + + self._populate_unsaved_relations(field_dict) + rows = 1 + + if self._meta.auto_increment and pk_value is None: + field_dict.pop(pk_field.name, None) + + if pk_value is not None and not force_insert: + if self._meta.composite_key: + for pk_part_name in pk_field.field_names: + field_dict.pop(pk_part_name, None) + else: + field_dict.pop(pk_field.name, None) + if not field_dict: + raise ValueError('no data to save!') + rows = self.update(**field_dict).where(self._pk_expr()).execute() + elif pk_field is not None: + pk = self.insert(**field_dict).execute() + if pk is not None and (self._meta.auto_increment or + pk_value is None): + self._pk = pk + # Although we set the primary-key, do not mark it as dirty. + self._dirty.discard(pk_field.name) + else: + self.insert(**field_dict).execute() + + self._dirty -= set(field_dict) # Remove any fields we saved. + return rows + + def is_dirty(self): + return bool(self._dirty) + + @property + def dirty_fields(self): + return [f for f in self._meta.sorted_fields if f.name in self._dirty] + + def dependencies(self, search_nullable=False): + model_class = type(self) + stack = [(type(self), None)] + seen = set() + + while stack: + klass, query = stack.pop() + if klass in seen: + continue + seen.add(klass) + for fk, rel_model in klass._meta.backrefs.items(): + if rel_model is model_class or query is None: + node = (fk == self.__data__[fk.rel_field.name]) + else: + node = fk << query + subquery = (rel_model.select(rel_model._meta.primary_key) + .where(node)) + if not fk.null or search_nullable: + stack.append((rel_model, subquery)) + yield (node, fk) + + def delete_instance(self, recursive=False, delete_nullable=False): + if recursive: + dependencies = self.dependencies(delete_nullable) + for query, fk in reversed(list(dependencies)): + model = fk.model + if fk.null and not delete_nullable: + model.update(**{fk.name: None}).where(query).execute() + else: + model.delete().where(query).execute() + return type(self).delete().where(self._pk_expr()).execute() + + def __hash__(self): + return hash((self.__class__, self._pk)) + + def __eq__(self, other): + return ( + other.__class__ == self.__class__ and + self._pk is not None and + self._pk == other._pk) + + def __ne__(self, other): + return not self == other + + def __sql__(self, ctx): + # NOTE: when comparing a foreign-key field whose related-field is not a + # primary-key, then doing an equality test for the foreign-key with a + # model instance will return the wrong value; since we would return + # the primary key for a given model instance. + # + # This checks to see if we have a converter in the scope, and that we + # are converting a foreign-key expression. If so, we hand the model + # instance to the converter rather than blindly grabbing the primary- + # key. In the event the provided converter fails to handle the model + # instance, then we will return the primary-key. + if ctx.state.converter is not None and ctx.state.is_fk_expr: + try: + return ctx.sql(Value(self, converter=ctx.state.converter)) + except (TypeError, ValueError): + pass + + return ctx.sql(Value(getattr(self, self._meta.primary_key.name), + converter=self._meta.primary_key.db_value)) + + @classmethod + def bind(cls, database, bind_refs=True, bind_backrefs=True, _exclude=None): + is_different = cls._meta.database is not database + cls._meta.set_database(database) + if bind_refs or bind_backrefs: + if _exclude is None: + _exclude = set() + G = cls._meta.model_graph(refs=bind_refs, backrefs=bind_backrefs) + for _, model, is_backref in G: + if model not in _exclude: + model._meta.set_database(database) + _exclude.add(model) + return is_different + + @classmethod + def bind_ctx(cls, database, bind_refs=True, bind_backrefs=True): + return _BoundModelsContext((cls,), database, bind_refs, bind_backrefs) + + @classmethod + def table_exists(cls): + M = cls._meta + return cls._schema.database.table_exists(M.table.__name__, M.schema) + + @classmethod + def create_table(cls, safe=True, **options): + if 'fail_silently' in options: + __deprecated__('"fail_silently" has been deprecated in favor of ' + '"safe" for the create_table() method.') + safe = options.pop('fail_silently') + + if safe and not cls._schema.database.safe_create_index \ + and cls.table_exists(): + return + if cls._meta.temporary: + options.setdefault('temporary', cls._meta.temporary) + cls._schema.create_all(safe, **options) + + @classmethod + def drop_table(cls, safe=True, drop_sequences=True, **options): + if safe and not cls._schema.database.safe_drop_index \ + and not cls.table_exists(): + return + if cls._meta.temporary: + options.setdefault('temporary', cls._meta.temporary) + cls._schema.drop_all(safe, drop_sequences, **options) + + @classmethod + def truncate_table(cls, **options): + cls._schema.truncate_table(**options) + + @classmethod + def index(cls, *fields, **kwargs): + return ModelIndex(cls, fields, **kwargs) + + @classmethod + def add_index(cls, *fields, **kwargs): + if len(fields) == 1 and isinstance(fields[0], (SQL, Index)): + cls._meta.indexes.append(fields[0]) + else: + cls._meta.indexes.append(ModelIndex(cls, fields, **kwargs)) + + +class ModelAlias(Node): + """Provide a separate reference to a model in a query.""" + def __init__(self, model, alias=None): + self.__dict__['model'] = model + self.__dict__['alias'] = alias + + def __getattr__(self, attr): + # Hack to work-around the fact that properties or other objects + # implementing the descriptor protocol (on the model being aliased), + # will not work correctly when we use getattr(). So we explicitly pass + # the model alias to the descriptor's getter. + try: + obj = self.model.__dict__[attr] + except KeyError: + pass + else: + if isinstance(obj, ModelDescriptor): + return obj.__get__(None, self) + + model_attr = getattr(self.model, attr) + if isinstance(model_attr, Field): + self.__dict__[attr] = FieldAlias.create(self, model_attr) + return self.__dict__[attr] + return model_attr + + def __setattr__(self, attr, value): + raise AttributeError('Cannot set attributes on model aliases.') + + def get_field_aliases(self): + return [getattr(self, n) for n in self.model._meta.sorted_field_names] + + def select(self, *selection): + if not selection: + selection = self.get_field_aliases() + return ModelSelect(self, selection) + + def __call__(self, **kwargs): + return self.model(**kwargs) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + # Return the quoted table name. + return ctx.sql(self.model) + + if self.alias: + ctx.alias_manager[self] = self.alias + + if ctx.scope == SCOPE_SOURCE: + # Define the table and its alias. + return (ctx + .sql(self.model._meta.entity) + .literal(' AS ') + .sql(Entity(ctx.alias_manager[self]))) + else: + # Refer to the table using the alias. + return ctx.sql(Entity(ctx.alias_manager[self])) + + +class FieldAlias(Field): + def __init__(self, source, field): + self.source = source + self.model = source.model + self.field = field + + @classmethod + def create(cls, source, field): + class _FieldAlias(cls, type(field)): + pass + return _FieldAlias(source, field) + + def clone(self): + return FieldAlias(self.source, self.field) + + def adapt(self, value): return self.field.adapt(value) + def python_value(self, value): return self.field.python_value(value) + def db_value(self, value): return self.field.db_value(value) + def __getattr__(self, attr): + return self.source if attr == 'model' else getattr(self.field, attr) + + def __sql__(self, ctx): + return ctx.sql(Column(self.source, self.field.column_name)) + + +def sort_models(models): + models = set(models) + seen = set() + ordering = [] + def dfs(model): + if model in models and model not in seen: + seen.add(model) + for foreign_key, rel_model in model._meta.refs.items(): + # Do not depth-first search deferred foreign-keys as this can + # cause tables to be created in the incorrect order. + if not foreign_key.deferred: + dfs(rel_model) + if model._meta.depends_on: + for dependency in model._meta.depends_on: + dfs(dependency) + ordering.append(model) + + names = lambda m: (m._meta.name, m._meta.table_name) + for m in sorted(models, key=names): + dfs(m) + return ordering + + +class _ModelQueryHelper(object): + default_row_type = ROW.MODEL + + def __init__(self, *args, **kwargs): + super(_ModelQueryHelper, self).__init__(*args, **kwargs) + if not self._database: + self._database = self.model._meta.database + + @Node.copy + def objects(self, constructor=None): + self._row_type = ROW.CONSTRUCTOR + self._constructor = self.model if constructor is None else constructor + + def _get_cursor_wrapper(self, cursor): + row_type = self._row_type or self.default_row_type + if row_type == ROW.MODEL: + return self._get_model_cursor_wrapper(cursor) + elif row_type == ROW.DICT: + return ModelDictCursorWrapper(cursor, self.model, self._returning) + elif row_type == ROW.TUPLE: + return ModelTupleCursorWrapper(cursor, self.model, self._returning) + elif row_type == ROW.NAMED_TUPLE: + return ModelNamedTupleCursorWrapper(cursor, self.model, + self._returning) + elif row_type == ROW.CONSTRUCTOR: + return ModelObjectCursorWrapper(cursor, self.model, + self._returning, self._constructor) + else: + raise ValueError('Unrecognized row type: "%s".' % row_type) + + def _get_model_cursor_wrapper(self, cursor): + return ModelObjectCursorWrapper(cursor, self.model, [], self.model) + + +class ModelRaw(_ModelQueryHelper, RawQuery): + def __init__(self, model, sql, params, **kwargs): + self.model = model + self._returning = () + super(ModelRaw, self).__init__(sql=sql, params=params, **kwargs) + + def get(self): + try: + return self.execute()[0] + except IndexError: + sql, params = self.sql() + raise self.model.DoesNotExist('%s instance matching query does ' + 'not exist:\nSQL: %s\nParams: %s' % + (self.model, sql, params)) + + +class BaseModelSelect(_ModelQueryHelper): + def union_all(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) + __add__ = union_all + + def union(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'UNION', rhs) + __or__ = union + + def intersect(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) + __and__ = intersect + + def except_(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) + __sub__ = except_ + + def __iter__(self): + if not self._cursor_wrapper: + self.execute() + return iter(self._cursor_wrapper) + + def prefetch(self, *subqueries, **kwargs): + return prefetch(self, *subqueries, **kwargs) + + def get(self, database=None): + clone = self.paginate(1, 1) + clone._cursor_wrapper = None + try: + return clone.execute(database)[0] + except IndexError: + sql, params = clone.sql() + raise self.model.DoesNotExist('%s instance matching query does ' + 'not exist:\nSQL: %s\nParams: %s' % + (clone.model, sql, params)) + + def get_or_none(self, database=None): + try: + return self.get(database=database) + except self.model.DoesNotExist: + pass + + @Node.copy + def group_by(self, *columns): + grouping = [] + for column in columns: + if is_model(column): + grouping.extend(column._meta.sorted_fields) + elif isinstance(column, Table): + if not column._columns: + raise ValueError('Cannot pass a table to group_by() that ' + 'does not have columns explicitly ' + 'declared.') + grouping.extend([getattr(column, col_name) + for col_name in column._columns]) + else: + grouping.append(column) + self._group_by = grouping + + +class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): + def __init__(self, model, *args, **kwargs): + self.model = model + super(ModelCompoundSelectQuery, self).__init__(*args, **kwargs) + + def _get_model_cursor_wrapper(self, cursor): + return self.lhs._get_model_cursor_wrapper(cursor) + + +def _normalize_model_select(fields_or_models): + fields = [] + for fm in fields_or_models: + if is_model(fm): + fields.extend(fm._meta.sorted_fields) + elif isinstance(fm, ModelAlias): + fields.extend(fm.get_field_aliases()) + elif isinstance(fm, Table) and fm._columns: + fields.extend([getattr(fm, col) for col in fm._columns]) + else: + fields.append(fm) + return fields + + +class ModelSelect(BaseModelSelect, Select): + def __init__(self, model, fields_or_models, is_default=False): + self.model = self._join_ctx = model + self._joins = {} + self._is_default = is_default + fields = _normalize_model_select(fields_or_models) + super(ModelSelect, self).__init__([model], fields) + + def clone(self): + clone = super(ModelSelect, self).clone() + if clone._joins: + clone._joins = dict(clone._joins) + return clone + + def select(self, *fields_or_models): + if fields_or_models or not self._is_default: + self._is_default = False + fields = _normalize_model_select(fields_or_models) + return super(ModelSelect, self).select(*fields) + return self + + def select_extend(self, *columns): + self._is_default = False + fields = _normalize_model_select(columns) + return super(ModelSelect, self).select_extend(*fields) + + def switch(self, ctx=None): + self._join_ctx = self.model if ctx is None else ctx + return self + + def _get_model(self, src): + if is_model(src): + return src, True + elif isinstance(src, Table) and src._model: + return src._model, False + elif isinstance(src, ModelAlias): + return src.model, False + elif isinstance(src, ModelSelect): + return src.model, False + return None, False + + def _normalize_join(self, src, dest, on, attr): + # Allow "on" expression to have an alias that determines the + # destination attribute for the joined data. + on_alias = isinstance(on, Alias) + if on_alias: + attr = attr or on._alias + on = on.alias() + + # Obtain references to the source and destination models being joined. + src_model, src_is_model = self._get_model(src) + dest_model, dest_is_model = self._get_model(dest) + + if src_model and dest_model: + self._join_ctx = dest + constructor = dest_model + + # In the case where the "on" clause is a Column or Field, we will + # convert that field into the appropriate predicate expression. + if not (src_is_model and dest_is_model) and isinstance(on, Column): + if on.source is src: + to_field = src_model._meta.columns[on.name] + elif on.source is dest: + to_field = dest_model._meta.columns[on.name] + else: + raise AttributeError('"on" clause Column %s does not ' + 'belong to %s or %s.' % + (on, src_model, dest_model)) + on = None + elif isinstance(on, Field): + to_field = on + on = None + else: + to_field = None + + fk_field, is_backref = self._generate_on_clause( + src_model, dest_model, to_field, on) + + if on is None: + src_attr = 'name' if src_is_model else 'column_name' + dest_attr = 'name' if dest_is_model else 'column_name' + if is_backref: + lhs = getattr(dest, getattr(fk_field, dest_attr)) + rhs = getattr(src, getattr(fk_field.rel_field, src_attr)) + else: + lhs = getattr(src, getattr(fk_field, src_attr)) + rhs = getattr(dest, getattr(fk_field.rel_field, dest_attr)) + on = (lhs == rhs) + + if not attr: + if fk_field is not None and not is_backref: + attr = fk_field.name + else: + attr = dest_model._meta.name + elif on_alias and fk_field is not None and \ + attr == fk_field.object_id_name and not is_backref: + raise ValueError('Cannot assign join alias to "%s", as this ' + 'attribute is the object_id_name for the ' + 'foreign-key field "%s"' % (attr, fk_field)) + + elif isinstance(dest, Source): + constructor = dict + attr = attr or dest._alias + if not attr and isinstance(dest, Table): + attr = attr or dest.__name__ + + return (on, attr, constructor) + + def _generate_on_clause(self, src, dest, to_field=None, on=None): + meta = src._meta + is_backref = fk_fields = False + + # Get all the foreign keys between source and dest, and determine if + # the join is via a back-reference. + if dest in meta.model_refs: + fk_fields = meta.model_refs[dest] + elif dest in meta.model_backrefs: + fk_fields = meta.model_backrefs[dest] + is_backref = True + + if not fk_fields: + if on is not None: + return None, False + raise ValueError('Unable to find foreign key between %s and %s. ' + 'Please specify an explicit join condition.' % + (src, dest)) + elif to_field is not None: + # If the foreign-key field was specified explicitly, remove all + # other foreign-key fields from the list. + target = (to_field.field if isinstance(to_field, FieldAlias) + else to_field) + fk_fields = [f for f in fk_fields if ( + (f is target) or + (is_backref and f.rel_field is to_field))] + + if len(fk_fields) == 1: + return fk_fields[0], is_backref + + if on is None: + # If multiple foreign-keys exist, try using the FK whose name + # matches that of the related model. If not, raise an error as this + # is ambiguous. + for fk in fk_fields: + if fk.name == dest._meta.name: + return fk, is_backref + + raise ValueError('More than one foreign key between %s and %s.' + ' Please specify which you are joining on.' % + (src, dest)) + + # If there are multiple foreign-keys to choose from and the join + # predicate is an expression, we'll try to figure out which + # foreign-key field we're joining on so that we can assign to the + # correct attribute when resolving the model graph. + to_field = None + if isinstance(on, Expression): + lhs, rhs = on.lhs, on.rhs + # Coerce to set() so that we force Python to compare using the + # object's hash rather than equality test, which returns a + # false-positive due to overriding __eq__. + fk_set = set(fk_fields) + + if isinstance(lhs, Field): + lhs_f = lhs.field if isinstance(lhs, FieldAlias) else lhs + if lhs_f in fk_set: + to_field = lhs_f + elif isinstance(rhs, Field): + rhs_f = rhs.field if isinstance(rhs, FieldAlias) else rhs + if rhs_f in fk_set: + to_field = rhs_f + + return to_field, False + + @Node.copy + def join(self, dest, join_type=JOIN.INNER, on=None, src=None, attr=None): + src = self._join_ctx if src is None else src + + if join_type == JOIN.LATERAL or join_type == JOIN.LEFT_LATERAL: + on = True + elif join_type != JOIN.CROSS: + on, attr, constructor = self._normalize_join(src, dest, on, attr) + if attr: + self._joins.setdefault(src, []) + self._joins[src].append((dest, attr, constructor, join_type)) + elif on is not None: + raise ValueError('Cannot specify on clause with cross join.') + + if not self._from_list: + raise ValueError('No sources to join on.') + + item = self._from_list.pop() + self._from_list.append(Join(item, dest, join_type, on)) + + def left_outer_join(self, dest, on=None, src=None, attr=None): + return self.join(dest, JOIN.LEFT_OUTER, on, src, attr) + + def join_from(self, src, dest, join_type=JOIN.INNER, on=None, attr=None): + return self.join(dest, join_type, on, src, attr) + + def _get_model_cursor_wrapper(self, cursor): + if len(self._from_list) == 1 and not self._joins: + return ModelObjectCursorWrapper(cursor, self.model, + self._returning, self.model) + return ModelCursorWrapper(cursor, self.model, self._returning, + self._from_list, self._joins) + + def ensure_join(self, lm, rm, on=None, **join_kwargs): + join_ctx = self._join_ctx + for dest, _, constructor, _ in self._joins.get(lm, []): + if dest == rm: + return self + return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx) + + def convert_dict_to_node(self, qdict): + accum = [] + joins = [] + fks = (ForeignKeyField, BackrefAccessor) + for key, value in sorted(qdict.items()): + curr = self.model + if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP: + key, op = key.rsplit('__', 1) + op = DJANGO_MAP[op] + elif value is None: + op = DJANGO_MAP['is'] + else: + op = DJANGO_MAP['eq'] + + if '__' not in key: + # Handle simplest case. This avoids joining over-eagerly when a + # direct FK lookup is all that is required. + model_attr = getattr(curr, key) + else: + for piece in key.split('__'): + for dest, attr, _, _ in self._joins.get(curr, ()): + try: model_attr = getattr(curr, piece, None) + except: pass + if attr == piece or (isinstance(dest, ModelAlias) and + dest.alias == piece): + curr = dest + break + else: + model_attr = getattr(curr, piece) + if value is not None and isinstance(model_attr, fks): + curr = model_attr.rel_model + joins.append(model_attr) + accum.append(op(model_attr, value)) + return accum, joins + + def filter(self, *args, **kwargs): + # normalize args and kwargs into a new expression + if args and kwargs: + dq_node = (reduce(operator.and_, [a.clone() for a in args]) & + DQ(**kwargs)) + elif args: + dq_node = (reduce(operator.and_, [a.clone() for a in args]) & + ColumnBase()) + elif kwargs: + dq_node = DQ(**kwargs) & ColumnBase() + else: + return self.clone() + + # dq_node should now be an Expression, lhs = Node(), rhs = ... + q = collections.deque([dq_node]) + dq_joins = [] + seen_joins = set() + while q: + curr = q.popleft() + if not isinstance(curr, Expression): + continue + for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)): + if isinstance(piece, DQ): + query, joins = self.convert_dict_to_node(piece.query) + for join in joins: + if join not in seen_joins: + dq_joins.append(join) + seen_joins.add(join) + expression = reduce(operator.and_, query) + # Apply values from the DQ object. + if piece._negated: + expression = Negated(expression) + #expression._alias = piece._alias + setattr(curr, side, expression) + else: + q.append(piece) + + if not args or not kwargs: + dq_node = dq_node.lhs + + query = self.clone() + for field in dq_joins: + if isinstance(field, ForeignKeyField): + lm, rm = field.model, field.rel_model + field_obj = field + elif isinstance(field, BackrefAccessor): + lm, rm = field.model, field.rel_model + field_obj = field.field + query = query.ensure_join(lm, rm, field_obj) + return query.where(dq_node) + + def create_table(self, name, safe=True, **meta): + return self.model._schema.create_table_as(name, self, safe, **meta) + + def __sql_selection__(self, ctx, is_subquery=False): + if self._is_default and is_subquery and len(self._returning) > 1 and \ + self.model._meta.primary_key is not False: + return ctx.sql(self.model._meta.primary_key) + + return ctx.sql(CommaNodeList(self._returning)) + + +class NoopModelSelect(ModelSelect): + def __sql__(self, ctx): + return self.model._meta.database.get_noop_select(ctx) + + def _get_cursor_wrapper(self, cursor): + return CursorWrapper(cursor) + + +class _ModelWriteQueryHelper(_ModelQueryHelper): + def __init__(self, model, *args, **kwargs): + self.model = model + super(_ModelWriteQueryHelper, self).__init__(model, *args, **kwargs) + + def returning(self, *returning): + accum = [] + for item in returning: + if is_model(item): + accum.extend(item._meta.sorted_fields) + else: + accum.append(item) + return super(_ModelWriteQueryHelper, self).returning(*accum) + + def _set_table_alias(self, ctx): + table = self.model._meta.table + ctx.alias_manager[table] = table.__name__ + + +class ModelUpdate(_ModelWriteQueryHelper, Update): + pass + + +class ModelInsert(_ModelWriteQueryHelper, Insert): + default_row_type = ROW.TUPLE + + def __init__(self, *args, **kwargs): + super(ModelInsert, self).__init__(*args, **kwargs) + if self._returning is None and self.model._meta.database is not None: + if self.model._meta.database.returning_clause: + self._returning = self.model._meta.get_primary_keys() + + def returning(self, *returning): + # By default ModelInsert will yield a `tuple` containing the + # primary-key of the newly inserted row. But if we are explicitly + # specifying a returning clause and have not set a row type, we will + # default to returning model instances instead. + if returning and self._row_type is None: + self._row_type = ROW.MODEL + return super(ModelInsert, self).returning(*returning) + + def get_default_data(self): + return self.model._meta.defaults + + def get_default_columns(self): + fields = self.model._meta.sorted_fields + return fields[1:] if self.model._meta.auto_increment else fields + + +class ModelDelete(_ModelWriteQueryHelper, Delete): + pass + + +class ManyToManyQuery(ModelSelect): + def __init__(self, instance, accessor, rel, *args, **kwargs): + self._instance = instance + self._accessor = accessor + self._src_attr = accessor.src_fk.rel_field.name + self._dest_attr = accessor.dest_fk.rel_field.name + super(ManyToManyQuery, self).__init__(rel, (rel,), *args, **kwargs) + + def _id_list(self, model_or_id_list): + if isinstance(model_or_id_list[0], Model): + return [getattr(obj, self._dest_attr) for obj in model_or_id_list] + return model_or_id_list + + def add(self, value, clear_existing=False): + if clear_existing: + self.clear() + + accessor = self._accessor + src_id = getattr(self._instance, self._src_attr) + if isinstance(value, SelectQuery): + query = value.columns( + Value(src_id), + accessor.dest_fk.rel_field) + accessor.through_model.insert_from( + fields=[accessor.src_fk, accessor.dest_fk], + query=query).execute() + else: + value = ensure_tuple(value) + if not value: return + + inserts = [{ + accessor.src_fk.name: src_id, + accessor.dest_fk.name: rel_id} + for rel_id in self._id_list(value)] + accessor.through_model.insert_many(inserts).execute() + + def remove(self, value): + src_id = getattr(self._instance, self._src_attr) + if isinstance(value, SelectQuery): + column = getattr(value.model, self._dest_attr) + subquery = value.columns(column) + return (self._accessor.through_model + .delete() + .where( + (self._accessor.dest_fk << subquery) & + (self._accessor.src_fk == src_id)) + .execute()) + else: + value = ensure_tuple(value) + if not value: + return + return (self._accessor.through_model + .delete() + .where( + (self._accessor.dest_fk << self._id_list(value)) & + (self._accessor.src_fk == src_id)) + .execute()) + + def clear(self): + src_id = getattr(self._instance, self._src_attr) + return (self._accessor.through_model + .delete() + .where(self._accessor.src_fk == src_id) + .execute()) + + +def safe_python_value(conv_func): + def validate(value): + try: + return conv_func(value) + except (TypeError, ValueError): + return value + return validate + + +class BaseModelCursorWrapper(DictCursorWrapper): + def __init__(self, cursor, model, columns): + super(BaseModelCursorWrapper, self).__init__(cursor) + self.model = model + self.select = columns or [] + + def _initialize_columns(self): + combined = self.model._meta.combined + table = self.model._meta.table + description = self.cursor.description + + self.ncols = len(self.cursor.description) + self.columns = [] + self.converters = converters = [None] * self.ncols + self.fields = fields = [None] * self.ncols + + for idx, description_item in enumerate(description): + column = orig_column = description_item[0] + + # Try to clean-up messy column descriptions when people do not + # provide an alias. The idea is that we take something like: + # SUM("t1"."price") -> "price") -> price + dot_index = column.rfind('.') + if dot_index != -1: + column = column[dot_index + 1:] + column = column.strip('()"`') + self.columns.append(column) + + # Now we'll see what they selected and see if we can improve the + # column-name being returned - e.g. by mapping it to the selected + # field's name. + try: + raw_node = self.select[idx] + except IndexError: + if column in combined: + raw_node = node = combined[column] + else: + continue + else: + node = raw_node.unwrap() + + # If this column was given an alias, then we will use whatever + # alias was returned by the cursor. + is_alias = raw_node.is_alias() + if is_alias: + self.columns[idx] = orig_column + + # Heuristics used to attempt to get the field associated with a + # given SELECT column, so that we can accurately convert the value + # returned by the database-cursor into a Python object. + if isinstance(node, Field): + if raw_node._coerce: + converters[idx] = node.python_value + fields[idx] = node + if not is_alias: + self.columns[idx] = node.name + elif isinstance(node, ColumnBase) and raw_node._converter: + converters[idx] = raw_node._converter + elif isinstance(node, Function) and node._coerce: + if node._python_value is not None: + converters[idx] = node._python_value + elif node.arguments and isinstance(node.arguments[0], Node): + # If the first argument is a field or references a column + # on a Model, try using that field's conversion function. + # This usually works, but we use "safe_python_value()" so + # that if a TypeError or ValueError occurs during + # conversion we can just fall-back to the raw cursor value. + first = node.arguments[0].unwrap() + if isinstance(first, Entity): + path = first._path[-1] # Try to look-up by name. + first = combined.get(path) + if isinstance(first, Field): + converters[idx] = safe_python_value(first.python_value) + elif column in combined: + if node._coerce: + converters[idx] = combined[column].python_value + if isinstance(node, Column) and node.source == table: + fields[idx] = combined[column] + + initialize = _initialize_columns + + def process_row(self, row): + raise NotImplementedError + + +class ModelDictCursorWrapper(BaseModelCursorWrapper): + def process_row(self, row): + result = {} + columns, converters = self.columns, self.converters + fields = self.fields + + for i in range(self.ncols): + attr = columns[i] + if attr in result: continue # Don't overwrite if we have dupes. + if converters[i] is not None: + result[attr] = converters[i](row[i]) + else: + result[attr] = row[i] + + return result + + +class ModelTupleCursorWrapper(ModelDictCursorWrapper): + constructor = tuple + + def process_row(self, row): + columns, converters = self.columns, self.converters + return self.constructor([ + (converters[i](row[i]) if converters[i] is not None else row[i]) + for i in range(self.ncols)]) + + +class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper): + def initialize(self): + self._initialize_columns() + attributes = [] + for i in range(self.ncols): + attributes.append(self.columns[i]) + self.tuple_class = collections.namedtuple('Row', attributes) + self.constructor = lambda row: self.tuple_class(*row) + + +class ModelObjectCursorWrapper(ModelDictCursorWrapper): + def __init__(self, cursor, model, select, constructor): + self.constructor = constructor + self.is_model = is_model(constructor) + super(ModelObjectCursorWrapper, self).__init__(cursor, model, select) + + def process_row(self, row): + data = super(ModelObjectCursorWrapper, self).process_row(row) + if self.is_model: + # Clear out any dirty fields before returning to the user. + obj = self.constructor(__no_default__=1, **data) + obj._dirty.clear() + return obj + else: + return self.constructor(**data) + + +class ModelCursorWrapper(BaseModelCursorWrapper): + def __init__(self, cursor, model, select, from_list, joins): + super(ModelCursorWrapper, self).__init__(cursor, model, select) + self.from_list = from_list + self.joins = joins + + def initialize(self): + self._initialize_columns() + selected_src = set([field.model for field in self.fields + if field is not None]) + select, columns = self.select, self.columns + + self.key_to_constructor = {self.model: self.model} + self.src_is_dest = {} + self.src_to_dest = [] + accum = collections.deque(self.from_list) + dests = set() + + while accum: + curr = accum.popleft() + if isinstance(curr, Join): + accum.append(curr.lhs) + accum.append(curr.rhs) + continue + + if curr not in self.joins: + continue + + is_dict = isinstance(curr, dict) + for key, attr, constructor, join_type in self.joins[curr]: + if key not in self.key_to_constructor: + self.key_to_constructor[key] = constructor + + # (src, attr, dest, is_dict, join_type). + self.src_to_dest.append((curr, attr, key, is_dict, + join_type)) + dests.add(key) + accum.append(key) + + # Ensure that we accommodate everything selected. + for src in selected_src: + if src not in self.key_to_constructor: + if is_model(src): + self.key_to_constructor[src] = src + elif isinstance(src, ModelAlias): + self.key_to_constructor[src] = src.model + + # Indicate which sources are also dests. + for src, _, dest, _, _ in self.src_to_dest: + self.src_is_dest[src] = src in dests and (dest in selected_src + or src in selected_src) + + self.column_keys = [] + for idx, node in enumerate(select): + key = self.model + field = self.fields[idx] + if field is not None: + if isinstance(field, FieldAlias): + key = field.source + else: + key = field.model + elif isinstance(node, BindTo): + if node.dest not in self.key_to_constructor: + raise ValueError('%s specifies bind-to %s, but %s is not ' + 'among the selected sources.' % + (node.unwrap(), node.dest, node.dest)) + key = node.dest + else: + if isinstance(node, Node): + node = node.unwrap() + if isinstance(node, Column): + key = node.source + + self.column_keys.append(key) + + def process_row(self, row): + objects = {} + object_list = [] + for key, constructor in self.key_to_constructor.items(): + objects[key] = constructor(__no_default__=True) + object_list.append(objects[key]) + + default_instance = objects[self.model] + + set_keys = set() + for idx, key in enumerate(self.column_keys): + # Get the instance corresponding to the selected column/value, + # falling back to the "root" model instance. + instance = objects.get(key, default_instance) + column = self.columns[idx] + value = row[idx] + if value is not None: + set_keys.add(key) + if self.converters[idx]: + value = self.converters[idx](value) + + if isinstance(instance, dict): + instance[column] = value + else: + setattr(instance, column, value) + + # Need to do some analysis on the joins before this. + for (src, attr, dest, is_dict, join_type) in self.src_to_dest: + instance = objects[src] + try: + joined_instance = objects[dest] + except KeyError: + continue + + # If no fields were set on the destination instance then do not + # assign an "empty" instance. + if instance is None or dest is None or \ + (dest not in set_keys and not self.src_is_dest.get(dest)): + continue + + # If no fields were set on either the source or the destination, + # then we have nothing to do here. + if instance not in set_keys and dest not in set_keys \ + and join_type.endswith('OUTER JOIN'): + continue + + if is_dict: + instance[attr] = joined_instance + else: + setattr(instance, attr, joined_instance) + + # When instantiating models from a cursor, we clear the dirty fields. + for instance in object_list: + if isinstance(instance, Model): + instance._dirty.clear() + + return objects[self.model] + + +class PrefetchQuery(collections.namedtuple('_PrefetchQuery', ( + 'query', 'fields', 'is_backref', 'rel_models', 'field_to_name', 'model'))): + def __new__(cls, query, fields=None, is_backref=None, rel_models=None, + field_to_name=None, model=None): + if fields: + if is_backref: + if rel_models is None: + rel_models = [field.model for field in fields] + foreign_key_attrs = [field.rel_field.name for field in fields] + else: + if rel_models is None: + rel_models = [field.rel_model for field in fields] + foreign_key_attrs = [field.name for field in fields] + field_to_name = list(zip(fields, foreign_key_attrs)) + model = query.model + return super(PrefetchQuery, cls).__new__( + cls, query, fields, is_backref, rel_models, field_to_name, model) + + def populate_instance(self, instance, id_map): + if self.is_backref: + for field in self.fields: + identifier = instance.__data__[field.name] + key = (field, identifier) + if key in id_map: + setattr(instance, field.name, id_map[key]) + else: + for field, attname in self.field_to_name: + identifier = instance.__data__[field.rel_field.name] + key = (field, identifier) + rel_instances = id_map.get(key, []) + for inst in rel_instances: + setattr(inst, attname, instance) + inst._dirty.clear() + setattr(instance, field.backref, rel_instances) + + def store_instance(self, instance, id_map): + for field, attname in self.field_to_name: + identity = field.rel_field.python_value(instance.__data__[attname]) + key = (field, identity) + if self.is_backref: + id_map[key] = instance + else: + id_map.setdefault(key, []) + id_map[key].append(instance) + + +def prefetch_add_subquery(sq, subqueries, prefetch_type): + fixed_queries = [PrefetchQuery(sq)] + for i, subquery in enumerate(subqueries): + if isinstance(subquery, tuple): + subquery, target_model = subquery + else: + target_model = None + if not isinstance(subquery, Query) and is_model(subquery) or \ + isinstance(subquery, ModelAlias): + subquery = subquery.select() + subquery_model = subquery.model + for j in reversed(range(i + 1)): + fks = backrefs = None + fixed = fixed_queries[j] + last_query = fixed.query + last_model = last_obj = fixed.model + if isinstance(last_model, ModelAlias): + last_model = last_model.model + rels = subquery_model._meta.model_refs.get(last_model, []) + if rels: + fks = [getattr(subquery_model, fk.name) for fk in rels] + pks = [getattr(last_obj, fk.rel_field.name) for fk in rels] + else: + backrefs = subquery_model._meta.model_backrefs.get(last_model) + if (fks or backrefs) and ((target_model is last_obj) or + (target_model is None)): + break + + else: + tgt_err = ' using %s' % target_model if target_model else '' + raise AttributeError('Error: unable to find foreign key for ' + 'query: %s%s' % (subquery, tgt_err)) + + dest = (target_model,) if target_model else None + + if fks: + if prefetch_type == PREFETCH_TYPE.WHERE: + expr = reduce(operator.or_, [ + (fk << last_query.select(pk)) + for (fk, pk) in zip(fks, pks)]) + subquery = subquery.where(expr) + elif prefetch_type == PREFETCH_TYPE.JOIN: + expr = [] + select_pks = set() + for fk, pk in zip(fks, pks): + expr.append(getattr(last_query.c, pk.column_name) == fk) + select_pks.add(pk) + subquery = subquery.distinct().join( + last_query.select(*select_pks), + on=reduce(operator.or_, expr)) + fixed_queries.append(PrefetchQuery(subquery, fks, False, dest)) + elif backrefs: + expr = [] + fields = [] + for backref in backrefs: + rel_field = getattr(subquery_model, backref.rel_field.name) + fk_field = getattr(last_obj, backref.name) + fields.append((rel_field, fk_field)) + + if prefetch_type == PREFETCH_TYPE.WHERE: + for rel_field, fk_field in fields: + expr.append(rel_field << last_query.select(fk_field)) + subquery = subquery.where(reduce(operator.or_, expr)) + elif prefetch_type == PREFETCH_TYPE.JOIN: + select_fks = [] + for rel_field, fk_field in fields: + select_fks.append(fk_field) + target = getattr(last_query.c, fk_field.column_name) + expr.append(rel_field == target) + subquery = subquery.distinct().join( + last_query.select(*select_fks), + on=reduce(operator.or_, expr)) + fixed_queries.append(PrefetchQuery(subquery, backrefs, True, dest)) + + return fixed_queries + + +def prefetch(sq, *subqueries, **kwargs): + if not subqueries: + return sq + prefetch_type = kwargs.pop('prefetch_type', PREFETCH_TYPE.WHERE) + if kwargs: + raise ValueError('Unrecognized arguments: %s' % kwargs) + + fixed_queries = prefetch_add_subquery(sq, subqueries, prefetch_type) + deps = {} + rel_map = {} + for pq in reversed(fixed_queries): + query_model = pq.model + if pq.fields: + for rel_model in pq.rel_models: + rel_map.setdefault(rel_model, []) + rel_map[rel_model].append(pq) + + deps.setdefault(query_model, {}) + id_map = deps[query_model] + has_relations = bool(rel_map.get(query_model)) + + for instance in pq.query: + if pq.fields: + pq.store_instance(instance, id_map) + if has_relations: + for rel in rel_map[query_model]: + rel.populate_instance(instance, deps[rel.model]) + + return list(pq.query) diff --git a/python3.10libs/playhouse/__init__.py b/python3.10libs/playhouse/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python3.10libs/playhouse/apsw_ext.py b/python3.10libs/playhouse/apsw_ext.py new file mode 100644 index 0000000..ea32f98 --- /dev/null +++ b/python3.10libs/playhouse/apsw_ext.py @@ -0,0 +1,158 @@ +""" +Peewee integration with APSW, "another python sqlite wrapper". + +Project page: https://rogerbinns.github.io/apsw/ + +APSW is a really neat library that provides a thin wrapper on top of SQLite's +C interface. + +Here are just a few reasons to use APSW, taken from the documentation: + +* APSW gives all functionality of SQLite, including virtual tables, virtual + file system, blob i/o, backups and file control. +* Connections can be shared across threads without any additional locking. +* Transactions are managed explicitly by your code. +* APSW can handle nested transactions. +* Unicode is handled correctly. +* APSW is faster. +""" +import apsw +from peewee import * +from peewee import __exception_wrapper__ +from peewee import BooleanField as _BooleanField +from peewee import DateField as _DateField +from peewee import DateTimeField as _DateTimeField +from peewee import DecimalField as _DecimalField +from peewee import Insert +from peewee import TimeField as _TimeField +from peewee import logger + +from playhouse.sqlite_ext import SqliteExtDatabase + + +class APSWDatabase(SqliteExtDatabase): + server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.')) + + def __init__(self, database, **kwargs): + self._modules = {} + super(APSWDatabase, self).__init__(database, **kwargs) + + def register_module(self, mod_name, mod_inst): + self._modules[mod_name] = mod_inst + if not self.is_closed(): + self.connection().createmodule(mod_name, mod_inst) + + def unregister_module(self, mod_name): + del(self._modules[mod_name]) + + def _connect(self): + conn = apsw.Connection(self.database, **self.connect_params) + if self._timeout is not None: + conn.setbusytimeout(self._timeout * 1000) + try: + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def _add_conn_hooks(self, conn): + super(APSWDatabase, self)._add_conn_hooks(conn) + self._load_modules(conn) # APSW-only. + + def _load_modules(self, conn): + for mod_name, mod_inst in self._modules.items(): + conn.createmodule(mod_name, mod_inst) + return conn + + def _load_aggregates(self, conn): + for name, (klass, num_params) in self._aggregates.items(): + def make_aggregate(): + return (klass(), klass.step, klass.finalize) + conn.createaggregatefunction(name, make_aggregate) + + def _load_collations(self, conn): + for name, fn in self._collations.items(): + conn.createcollation(name, fn) + + def _load_functions(self, conn): + for name, (fn, num_params, deterministic) in self._functions.items(): + args = (deterministic,) if deterministic else () + conn.createscalarfunction(name, fn, num_params, *args) + + def _load_extensions(self, conn): + conn.enableloadextension(True) + for extension in self._extensions: + conn.loadextension(extension) + + def load_extension(self, extension): + self._extensions.add(extension) + if not self.is_closed(): + conn = self.connection() + conn.enableloadextension(True) + conn.loadextension(extension) + + def last_insert_id(self, cursor, query_type=None): + if not self.returning_clause: + return cursor.getconnection().last_insert_rowid() + elif query_type == Insert.SIMPLE: + try: + return cursor[0][0] + except (AttributeError, IndexError, TypeError): + pass + return cursor + + def rows_affected(self, cursor): + try: + return cursor.getconnection().changes() + except AttributeError: + return cursor.cursor.getconnection().changes() # RETURNING query. + + def begin(self, lock_type='deferred'): + self.cursor().execute('begin %s;' % lock_type) + + def commit(self): + with __exception_wrapper__: + curs = self.cursor() + if curs.getconnection().getautocommit(): + return False + curs.execute('commit;') + return True + + def rollback(self): + with __exception_wrapper__: + curs = self.cursor() + if curs.getconnection().getautocommit(): + return False + curs.execute('rollback;') + return True + + def execute_sql(self, sql, params=None): + logger.debug((sql, params)) + with __exception_wrapper__: + cursor = self.cursor() + cursor.execute(sql, params or ()) + return cursor + + +def nh(s, v): + if v is not None: + return str(v) + +class BooleanField(_BooleanField): + def db_value(self, v): + v = super(BooleanField, self).db_value(v) + if v is not None: + return v and 1 or 0 + +class DateField(_DateField): + db_value = nh + +class TimeField(_TimeField): + db_value = nh + +class DateTimeField(_DateTimeField): + db_value = nh + +class DecimalField(_DecimalField): + db_value = nh diff --git a/python3.10libs/playhouse/cockroachdb.py b/python3.10libs/playhouse/cockroachdb.py new file mode 100644 index 0000000..a07362e --- /dev/null +++ b/python3.10libs/playhouse/cockroachdb.py @@ -0,0 +1,223 @@ +import functools +import re +import sys + +from peewee import * +from peewee import _atomic +from peewee import _manual +from peewee import ColumnMetadata # (name, data_type, null, primary_key, table, default) +from peewee import EnclosedNodeList +from peewee import Entity +from peewee import ForeignKeyMetadata # (column, dest_table, dest_column, table). +from peewee import IndexMetadata +from peewee import NodeList +from playhouse.pool import _PooledPostgresqlDatabase +try: + from playhouse.postgres_ext import ArrayField + from playhouse.postgres_ext import BinaryJSONField + from playhouse.postgres_ext import IntervalField + JSONField = BinaryJSONField +except ImportError: # psycopg2 not installed, ignore. + ArrayField = BinaryJSONField = IntervalField = JSONField = None + +if sys.version_info[0] > 2: + basestring = str + + +NESTED_TX_MIN_VERSION = 200100 + +TXN_ERR_MSG = ('CockroachDB does not support nested transactions. You may ' + 'alternatively use the @transaction context-manager/decorator, ' + 'which only wraps the outer-most block in transactional logic. ' + 'To run a transaction with automatic retries, use the ' + 'run_transaction() helper.') + +class ExceededMaxAttempts(OperationalError): pass + + +class UUIDKeyField(UUIDField): + auto_increment = True + + def __init__(self, *args, **kwargs): + if kwargs.get('constraints'): + raise ValueError('%s cannot specify constraints.' % type(self)) + kwargs['constraints'] = [SQL('DEFAULT gen_random_uuid()')] + kwargs.setdefault('primary_key', True) + super(UUIDKeyField, self).__init__(*args, **kwargs) + + +class RowIDField(AutoField): + field_type = 'INT' + + def __init__(self, *args, **kwargs): + if kwargs.get('constraints'): + raise ValueError('%s cannot specify constraints.' % type(self)) + kwargs['constraints'] = [SQL('DEFAULT unique_rowid()')] + super(RowIDField, self).__init__(*args, **kwargs) + + +class CockroachDatabase(PostgresqlDatabase): + field_types = PostgresqlDatabase.field_types.copy() + field_types.update({ + 'BLOB': 'BYTES', + }) + + release_after_rollback = True + + def __init__(self, database, *args, **kwargs): + # Unless a DSN or database connection-url were specified, provide + # convenient defaults for the user and port. + if 'dsn' not in kwargs and (database and + not database.startswith('postgresql://')): + kwargs.setdefault('user', 'root') + kwargs.setdefault('port', 26257) + super(CockroachDatabase, self).__init__(database, *args, **kwargs) + + def _set_server_version(self, conn): + curs = conn.cursor() + curs.execute('select version()') + raw, = curs.fetchone() + match_obj = re.match(r'^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw) + if match_obj is not None: + clean = '%d%02d%02d' % tuple(int(i) for i in match_obj.groups()) + self.server_version = int(clean) # 19.1.5 -> 190105. + else: + # Fallback to use whatever cockroachdb tells us via protocol. + super(CockroachDatabase, self)._set_server_version(conn) + + def _get_pk_constraint(self, table, schema=None): + query = ('SELECT constraint_name ' + 'FROM information_schema.table_constraints ' + 'WHERE table_name = %s AND table_schema = %s ' + 'AND constraint_type = %s') + cursor = self.execute_sql(query, (table, schema or 'public', + 'PRIMARY KEY')) + row = cursor.fetchone() + return row and row[0] or None + + def get_indexes(self, table, schema=None): + # The primary-key index is returned by default, so we will just strip + # it out here. + indexes = super(CockroachDatabase, self).get_indexes(table, schema) + pkc = self._get_pk_constraint(table, schema) + return [idx for idx in indexes if (not pkc) or (idx.name != pkc)] + + def conflict_statement(self, on_conflict, query): + if not on_conflict._action: return + + action = on_conflict._action.lower() + if action in ('replace', 'upsert'): + return SQL('UPSERT') + elif action not in ('ignore', 'nothing', 'update'): + raise ValueError('Un-supported action for conflict resolution. ' + 'CockroachDB supports REPLACE (UPSERT), IGNORE ' + 'and UPDATE.') + + def conflict_update(self, oc, query): + action = oc._action.lower() if oc._action else '' + if action in ('ignore', 'nothing'): + parts = [SQL('ON CONFLICT')] + if oc._conflict_target: + parts.append(EnclosedNodeList([ + Entity(col) if isinstance(col, basestring) else col + for col in oc._conflict_target])) + parts.append(SQL('DO NOTHING')) + return NodeList(parts) + elif action in ('replace', 'upsert'): + # No special stuff is necessary, this is just indicated by starting + # the statement with UPSERT instead of INSERT. + return + elif oc._conflict_constraint: + raise ValueError('CockroachDB does not support the usage of a ' + 'constraint name. Use the column(s) instead.') + + return super(CockroachDatabase, self).conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.extract(date_part, date_field) + + def from_timestamp(self, date_field): + # CRDB does not allow casting a decimal/float to timestamp, so we first + # cast to int, then to timestamptz. + return date_field.cast('int').cast('timestamptz') + + def begin(self, system_time=None, priority=None): + super(CockroachDatabase, self).begin() + if system_time is not None: + self.cursor().execute('SET TRANSACTION AS OF SYSTEM TIME %s', + (system_time,)) + if priority is not None: + priority = priority.lower() + if priority not in ('low', 'normal', 'high'): + raise ValueError('priority must be low, normal or high') + self.cursor().execute('SET TRANSACTION PRIORITY %s' % priority) + + def atomic(self, system_time=None, priority=None): + if self.is_closed(): self.connect() # Side-effect, set server version. + if self.server_version < NESTED_TX_MIN_VERSION: + return _crdb_atomic(self, system_time, priority) + return super(CockroachDatabase, self).atomic(system_time, priority) + + def savepoint(self): + if self.is_closed(): self.connect() # Side-effect, set server version. + if self.server_version < NESTED_TX_MIN_VERSION: + raise NotImplementedError(TXN_ERR_MSG) + return super(CockroachDatabase, self).savepoint() + + def retry_transaction(self, max_attempts=None, system_time=None, + priority=None): + def deco(cb): + @functools.wraps(cb) + def new_fn(): + return run_transaction(self, cb, max_attempts, system_time, + priority) + return new_fn + return deco + + def run_transaction(self, cb, max_attempts=None, system_time=None, + priority=None): + return run_transaction(self, cb, max_attempts, system_time, priority) + + +class _crdb_atomic(_atomic): + def __enter__(self): + if self.db.transaction_depth() > 0: + if not isinstance(self.db.top_transaction(), _manual): + raise NotImplementedError(TXN_ERR_MSG) + return super(_crdb_atomic, self).__enter__() + + +def run_transaction(db, callback, max_attempts=None, system_time=None, + priority=None): + """ + Run transactional SQL in a transaction with automatic retries. + + User-provided `callback`: + * Must accept one parameter, the `db` instance representing the connection + the transaction is running under. + * Must not attempt to commit, rollback or otherwise manage transactions. + * May be called more than once. + * Should ideally only contain SQL operations. + + Additionally, the database must not have any open transaction at the time + this function is called, as CRDB does not support nested transactions. + """ + max_attempts = max_attempts or -1 + with db.atomic(system_time=system_time, priority=priority) as txn: + db.execute_sql('SAVEPOINT cockroach_restart') + while max_attempts != 0: + try: + result = callback(db) + db.execute_sql('RELEASE SAVEPOINT cockroach_restart') + return result + except OperationalError as exc: + if exc.orig.pgcode == '40001': + max_attempts -= 1 + db.execute_sql('ROLLBACK TO SAVEPOINT cockroach_restart') + continue + raise + raise ExceededMaxAttempts(None, 'unable to commit transaction') + + +class PooledCockroachDatabase(_PooledPostgresqlDatabase, CockroachDatabase): + pass diff --git a/python3.10libs/playhouse/dataset.py b/python3.10libs/playhouse/dataset.py new file mode 100644 index 0000000..9ccf662 --- /dev/null +++ b/python3.10libs/playhouse/dataset.py @@ -0,0 +1,462 @@ +import csv +import datetime +from decimal import Decimal +import json +import operator +try: + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse +import sys +import uuid + +from peewee import * +from playhouse.db_url import connect +from playhouse.migrate import migrate +from playhouse.migrate import SchemaMigrator +from playhouse.reflection import Introspector + +if sys.version_info[0] == 3: + basestring = str + from functools import reduce + def open_file(f, mode, encoding='utf8'): + return open(f, mode, encoding=encoding) +else: + def open_file(f, mode, encoding='utf8'): + return open(f, mode) + + +class DataSet(object): + def __init__(self, url, include_views=False, **kwargs): + if isinstance(url, Database): + self._url = None + self._database = url + self._database_path = self._database.database + else: + self._url = url + parse_result = urlparse(url) + self._database_path = parse_result.path[1:] + + # Connect to the database. + self._database = connect(url) + + # Open a connection if one does not already exist. + self._database.connect(reuse_if_open=True) + + # Introspect the database and generate models. + self._introspector = Introspector.from_database(self._database) + self._include_views = include_views + self._models = self._introspector.generate_models( + skip_invalid=True, + literal_column_names=True, + include_views=self._include_views, + **kwargs) + self._migrator = SchemaMigrator.from_database(self._database) + + class BaseModel(Model): + class Meta: + database = self._database + self._base_model = BaseModel + self._export_formats = self.get_export_formats() + self._import_formats = self.get_import_formats() + + def __repr__(self): + return '' % self._database_path + + def get_export_formats(self): + return { + 'csv': CSVExporter, + 'json': JSONExporter, + 'tsv': TSVExporter} + + def get_import_formats(self): + return { + 'csv': CSVImporter, + 'json': JSONImporter, + 'tsv': TSVImporter} + + def __getitem__(self, table): + if table not in self._models and table in self.tables: + self.update_cache(table) + return Table(self, table, self._models.get(table)) + + @property + def tables(self): + tables = self._database.get_tables() + if self._include_views: + tables += self.views + return tables + + @property + def views(self): + return [v.name for v in self._database.get_views()] + + def __contains__(self, table): + return table in self.tables + + def connect(self, reuse_if_open=False): + self._database.connect(reuse_if_open=reuse_if_open) + + def close(self): + self._database.close() + + def update_cache(self, table=None): + if table: + dependencies = [table] + if table in self._models: + model_class = self._models[table] + dependencies.extend([ + related._meta.table_name for _, related, _ in + model_class._meta.model_graph()]) + else: + dependencies.extend(self.get_table_dependencies(table)) + else: + dependencies = None # Update all tables. + self._models = {} + updated = self._introspector.generate_models( + skip_invalid=True, + table_names=dependencies, + literal_column_names=True, + include_views=self._include_views) + self._models.update(updated) + + def get_table_dependencies(self, table): + stack = [table] + accum = [] + seen = set() + while stack: + table = stack.pop() + for fk_meta in self._database.get_foreign_keys(table): + dest = fk_meta.dest_table + if dest not in seen: + stack.append(dest) + accum.append(dest) + return accum + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._database.is_closed(): + self.close() + + def query(self, sql, params=None): + return self._database.execute_sql(sql, params) + + def transaction(self): + return self._database.atomic() + + def _check_arguments(self, filename, file_obj, format, format_dict): + if filename and file_obj: + raise ValueError('file is over-specified. Please use either ' + 'filename or file_obj, but not both.') + if not filename and not file_obj: + raise ValueError('A filename or file-like object must be ' + 'specified.') + if format not in format_dict: + valid_formats = ', '.join(sorted(format_dict.keys())) + raise ValueError('Unsupported format "%s". Use one of %s.' % ( + format, valid_formats)) + + def freeze(self, query, format='csv', filename=None, file_obj=None, + encoding='utf8', **kwargs): + self._check_arguments(filename, file_obj, format, self._export_formats) + if filename: + file_obj = open_file(filename, 'w', encoding) + + exporter = self._export_formats[format](query) + exporter.export(file_obj, **kwargs) + + if filename: + file_obj.close() + + def thaw(self, table, format='csv', filename=None, file_obj=None, + strict=False, encoding='utf8', **kwargs): + self._check_arguments(filename, file_obj, format, self._export_formats) + if filename: + file_obj = open_file(filename, 'r', encoding) + + importer = self._import_formats[format](self[table], strict) + count = importer.load(file_obj, **kwargs) + + if filename: + file_obj.close() + + return count + + +class Table(object): + def __init__(self, dataset, name, model_class): + self.dataset = dataset + self.name = name + if model_class is None: + model_class = self._create_model() + model_class.create_table() + self.dataset._models[name] = model_class + + @property + def model_class(self): + return self.dataset._models[self.name] + + def __repr__(self): + return '' % self.name + + def __len__(self): + return self.find().count() + + def __iter__(self): + return iter(self.find().iterator()) + + def _create_model(self): + class Meta: + table_name = self.name + return type( + str(self.name), + (self.dataset._base_model,), + {'Meta': Meta}) + + def create_index(self, columns, unique=False): + index = ModelIndex(self.model_class, columns, unique=unique) + self.model_class.add_index(index) + self.dataset._database.execute(index) + + def _guess_field_type(self, value): + if isinstance(value, basestring): + return TextField + if isinstance(value, (datetime.date, datetime.datetime)): + return DateTimeField + elif value is True or value is False: + return BooleanField + elif isinstance(value, int): + return IntegerField + elif isinstance(value, float): + return FloatField + elif isinstance(value, Decimal): + return DecimalField + return TextField + + @property + def columns(self): + return [f.name for f in self.model_class._meta.sorted_fields] + + def _migrate_new_columns(self, data): + new_keys = set(data) - set(self.model_class._meta.fields) + new_keys -= set(self.model_class._meta.columns) + if new_keys: + operations = [] + for key in new_keys: + field_class = self._guess_field_type(data[key]) + field = field_class(null=True) + operations.append( + self.dataset._migrator.add_column(self.name, key, field)) + field.bind(self.model_class, key) + + migrate(*operations) + + self.dataset.update_cache(self.name) + + def __getitem__(self, item): + try: + return self.model_class[item] + except self.model_class.DoesNotExist: + pass + + def __setitem__(self, item, value): + if not isinstance(value, dict): + raise ValueError('Table.__setitem__() value must be a dict') + + pk = self.model_class._meta.primary_key + value[pk.name] = item + + try: + with self.dataset.transaction() as txn: + self.insert(**value) + except IntegrityError: + self.dataset.update_cache(self.name) + self.update(columns=[pk.name], **value) + + def __delitem__(self, item): + del self.model_class[item] + + def insert(self, **data): + self._migrate_new_columns(data) + return self.model_class.insert(**data).execute() + + def _apply_where(self, query, filters, conjunction=None): + conjunction = conjunction or operator.and_ + if filters: + expressions = [ + (self.model_class._meta.fields[column] == value) + for column, value in filters.items()] + query = query.where(reduce(conjunction, expressions)) + return query + + def update(self, columns=None, conjunction=None, **data): + self._migrate_new_columns(data) + filters = {} + if columns: + for column in columns: + filters[column] = data.pop(column) + + return self._apply_where( + self.model_class.update(**data), + filters, + conjunction).execute() + + def _query(self, **query): + return self._apply_where(self.model_class.select(), query) + + def find(self, **query): + return self._query(**query).dicts() + + def find_one(self, **query): + try: + return self.find(**query).get() + except self.model_class.DoesNotExist: + return None + + def all(self): + return self.find() + + def delete(self, **query): + return self._apply_where(self.model_class.delete(), query).execute() + + def freeze(self, *args, **kwargs): + return self.dataset.freeze(self.all(), *args, **kwargs) + + def thaw(self, *args, **kwargs): + return self.dataset.thaw(self.name, *args, **kwargs) + + +class Exporter(object): + def __init__(self, query): + self.query = query + + def export(self, file_obj): + raise NotImplementedError + + +class JSONExporter(Exporter): + def __init__(self, query, iso8601_datetimes=False): + super(JSONExporter, self).__init__(query) + self.iso8601_datetimes = iso8601_datetimes + + def _make_default(self): + datetime_types = (datetime.datetime, datetime.date, datetime.time) + + if self.iso8601_datetimes: + def default(o): + if isinstance(o, datetime_types): + return o.isoformat() + elif isinstance(o, (Decimal, uuid.UUID)): + return str(o) + raise TypeError('Unable to serialize %r as JSON' % o) + else: + def default(o): + if isinstance(o, datetime_types + (Decimal, uuid.UUID)): + return str(o) + raise TypeError('Unable to serialize %r as JSON' % o) + return default + + def export(self, file_obj, **kwargs): + json.dump( + list(self.query), + file_obj, + default=self._make_default(), + **kwargs) + + +class CSVExporter(Exporter): + def export(self, file_obj, header=True, **kwargs): + writer = csv.writer(file_obj, **kwargs) + tuples = self.query.tuples().execute() + tuples.initialize() + if header and getattr(tuples, 'columns', None): + writer.writerow([column for column in tuples.columns]) + for row in tuples: + writer.writerow(row) + + +class TSVExporter(CSVExporter): + def export(self, file_obj, header=True, **kwargs): + kwargs.setdefault('delimiter', '\t') + return super(TSVExporter, self).export(file_obj, header, **kwargs) + + +class Importer(object): + def __init__(self, table, strict=False): + self.table = table + self.strict = strict + + model = self.table.model_class + self.columns = model._meta.columns + self.columns.update(model._meta.fields) + + def load(self, file_obj): + raise NotImplementedError + + +class JSONImporter(Importer): + def load(self, file_obj, **kwargs): + data = json.load(file_obj, **kwargs) + count = 0 + + for row in data: + if self.strict: + obj = {} + for key in row: + field = self.columns.get(key) + if field is not None: + obj[field.name] = field.python_value(row[key]) + else: + obj = row + + if obj: + self.table.insert(**obj) + count += 1 + + return count + + +class CSVImporter(Importer): + def load(self, file_obj, header=True, **kwargs): + count = 0 + reader = csv.reader(file_obj, **kwargs) + if header: + try: + header_keys = next(reader) + except StopIteration: + return count + + if self.strict: + header_fields = [] + for idx, key in enumerate(header_keys): + if key in self.columns: + header_fields.append((idx, self.columns[key])) + else: + header_fields = list(enumerate(header_keys)) + else: + header_fields = list(enumerate(self.model._meta.sorted_fields)) + + if not header_fields: + return count + + for row in reader: + obj = {} + for idx, field in header_fields: + if self.strict: + obj[field.name] = field.python_value(row[idx]) + else: + obj[field] = row[idx] + + self.table.insert(**obj) + count += 1 + + return count + + +class TSVImporter(CSVImporter): + def load(self, file_obj, header=True, **kwargs): + kwargs.setdefault('delimiter', '\t') + return super(TSVImporter, self).load(file_obj, header, **kwargs) diff --git a/python3.10libs/playhouse/db_url.py b/python3.10libs/playhouse/db_url.py new file mode 100644 index 0000000..7176c80 --- /dev/null +++ b/python3.10libs/playhouse/db_url.py @@ -0,0 +1,130 @@ +try: + from urlparse import parse_qsl, unquote, urlparse +except ImportError: + from urllib.parse import parse_qsl, unquote, urlparse + +from peewee import * +from playhouse.cockroachdb import CockroachDatabase +from playhouse.cockroachdb import PooledCockroachDatabase +from playhouse.pool import PooledMySQLDatabase +from playhouse.pool import PooledPostgresqlDatabase +from playhouse.pool import PooledSqliteDatabase +from playhouse.pool import PooledSqliteExtDatabase +from playhouse.sqlite_ext import SqliteExtDatabase + + +schemes = { + 'cockroachdb': CockroachDatabase, + 'cockroachdb+pool': PooledCockroachDatabase, + 'crdb': CockroachDatabase, + 'crdb+pool': PooledCockroachDatabase, + 'mysql': MySQLDatabase, + 'mysql+pool': PooledMySQLDatabase, + 'postgres': PostgresqlDatabase, + 'postgresql': PostgresqlDatabase, + 'postgres+pool': PooledPostgresqlDatabase, + 'postgresql+pool': PooledPostgresqlDatabase, + 'sqlite': SqliteDatabase, + 'sqliteext': SqliteExtDatabase, + 'sqlite+pool': PooledSqliteDatabase, + 'sqliteext+pool': PooledSqliteExtDatabase, +} + +def register_database(db_class, *names): + global schemes + for name in names: + schemes[name] = db_class + +def parseresult_to_dict(parsed, unquote_password=False): + + # urlparse in python 2.6 is broken so query will be empty and instead + # appended to path complete with '?' + path_parts = parsed.path[1:].split('?') + try: + query = path_parts[1] + except IndexError: + query = parsed.query + + connect_kwargs = {'database': path_parts[0]} + if parsed.username: + connect_kwargs['user'] = parsed.username + if parsed.password: + connect_kwargs['password'] = parsed.password + if unquote_password: + connect_kwargs['password'] = unquote(connect_kwargs['password']) + if parsed.hostname: + connect_kwargs['host'] = parsed.hostname + if parsed.port: + connect_kwargs['port'] = parsed.port + + # Adjust parameters for MySQL. + if parsed.scheme == 'mysql' and 'password' in connect_kwargs: + connect_kwargs['passwd'] = connect_kwargs.pop('password') + elif 'sqlite' in parsed.scheme and not connect_kwargs['database']: + connect_kwargs['database'] = ':memory:' + + # Get additional connection args from the query string + qs_args = parse_qsl(query, keep_blank_values=True) + for key, value in qs_args: + if value.lower() == 'false': + value = False + elif value.lower() == 'true': + value = True + elif value.isdigit(): + value = int(value) + elif '.' in value and all(p.isdigit() for p in value.split('.', 1)): + try: + value = float(value) + except ValueError: + pass + elif value.lower() in ('null', 'none'): + value = None + + connect_kwargs[key] = value + + return connect_kwargs + +def parse(url, unquote_password=False): + parsed = urlparse(url) + return parseresult_to_dict(parsed, unquote_password) + +def connect(url, unquote_password=False, **connect_params): + parsed = urlparse(url) + connect_kwargs = parseresult_to_dict(parsed, unquote_password) + connect_kwargs.update(connect_params) + database_class = schemes.get(parsed.scheme) + + if database_class is None: + if database_class in schemes: + raise RuntimeError('Attempted to use "%s" but a required library ' + 'could not be imported.' % parsed.scheme) + else: + raise RuntimeError('Unrecognized or unsupported scheme: "%s".' % + parsed.scheme) + + return database_class(**connect_kwargs) + +# Conditionally register additional databases. +try: + from playhouse.pool import PooledPostgresqlExtDatabase +except ImportError: + pass +else: + register_database( + PooledPostgresqlExtDatabase, + 'postgresext+pool', + 'postgresqlext+pool') + +try: + from playhouse.apsw_ext import APSWDatabase +except ImportError: + pass +else: + register_database(APSWDatabase, 'apsw') + +try: + from playhouse.postgres_ext import PostgresqlExtDatabase +except ImportError: + pass +else: + register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext') diff --git a/python3.10libs/playhouse/fields.py b/python3.10libs/playhouse/fields.py new file mode 100644 index 0000000..d024149 --- /dev/null +++ b/python3.10libs/playhouse/fields.py @@ -0,0 +1,60 @@ +try: + import bz2 +except ImportError: + bz2 = None +try: + import zlib +except ImportError: + zlib = None +try: + import cPickle as pickle +except ImportError: + import pickle + +from peewee import BlobField +from peewee import buffer_type + + +class CompressedField(BlobField): + ZLIB = 'zlib' + BZ2 = 'bz2' + algorithm_to_import = { + ZLIB: zlib, + BZ2: bz2, + } + + def __init__(self, compression_level=6, algorithm=ZLIB, *args, + **kwargs): + self.compression_level = compression_level + if algorithm not in self.algorithm_to_import: + raise ValueError('Unrecognized algorithm %s' % algorithm) + compress_module = self.algorithm_to_import[algorithm] + if compress_module is None: + raise ValueError('Missing library required for %s.' % algorithm) + + self.algorithm = algorithm + self.compress = compress_module.compress + self.decompress = compress_module.decompress + super(CompressedField, self).__init__(*args, **kwargs) + + def python_value(self, value): + if value is not None: + return self.decompress(value) + + def db_value(self, value): + if value is not None: + return self._constructor( + self.compress(value, self.compression_level)) + + +class PickleField(BlobField): + def python_value(self, value): + if value is not None: + if isinstance(value, buffer_type): + value = bytes(value) + return pickle.loads(value) + + def db_value(self, value): + if value is not None: + pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL) + return self._constructor(pickled) diff --git a/python3.10libs/playhouse/flask_utils.py b/python3.10libs/playhouse/flask_utils.py new file mode 100644 index 0000000..c1023fe --- /dev/null +++ b/python3.10libs/playhouse/flask_utils.py @@ -0,0 +1,241 @@ +import math +import sys + +from flask import abort +from flask import render_template +from flask import request +from peewee import Database +from peewee import DoesNotExist +from peewee import Model +from peewee import Proxy +from peewee import SelectQuery +from playhouse.db_url import connect as db_url_connect + + +class PaginatedQuery(object): + def __init__(self, query_or_model, paginate_by, page_var='page', page=None, + check_bounds=False): + self.paginate_by = paginate_by + self.page_var = page_var + self.page = page or None + self.check_bounds = check_bounds + + if isinstance(query_or_model, SelectQuery): + self.query = query_or_model + self.model = self.query.model + else: + self.model = query_or_model + self.query = self.model.select() + + def get_page(self): + if self.page is not None: + return self.page + + curr_page = request.args.get(self.page_var) + if curr_page and curr_page.isdigit(): + return max(1, int(curr_page)) + return 1 + + def get_page_count(self): + if not hasattr(self, '_page_count'): + self._page_count = int(math.ceil( + float(self.query.count()) / self.paginate_by)) + return self._page_count + + def get_object_list(self): + if self.check_bounds and self.get_page() > self.get_page_count(): + abort(404) + return self.query.paginate(self.get_page(), self.paginate_by) + + def get_page_range(self, page, total, show=5): + # Generate page buttons for a subset of pages, e.g. if the current page + # is 4, we have 10 pages, and want to show 5 buttons, this function + # returns us: [2, 3, 4, 5, 6] + start = max((page - (show // 2)), 1) + stop = min(start + show, total) + 1 + start = max(min(start, stop - show), 1) + return list(range(start, stop)[:show]) + + +def get_object_or_404(query_or_model, *query): + if not isinstance(query_or_model, SelectQuery): + query_or_model = query_or_model.select() + try: + return query_or_model.where(*query).get() + except DoesNotExist: + abort(404) + +def object_list(template_name, query, context_variable='object_list', + paginate_by=20, page_var='page', page=None, check_bounds=True, + **kwargs): + paginated_query = PaginatedQuery( + query, + paginate_by=paginate_by, + page_var=page_var, + page=page, + check_bounds=check_bounds) + kwargs[context_variable] = paginated_query.get_object_list() + return render_template( + template_name, + pagination=paginated_query, + page=paginated_query.get_page(), + **kwargs) + +def get_current_url(): + if not request.query_string: + return request.path + return '%s?%s' % (request.path, request.query_string) + +def get_next_url(default='/'): + if request.args.get('next'): + return request.args['next'] + elif request.form.get('next'): + return request.form['next'] + return default + +class FlaskDB(object): + """ + Convenience wrapper for configuring a Peewee database for use with a Flask + application. Provides a base `Model` class and registers handlers to manage + the database connection during the request/response cycle. + + Usage:: + + from flask import Flask + from peewee import * + from playhouse.flask_utils import FlaskDB + + + # The database can be specified using a database URL, or you can pass a + # Peewee database instance directly: + DATABASE = 'postgresql:///my_app' + DATABASE = PostgresqlDatabase('my_app') + + # If we do not want connection-management on any views, we can specify + # the view names using FLASKDB_EXCLUDED_ROUTES. The db connection will + # not be opened/closed automatically when these views are requested: + FLASKDB_EXCLUDED_ROUTES = ('logout',) + + app = Flask(__name__) + app.config.from_object(__name__) + + # Now we can configure our FlaskDB: + flask_db = FlaskDB(app) + + # Or use the "deferred initialization" pattern: + flask_db = FlaskDB() + flask_db.init_app(app) + + # The `flask_db` provides a base Model-class for easily binding models + # to the configured database: + class User(flask_db.Model): + email = CharField() + + """ + def __init__(self, app=None, database=None, model_class=Model, + excluded_routes=None): + self.database = None # Reference to actual Peewee database instance. + self.base_model_class = model_class + self._app = app + self._db = database # dict, url, Database, or None (default). + self._excluded_routes = excluded_routes or () + if app is not None: + self.init_app(app) + + def init_app(self, app): + self._app = app + + if self._db is None: + if 'DATABASE' in app.config: + initial_db = app.config['DATABASE'] + elif 'DATABASE_URL' in app.config: + initial_db = app.config['DATABASE_URL'] + else: + raise ValueError('Missing required configuration data for ' + 'database: DATABASE or DATABASE_URL.') + else: + initial_db = self._db + + if 'FLASKDB_EXCLUDED_ROUTES' in app.config: + self._excluded_routes = app.config['FLASKDB_EXCLUDED_ROUTES'] + + self._load_database(app, initial_db) + self._register_handlers(app) + + def _load_database(self, app, config_value): + if isinstance(config_value, Database): + database = config_value + elif isinstance(config_value, dict): + database = self._load_from_config_dict(dict(config_value)) + else: + # Assume a database connection URL. + database = db_url_connect(config_value) + + if isinstance(self.database, Proxy): + self.database.initialize(database) + else: + self.database = database + + def _load_from_config_dict(self, config_dict): + try: + name = config_dict.pop('name') + engine = config_dict.pop('engine') + except KeyError: + raise RuntimeError('DATABASE configuration must specify a ' + '`name` and `engine`.') + + if '.' in engine: + path, class_name = engine.rsplit('.', 1) + else: + path, class_name = 'peewee', engine + + try: + __import__(path) + module = sys.modules[path] + database_class = getattr(module, class_name) + assert issubclass(database_class, Database) + except ImportError: + raise RuntimeError('Unable to import %s' % engine) + except AttributeError: + raise RuntimeError('Database engine not found %s' % engine) + except AssertionError: + raise RuntimeError('Database engine not a subclass of ' + 'peewee.Database: %s' % engine) + + return database_class(name, **config_dict) + + def _register_handlers(self, app): + app.before_request(self.connect_db) + app.teardown_request(self.close_db) + + def get_model_class(self): + if self.database is None: + raise RuntimeError('Database must be initialized.') + + class BaseModel(self.base_model_class): + class Meta: + database = self.database + + return BaseModel + + @property + def Model(self): + if self._app is None: + database = getattr(self, 'database', None) + if database is None: + self.database = Proxy() + + if not hasattr(self, '_model_class'): + self._model_class = self.get_model_class() + return self._model_class + + def connect_db(self): + if self._excluded_routes and request.endpoint in self._excluded_routes: + return + self.database.connect() + + def close_db(self, exc): + if self._excluded_routes and request.endpoint in self._excluded_routes: + return + if not self.database.is_closed(): + self.database.close() diff --git a/python3.10libs/playhouse/hybrid.py b/python3.10libs/playhouse/hybrid.py new file mode 100644 index 0000000..50531cc --- /dev/null +++ b/python3.10libs/playhouse/hybrid.py @@ -0,0 +1,53 @@ +from peewee import ModelDescriptor + + +# Hybrid methods/attributes, based on similar functionality in SQLAlchemy: +# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html +class hybrid_method(ModelDescriptor): + def __init__(self, func, expr=None): + self.func = func + self.expr = expr or func + + def __get__(self, instance, instance_type): + if instance is None: + return self.expr.__get__(instance_type, instance_type.__class__) + return self.func.__get__(instance, instance_type) + + def expression(self, expr): + self.expr = expr + return self + + +class hybrid_property(ModelDescriptor): + def __init__(self, fget, fset=None, fdel=None, expr=None): + self.fget = fget + self.fset = fset + self.fdel = fdel + self.expr = expr or fget + + def __get__(self, instance, instance_type): + if instance is None: + return self.expr(instance_type) + return self.fget(instance) + + def __set__(self, instance, value): + if self.fset is None: + raise AttributeError('Cannot set attribute.') + self.fset(instance, value) + + def __delete__(self, instance): + if self.fdel is None: + raise AttributeError('Cannot delete attribute.') + self.fdel(instance) + + def setter(self, fset): + self.fset = fset + return self + + def deleter(self, fdel): + self.fdel = fdel + return self + + def expression(self, expr): + self.expr = expr + return self diff --git a/python3.10libs/playhouse/kv.py b/python3.10libs/playhouse/kv.py new file mode 100644 index 0000000..3451631 --- /dev/null +++ b/python3.10libs/playhouse/kv.py @@ -0,0 +1,176 @@ +import operator + +from peewee import * +from peewee import sqlite3 +from peewee import Expression +from playhouse.fields import PickleField +try: + from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase +except ImportError: + from playhouse.sqlite_ext import SqliteExtDatabase + + +Sentinel = type('Sentinel', (object,), {}) + + +class KeyValue(object): + """ + Persistent dictionary. + + :param Field key_field: field to use for key. Defaults to CharField. + :param Field value_field: field to use for value. Defaults to PickleField. + :param bool ordered: data should be returned in key-sorted order. + :param Database database: database where key/value data is stored. + :param str table_name: table name for data. + """ + def __init__(self, key_field=None, value_field=None, ordered=False, + database=None, table_name='keyvalue'): + if key_field is None: + key_field = CharField(max_length=255, primary_key=True) + if not key_field.primary_key: + raise ValueError('key_field must have primary_key=True.') + + if value_field is None: + value_field = PickleField() + + self._key_field = key_field + self._value_field = value_field + self._ordered = ordered + self._database = database or SqliteExtDatabase(':memory:') + self._table_name = table_name + support_on_conflict = (isinstance(self._database, PostgresqlDatabase) or + (isinstance(self._database, SqliteDatabase) and + self._database.server_version >= (3, 24))) + if support_on_conflict: + self.upsert = self._postgres_upsert + self.update = self._postgres_update + else: + self.upsert = self._upsert + self.update = self._update + + self.model = self.create_model() + self.key = self.model.key + self.value = self.model.value + + # Ensure table exists. + self.model.create_table() + + def create_model(self): + class KeyValue(Model): + key = self._key_field + value = self._value_field + class Meta: + database = self._database + table_name = self._table_name + return KeyValue + + def query(self, *select): + query = self.model.select(*select).tuples() + if self._ordered: + query = query.order_by(self.key) + return query + + def convert_expression(self, expr): + if not isinstance(expr, Expression): + return (self.key == expr), True + return expr, False + + def __contains__(self, key): + expr, _ = self.convert_expression(key) + return self.model.select().where(expr).exists() + + def __len__(self): + return len(self.model) + + def __getitem__(self, expr): + converted, is_single = self.convert_expression(expr) + query = self.query(self.value).where(converted) + item_getter = operator.itemgetter(0) + result = [item_getter(row) for row in query] + if len(result) == 0 and is_single: + raise KeyError(expr) + elif is_single: + return result[0] + return result + + def _upsert(self, key, value): + (self.model + .insert(key=key, value=value) + .on_conflict('replace') + .execute()) + + def _postgres_upsert(self, key, value): + (self.model + .insert(key=key, value=value) + .on_conflict(conflict_target=[self.key], + preserve=[self.value]) + .execute()) + + def __setitem__(self, expr, value): + if isinstance(expr, Expression): + self.model.update(value=value).where(expr).execute() + else: + self.upsert(expr, value) + + def __delitem__(self, expr): + converted, _ = self.convert_expression(expr) + self.model.delete().where(converted).execute() + + def __iter__(self): + return iter(self.query().execute()) + + def keys(self): + return map(operator.itemgetter(0), self.query(self.key)) + + def values(self): + return map(operator.itemgetter(0), self.query(self.value)) + + def items(self): + return iter(self.query().execute()) + + def _update(self, __data=None, **mapping): + if __data is not None: + mapping.update(__data) + return (self.model + .insert_many(list(mapping.items()), + fields=[self.key, self.value]) + .on_conflict('replace') + .execute()) + + def _postgres_update(self, __data=None, **mapping): + if __data is not None: + mapping.update(__data) + return (self.model + .insert_many(list(mapping.items()), + fields=[self.key, self.value]) + .on_conflict(conflict_target=[self.key], + preserve=[self.value]) + .execute()) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + def pop(self, key, default=Sentinel): + with self._database.atomic(): + try: + result = self[key] + except KeyError: + if default is Sentinel: + raise + return default + del self[key] + + return result + + def clear(self): + self.model.delete().execute() diff --git a/python3.10libs/playhouse/migrate.py b/python3.10libs/playhouse/migrate.py new file mode 100644 index 0000000..2f877c4 --- /dev/null +++ b/python3.10libs/playhouse/migrate.py @@ -0,0 +1,897 @@ +""" +Lightweight schema migrations. + +Example Usage +------------- + +Instantiate a migrator: + + # Postgres example: + my_db = PostgresqlDatabase(...) + migrator = PostgresqlMigrator(my_db) + + # SQLite example: + my_db = SqliteDatabase('my_database.db') + migrator = SqliteMigrator(my_db) + +Then you will use the `migrate` function to run various `Operation`s which +are generated by the migrator: + + migrate( + migrator.add_column('some_table', 'column_name', CharField(default='')) + ) + +Migrations are not run inside a transaction, so if you wish the migration to +run in a transaction you will need to wrap the call to `migrate` in a +transaction block, e.g.: + + with my_db.transaction(): + migrate(...) + +Supported Operations +-------------------- + +Add new field(s) to an existing model: + + # Create your field instances. For non-null fields you must specify a + # default value. + pubdate_field = DateTimeField(null=True) + comment_field = TextField(default='') + + # Run the migration, specifying the database table, field name and field. + migrate( + migrator.add_column('comment_tbl', 'pub_date', pubdate_field), + migrator.add_column('comment_tbl', 'comment', comment_field), + ) + +Renaming a field: + + # Specify the table, original name of the column, and its new name. + migrate( + migrator.rename_column('story', 'pub_date', 'publish_date'), + migrator.rename_column('story', 'mod_date', 'modified_date'), + ) + +Dropping a field: + + migrate( + migrator.drop_column('story', 'some_old_field'), + ) + +Making a field nullable or not nullable: + + # Note that when making a field not null that field must not have any + # NULL values present. + migrate( + # Make `pub_date` allow NULL values. + migrator.drop_not_null('story', 'pub_date'), + + # Prevent `modified_date` from containing NULL values. + migrator.add_not_null('story', 'modified_date'), + ) + +Renaming a table: + + migrate( + migrator.rename_table('story', 'stories_tbl'), + ) + +Adding an index: + + # Specify the table, column names, and whether the index should be + # UNIQUE or not. + migrate( + # Create an index on the `pub_date` column. + migrator.add_index('story', ('pub_date',), False), + + # Create a multi-column index on the `pub_date` and `status` fields. + migrator.add_index('story', ('pub_date', 'status'), False), + + # Create a unique index on the category and title fields. + migrator.add_index('story', ('category_id', 'title'), True), + ) + +Dropping an index: + + # Specify the index name. + migrate(migrator.drop_index('story', 'story_pub_date_status')) + +Adding or dropping table constraints: + +.. code-block:: python + + # Add a CHECK() constraint to enforce the price cannot be negative. + migrate(migrator.add_constraint( + 'products', + 'price_check', + Check('price >= 0'))) + + # Remove the price check constraint. + migrate(migrator.drop_constraint('products', 'price_check')) + + # Add a UNIQUE constraint on the first and last names. + migrate(migrator.add_unique('person', 'first_name', 'last_name')) +""" +from collections import namedtuple +import functools +import hashlib +import re + +from peewee import * +from peewee import CommaNodeList +from peewee import EnclosedNodeList +from peewee import Entity +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import OP +from peewee import callable_ +from peewee import sort_models +from peewee import sqlite3 +from peewee import _truncate_constraint_name +try: + from playhouse.cockroachdb import CockroachDatabase +except ImportError: + CockroachDatabase = None + + +class Operation(object): + """Encapsulate a single schema altering operation.""" + def __init__(self, migrator, method, *args, **kwargs): + self.migrator = migrator + self.method = method + self.args = args + self.kwargs = kwargs + + def execute(self, node): + self.migrator.database.execute(node) + + def _handle_result(self, result): + if isinstance(result, (Node, Context)): + self.execute(result) + elif isinstance(result, Operation): + result.run() + elif isinstance(result, (list, tuple)): + for item in result: + self._handle_result(item) + + def run(self): + kwargs = self.kwargs.copy() + kwargs['with_context'] = True + method = getattr(self.migrator, self.method) + self._handle_result(method(*self.args, **kwargs)) + + +def operation(fn): + @functools.wraps(fn) + def inner(self, *args, **kwargs): + with_context = kwargs.pop('with_context', False) + if with_context: + return fn(self, *args, **kwargs) + return Operation(self, fn.__name__, *args, **kwargs) + return inner + + +def make_index_name(table_name, columns): + index_name = '_'.join((table_name,) + tuple(columns)) + if len(index_name) > 64: + index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest() + index_name = '%s_%s' % (index_name[:56], index_hash[:7]) + return index_name + + +class SchemaMigrator(object): + explicit_create_foreign_key = False + explicit_delete_foreign_key = False + + def __init__(self, database): + self.database = database + + def make_context(self): + return self.database.get_sql_context() + + @classmethod + def from_database(cls, database): + if CockroachDatabase and isinstance(database, CockroachDatabase): + return CockroachDBMigrator(database) + elif isinstance(database, PostgresqlDatabase): + return PostgresqlMigrator(database) + elif isinstance(database, MySQLDatabase): + return MySQLMigrator(database) + elif isinstance(database, SqliteDatabase): + return SqliteMigrator(database) + raise ValueError('Unsupported database: %s' % database) + + @operation + def apply_default(self, table, column_name, field): + default = field.default + if callable_(default): + default = default() + + return (self.make_context() + .literal('UPDATE ') + .sql(Entity(table)) + .literal(' SET ') + .sql(Expression( + Entity(column_name), + OP.EQ, + field.db_value(default), + flat=True))) + + def _alter_table(self, ctx, table): + return ctx.literal('ALTER TABLE ').sql(Entity(table)) + + def _alter_column(self, ctx, table, column): + return (self + ._alter_table(ctx, table) + .literal(' ALTER COLUMN ') + .sql(Entity(column))) + + @operation + def alter_add_column(self, table, column_name, field): + # Make field null at first. + ctx = self.make_context() + field_null, field.null = field.null, True + + # Set the field's column-name and name, if it is not set or doesn't + # match the new value. + if field.column_name != column_name: + field.name = field.column_name = column_name + + (self + ._alter_table(ctx, table) + .literal(' ADD COLUMN ') + .sql(field.ddl(ctx))) + + field.null = field_null + if isinstance(field, ForeignKeyField): + self.add_inline_fk_sql(ctx, field) + return ctx + + @operation + def add_constraint(self, table, name, constraint): + return (self + ._alter_table(self.make_context(), table) + .literal(' ADD CONSTRAINT ') + .sql(Entity(name)) + .literal(' ') + .sql(constraint)) + + @operation + def add_unique(self, table, *column_names): + constraint_name = 'uniq_%s' % '_'.join(column_names) + constraint = NodeList(( + SQL('UNIQUE'), + EnclosedNodeList([Entity(column) for column in column_names]))) + return self.add_constraint(table, constraint_name, constraint) + + @operation + def drop_constraint(self, table, name): + return (self + ._alter_table(self.make_context(), table) + .literal(' DROP CONSTRAINT ') + .sql(Entity(name))) + + def add_inline_fk_sql(self, ctx, field): + ctx = (ctx + .literal(' REFERENCES ') + .sql(Entity(field.rel_model._meta.table_name)) + .literal(' ') + .sql(EnclosedNodeList((Entity(field.rel_field.column_name),)))) + if field.on_delete is not None: + ctx = ctx.literal(' ON DELETE %s' % field.on_delete) + if field.on_update is not None: + ctx = ctx.literal(' ON UPDATE %s' % field.on_update) + return ctx + + @operation + def add_foreign_key_constraint(self, table, column_name, rel, rel_column, + on_delete=None, on_update=None): + constraint = 'fk_%s_%s_refs_%s' % (table, column_name, rel) + ctx = (self + .make_context() + .literal('ALTER TABLE ') + .sql(Entity(table)) + .literal(' ADD CONSTRAINT ') + .sql(Entity(_truncate_constraint_name(constraint))) + .literal(' FOREIGN KEY ') + .sql(EnclosedNodeList((Entity(column_name),))) + .literal(' REFERENCES ') + .sql(Entity(rel)) + .literal(' (') + .sql(Entity(rel_column)) + .literal(')')) + if on_delete is not None: + ctx = ctx.literal(' ON DELETE %s' % on_delete) + if on_update is not None: + ctx = ctx.literal(' ON UPDATE %s' % on_update) + return ctx + + @operation + def add_column(self, table, column_name, field): + # Adding a column is complicated by the fact that if there are rows + # present and the field is non-null, then we need to first add the + # column as a nullable field, then set the value, then add a not null + # constraint. + if not field.null and field.default is None: + raise ValueError('%s is not null but has no default' % column_name) + + is_foreign_key = isinstance(field, ForeignKeyField) + if is_foreign_key and not field.rel_field: + raise ValueError('Foreign keys must specify a `field`.') + + operations = [self.alter_add_column(table, column_name, field)] + + # In the event the field is *not* nullable, update with the default + # value and set not null. + if not field.null: + operations.extend([ + self.apply_default(table, column_name, field), + self.add_not_null(table, column_name)]) + + if is_foreign_key and self.explicit_create_foreign_key: + operations.append( + self.add_foreign_key_constraint( + table, + column_name, + field.rel_model._meta.table_name, + field.rel_field.column_name, + field.on_delete, + field.on_update)) + + if field.index or field.unique: + using = getattr(field, 'index_type', None) + operations.append(self.add_index(table, (column_name,), + field.unique, using)) + + return operations + + @operation + def drop_foreign_key_constraint(self, table, column_name): + raise NotImplementedError + + @operation + def drop_column(self, table, column_name, cascade=True): + ctx = self.make_context() + (self._alter_table(ctx, table) + .literal(' DROP COLUMN ') + .sql(Entity(column_name))) + + if cascade: + ctx.literal(' CASCADE') + + fk_columns = [ + foreign_key.column + for foreign_key in self.database.get_foreign_keys(table)] + if column_name in fk_columns and self.explicit_delete_foreign_key: + return [self.drop_foreign_key_constraint(table, column_name), ctx] + + return ctx + + @operation + def rename_column(self, table, old_name, new_name): + return (self + ._alter_table(self.make_context(), table) + .literal(' RENAME COLUMN ') + .sql(Entity(old_name)) + .literal(' TO ') + .sql(Entity(new_name))) + + @operation + def add_not_null(self, table, column): + return (self + ._alter_column(self.make_context(), table, column) + .literal(' SET NOT NULL')) + + @operation + def drop_not_null(self, table, column): + return (self + ._alter_column(self.make_context(), table, column) + .literal(' DROP NOT NULL')) + + @operation + def alter_column_type(self, table, column, field, cast=None): + # ALTER TABLE
ALTER COLUMN + ctx = self.make_context() + ctx = (self + ._alter_column(ctx, table, column) + .literal(' TYPE ') + .sql(field.ddl_datatype(ctx))) + if cast is not None: + if not isinstance(cast, Node): + cast = SQL(cast) + ctx = ctx.literal(' USING ').sql(cast) + return ctx + + @operation + def rename_table(self, old_name, new_name): + return (self + ._alter_table(self.make_context(), old_name) + .literal(' RENAME TO ') + .sql(Entity(new_name))) + + @operation + def add_index(self, table, columns, unique=False, using=None): + ctx = self.make_context() + index_name = make_index_name(table, columns) + table_obj = Table(table) + cols = [getattr(table_obj.c, column) for column in columns] + index = Index(index_name, table_obj, cols, unique=unique, using=using) + return ctx.sql(index) + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name))) + + +class PostgresqlMigrator(SchemaMigrator): + def _primary_key_columns(self, tbl): + query = """ + SELECT pg_attribute.attname + FROM pg_index, pg_class, pg_attribute + WHERE + pg_class.oid = '%s'::regclass AND + indrelid = pg_class.oid AND + pg_attribute.attrelid = pg_class.oid AND + pg_attribute.attnum = any(pg_index.indkey) AND + indisprimary; + """ + cursor = self.database.execute_sql(query % tbl) + return [row[0] for row in cursor.fetchall()] + + @operation + def set_search_path(self, schema_name): + return (self + .make_context() + .literal('SET search_path TO %s' % schema_name)) + + @operation + def rename_table(self, old_name, new_name): + pk_names = self._primary_key_columns(old_name) + ParentClass = super(PostgresqlMigrator, self) + + operations = [ + ParentClass.rename_table(old_name, new_name, with_context=True)] + + if len(pk_names) == 1: + # Check for existence of primary key sequence. + seq_name = '%s_%s_seq' % (old_name, pk_names[0]) + query = """ + SELECT 1 + FROM information_schema.sequences + WHERE LOWER(sequence_name) = LOWER(%s) + """ + cursor = self.database.execute_sql(query, (seq_name,)) + if bool(cursor.fetchone()): + new_seq_name = '%s_%s_seq' % (new_name, pk_names[0]) + operations.append(ParentClass.rename_table( + seq_name, new_seq_name)) + + return operations + + +class CockroachDBMigrator(PostgresqlMigrator): + explicit_create_foreign_key = True + + def add_inline_fk_sql(self, ctx, field): + pass + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name)) + .literal(' CASCADE')) + + +class MySQLColumn(namedtuple('_Column', ('name', 'definition', 'null', 'pk', + 'default', 'extra'))): + @property + def is_pk(self): + return self.pk == 'PRI' + + @property + def is_unique(self): + return self.pk == 'UNI' + + @property + def is_null(self): + return self.null == 'YES' + + def sql(self, column_name=None, is_null=None): + if is_null is None: + is_null = self.is_null + if column_name is None: + column_name = self.name + parts = [ + Entity(column_name), + SQL(self.definition)] + if self.is_unique: + parts.append(SQL('UNIQUE')) + if is_null: + parts.append(SQL('NULL')) + else: + parts.append(SQL('NOT NULL')) + if self.is_pk: + parts.append(SQL('PRIMARY KEY')) + if self.extra: + parts.append(SQL(self.extra)) + return NodeList(parts) + + +class MySQLMigrator(SchemaMigrator): + explicit_create_foreign_key = True + explicit_delete_foreign_key = True + + def _alter_column(self, ctx, table, column): + return (self + ._alter_table(ctx, table) + .literal(' MODIFY ') + .sql(Entity(column))) + + @operation + def rename_table(self, old_name, new_name): + return (self + .make_context() + .literal('RENAME TABLE ') + .sql(Entity(old_name)) + .literal(' TO ') + .sql(Entity(new_name))) + + def _get_column_definition(self, table, column_name): + cursor = self.database.execute_sql('DESCRIBE `%s`;' % table) + rows = cursor.fetchall() + for row in rows: + column = MySQLColumn(*row) + if column.name == column_name: + return column + return False + + def get_foreign_key_constraint(self, table, column_name): + cursor = self.database.execute_sql( + ('SELECT constraint_name ' + 'FROM information_schema.key_column_usage WHERE ' + 'table_schema = DATABASE() AND ' + 'table_name = %s AND ' + 'column_name = %s AND ' + 'referenced_table_name IS NOT NULL AND ' + 'referenced_column_name IS NOT NULL;'), + (table, column_name)) + result = cursor.fetchone() + if not result: + raise AttributeError( + 'Unable to find foreign key constraint for ' + '"%s" on table "%s".' % (table, column_name)) + return result[0] + + @operation + def drop_foreign_key_constraint(self, table, column_name): + fk_constraint = self.get_foreign_key_constraint(table, column_name) + return (self + ._alter_table(self.make_context(), table) + .literal(' DROP FOREIGN KEY ') + .sql(Entity(fk_constraint))) + + def add_inline_fk_sql(self, ctx, field): + pass + + @operation + def add_not_null(self, table, column): + column_def = self._get_column_definition(table, column) + add_not_null = (self + ._alter_table(self.make_context(), table) + .literal(' MODIFY ') + .sql(column_def.sql(is_null=False))) + + fk_objects = dict( + (fk.column, fk) + for fk in self.database.get_foreign_keys(table)) + if column not in fk_objects: + return add_not_null + + fk_metadata = fk_objects[column] + return (self.drop_foreign_key_constraint(table, column), + add_not_null, + self.add_foreign_key_constraint( + table, + column, + fk_metadata.dest_table, + fk_metadata.dest_column)) + + @operation + def drop_not_null(self, table, column): + column = self._get_column_definition(table, column) + if column.is_pk: + raise ValueError('Primary keys can not be null') + return (self + ._alter_table(self.make_context(), table) + .literal(' MODIFY ') + .sql(column.sql(is_null=True))) + + @operation + def rename_column(self, table, old_name, new_name): + fk_objects = dict( + (fk.column, fk) + for fk in self.database.get_foreign_keys(table)) + is_foreign_key = old_name in fk_objects + + column = self._get_column_definition(table, old_name) + rename_ctx = (self + ._alter_table(self.make_context(), table) + .literal(' CHANGE ') + .sql(Entity(old_name)) + .literal(' ') + .sql(column.sql(column_name=new_name))) + if is_foreign_key: + fk_metadata = fk_objects[old_name] + return [ + self.drop_foreign_key_constraint(table, old_name), + rename_ctx, + self.add_foreign_key_constraint( + table, + new_name, + fk_metadata.dest_table, + fk_metadata.dest_column), + ] + else: + return rename_ctx + + @operation + def alter_column_type(self, table, column, field, cast=None): + if cast is not None: + raise ValueError('alter_column_type() does not support cast with ' + 'MySQL.') + ctx = self.make_context() + return (self + ._alter_table(ctx, table) + .literal(' MODIFY ') + .sql(Entity(column)) + .literal(' ') + .sql(field.ddl(ctx))) + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name)) + .literal(' ON ') + .sql(Entity(table))) + + +class SqliteMigrator(SchemaMigrator): + """ + SQLite supports a subset of ALTER TABLE queries, view the docs for the + full details http://sqlite.org/lang_altertable.html + """ + column_re = re.compile('(.+?)\((.+)\)') + column_split_re = re.compile(r'(?:[^,(]|\([^)]*\))+') + column_name_re = re.compile(r'''["`']?([\w]+)''') + fk_re = re.compile(r'FOREIGN KEY\s+\("?([\w]+)"?\)\s+', re.I) + + def _get_column_names(self, table): + res = self.database.execute_sql('select * from "%s" limit 1' % table) + return [item[0] for item in res.description] + + def _get_create_table(self, table): + res = self.database.execute_sql( + ('select name, sql from sqlite_master ' + 'where type=? and LOWER(name)=?'), + ['table', table.lower()]) + return res.fetchone() + + @operation + def _update_column(self, table, column_to_update, fn): + columns = set(column.name.lower() + for column in self.database.get_columns(table)) + if column_to_update.lower() not in columns: + raise ValueError('Column "%s" does not exist on "%s"' % + (column_to_update, table)) + + # Get the SQL used to create the given table. + table, create_table = self._get_create_table(table) + + # Get the indexes and SQL to re-create indexes. + indexes = self.database.get_indexes(table) + + # Find any foreign keys we may need to remove. + self.database.get_foreign_keys(table) + + # Make sure the create_table does not contain any newlines or tabs, + # allowing the regex to work correctly. + create_table = re.sub(r'\s+', ' ', create_table) + + # Parse out the `CREATE TABLE` and column list portions of the query. + raw_create, raw_columns = self.column_re.search(create_table).groups() + + # Clean up the individual column definitions. + split_columns = self.column_split_re.findall(raw_columns) + column_defs = [col.strip() for col in split_columns] + + new_column_defs = [] + new_column_names = [] + original_column_names = [] + constraint_terms = ('foreign ', 'primary ', 'constraint ', 'check ') + + for column_def in column_defs: + column_name, = self.column_name_re.match(column_def).groups() + + if column_name == column_to_update: + new_column_def = fn(column_name, column_def) + if new_column_def: + new_column_defs.append(new_column_def) + original_column_names.append(column_name) + column_name, = self.column_name_re.match( + new_column_def).groups() + new_column_names.append(column_name) + else: + new_column_defs.append(column_def) + + # Avoid treating constraints as columns. + if not column_def.lower().startswith(constraint_terms): + new_column_names.append(column_name) + original_column_names.append(column_name) + + # Create a mapping of original columns to new columns. + original_to_new = dict(zip(original_column_names, new_column_names)) + new_column = original_to_new.get(column_to_update) + + fk_filter_fn = lambda column_def: column_def + if not new_column: + # Remove any foreign keys associated with this column. + fk_filter_fn = lambda column_def: None + elif new_column != column_to_update: + # Update any foreign keys for this column. + fk_filter_fn = lambda column_def: self.fk_re.sub( + 'FOREIGN KEY ("%s") ' % new_column, + column_def) + + cleaned_columns = [] + for column_def in new_column_defs: + match = self.fk_re.match(column_def) + if match is not None and match.groups()[0] == column_to_update: + column_def = fk_filter_fn(column_def) + if column_def: + cleaned_columns.append(column_def) + + # Update the name of the new CREATE TABLE query. + temp_table = table + '__tmp__' + rgx = re.compile('("?)%s("?)' % table, re.I) + create = rgx.sub( + '\\1%s\\2' % temp_table, + raw_create) + + # Create the new table. + columns = ', '.join(cleaned_columns) + queries = [ + NodeList([SQL('DROP TABLE IF EXISTS'), Entity(temp_table)]), + SQL('%s (%s)' % (create.strip(), columns))] + + # Populate new table. + populate_table = NodeList(( + SQL('INSERT INTO'), + Entity(temp_table), + EnclosedNodeList([Entity(col) for col in new_column_names]), + SQL('SELECT'), + CommaNodeList([Entity(col) for col in original_column_names]), + SQL('FROM'), + Entity(table))) + drop_original = NodeList([SQL('DROP TABLE'), Entity(table)]) + + # Drop existing table and rename temp table. + queries += [ + populate_table, + drop_original, + self.rename_table(temp_table, table)] + + # Re-create user-defined indexes. User-defined indexes will have a + # non-empty SQL attribute. + for index in filter(lambda idx: idx.sql, indexes): + if column_to_update not in index.columns: + queries.append(SQL(index.sql)) + elif new_column: + sql = self._fix_index(index.sql, column_to_update, new_column) + if sql is not None: + queries.append(SQL(sql)) + + return queries + + def _fix_index(self, sql, column_to_update, new_column): + # Split on the name of the column to update. If it splits into two + # pieces, then there's no ambiguity and we can simply replace the + # old with the new. + parts = sql.split(column_to_update) + if len(parts) == 2: + return sql.replace(column_to_update, new_column) + + # Find the list of columns in the index expression. + lhs, rhs = sql.rsplit('(', 1) + + # Apply the same "split in two" logic to the column list portion of + # the query. + if len(rhs.split(column_to_update)) == 2: + return '%s(%s' % (lhs, rhs.replace(column_to_update, new_column)) + + # Strip off the trailing parentheses and go through each column. + parts = rhs.rsplit(')', 1)[0].split(',') + columns = [part.strip('"`[]\' ') for part in parts] + + # `columns` looks something like: ['status', 'timestamp" DESC'] + # https://www.sqlite.org/lang_keywords.html + # Strip out any junk after the column name. + clean = [] + for column in columns: + if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column): + column = new_column + column[len(column_to_update):] + clean.append(column) + + return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean)) + + @operation + def drop_column(self, table, column_name, cascade=True, legacy=False): + if sqlite3.sqlite_version_info >= (3, 25, 0) and not legacy: + ctx = self.make_context() + (self._alter_table(ctx, table) + .literal(' DROP COLUMN ') + .sql(Entity(column_name))) + return ctx + return self._update_column(table, column_name, lambda a, b: None) + + @operation + def rename_column(self, table, old_name, new_name, legacy=False): + if sqlite3.sqlite_version_info >= (3, 25, 0) and not legacy: + return (self + ._alter_table(self.make_context(), table) + .literal(' RENAME COLUMN ') + .sql(Entity(old_name)) + .literal(' TO ') + .sql(Entity(new_name))) + def _rename(column_name, column_def): + return column_def.replace(column_name, new_name) + return self._update_column(table, old_name, _rename) + + @operation + def add_not_null(self, table, column): + def _add_not_null(column_name, column_def): + return column_def + ' NOT NULL' + return self._update_column(table, column, _add_not_null) + + @operation + def drop_not_null(self, table, column): + def _drop_not_null(column_name, column_def): + return column_def.replace('NOT NULL', '') + return self._update_column(table, column, _drop_not_null) + + @operation + def alter_column_type(self, table, column, field, cast=None): + if cast is not None: + raise ValueError('alter_column_type() does not support cast with ' + 'Sqlite.') + ctx = self.make_context() + def _alter_column_type(column_name, column_def): + node_list = field.ddl(ctx) + sql, _ = ctx.sql(Entity(column)).sql(node_list).query() + return sql + return self._update_column(table, column, _alter_column_type) + + @operation + def add_constraint(self, table, name, constraint): + raise NotImplementedError + + @operation + def drop_constraint(self, table, name): + raise NotImplementedError + + @operation + def add_foreign_key_constraint(self, table, column_name, field, + on_delete=None, on_update=None): + raise NotImplementedError + + +def migrate(*operations, **kwargs): + for operation in operations: + operation.run() diff --git a/python3.10libs/playhouse/mysql_ext.py b/python3.10libs/playhouse/mysql_ext.py new file mode 100644 index 0000000..81b053e --- /dev/null +++ b/python3.10libs/playhouse/mysql_ext.py @@ -0,0 +1,117 @@ +import json + +try: + import mysql.connector as mysql_connector +except ImportError: + mysql_connector = None +try: + import mariadb +except ImportError: + mariadb = None + +from peewee import ImproperlyConfigured +from peewee import Insert +from peewee import MySQLDatabase +from peewee import Node +from peewee import NodeList +from peewee import SQL +from peewee import TextField +from peewee import fn +from peewee import __deprecated__ + + +class MySQLConnectorDatabase(MySQLDatabase): + def _connect(self): + if mysql_connector is None: + raise ImproperlyConfigured('MySQL connector not installed!') + return mysql_connector.connect(db=self.database, autocommit=True, + **self.connect_params) + + def cursor(self, commit=None, named_cursor=None): + if commit is not None: + __deprecated__('"commit" has been deprecated and is a no-op.') + if self.is_closed(): + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') + return self._state.conn.cursor(buffered=True) + + def get_binary_type(self): + return mysql_connector.Binary + + +class MariaDBConnectorDatabase(MySQLDatabase): + def _connect(self): + if mariadb is None: + raise ImproperlyConfigured('mariadb connector not installed!') + self.connect_params.pop('charset', None) + self.connect_params.pop('sql_mode', None) + self.connect_params.pop('use_unicode', None) + return mariadb.connect(db=self.database, autocommit=True, + **self.connect_params) + + def cursor(self, commit=None, named_cursor=None): + if commit is not None: + __deprecated__('"commit" has been deprecated and is a no-op.') + if self.is_closed(): + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') + return self._state.conn.cursor(buffered=True) + + def _set_server_version(self, conn): + version = conn.server_version + version, point = divmod(version, 100) + version, minor = divmod(version, 100) + self.server_version = (version, minor, point) + if self.server_version >= (10, 5, 0): + self.returning_clause = True + + def last_insert_id(self, cursor, query_type=None): + if not self.returning_clause: + return cursor.lastrowid + elif query_type == Insert.SIMPLE: + try: + return cursor[0][0] + except (AttributeError, IndexError): + return cursor.lastrowid + return cursor + + def get_binary_type(self): + return mariadb.Binary + + +class JSONField(TextField): + field_type = 'JSON' + + def __init__(self, json_dumps=None, json_loads=None, **kwargs): + self._json_dumps = json_dumps or json.dumps + self._json_loads = json_loads or json.loads + super(JSONField, self).__init__(**kwargs) + + def python_value(self, value): + if value is not None: + try: + return self._json_loads(value) + except (TypeError, ValueError): + return value + + def db_value(self, value): + if value is not None: + if not isinstance(value, Node): + value = self._json_dumps(value) + return value + + def extract(self, path): + return fn.json_extract(self, path) + + +def Match(columns, expr, modifier=None): + if isinstance(columns, (list, tuple)): + match = fn.MATCH(*columns) # Tuple of one or more columns / fields. + else: + match = fn.MATCH(columns) # Single column / field. + args = expr if modifier is None else NodeList((expr, SQL(modifier))) + return NodeList((match, fn.AGAINST(args))) diff --git a/python3.10libs/playhouse/pool.py b/python3.10libs/playhouse/pool.py new file mode 100644 index 0000000..2ee3b48 --- /dev/null +++ b/python3.10libs/playhouse/pool.py @@ -0,0 +1,318 @@ +""" +Lightweight connection pooling for peewee. + +In a multi-threaded application, up to `max_connections` will be opened. Each +thread (or, if using gevent, greenlet) will have it's own connection. + +In a single-threaded application, only one connection will be created. It will +be continually recycled until either it exceeds the stale timeout or is closed +explicitly (using `.manual_close()`). + +By default, all your application needs to do is ensure that connections are +closed when you are finished with them, and they will be returned to the pool. +For web applications, this typically means that at the beginning of a request, +you will open a connection, and when you return a response, you will close the +connection. + +Simple Postgres pool example code: + + # Use the special postgresql extensions. + from playhouse.pool import PooledPostgresqlExtDatabase + + db = PooledPostgresqlExtDatabase( + 'my_app', + max_connections=32, + stale_timeout=300, # 5 minutes. + user='postgres') + + class BaseModel(Model): + class Meta: + database = db + +That's it! +""" +import heapq +import logging +import random +import time +from collections import namedtuple +from itertools import chain + +try: + from psycopg2.extensions import TRANSACTION_STATUS_IDLE + from psycopg2.extensions import TRANSACTION_STATUS_INERROR + from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN +except ImportError: + TRANSACTION_STATUS_IDLE = \ + TRANSACTION_STATUS_INERROR = \ + TRANSACTION_STATUS_UNKNOWN = None + +from peewee import MySQLDatabase +from peewee import PostgresqlDatabase +from peewee import SqliteDatabase + +logger = logging.getLogger('peewee.pool') + + +def make_int(val): + if val is not None and not isinstance(val, (int, float)): + return int(val) + return val + + +class MaxConnectionsExceeded(ValueError): pass + + +PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection', + 'checked_out')) + + +class PooledDatabase(object): + def __init__(self, database, max_connections=20, stale_timeout=None, + timeout=None, **kwargs): + self._max_connections = make_int(max_connections) + self._stale_timeout = make_int(stale_timeout) + self._wait_timeout = make_int(timeout) + if self._wait_timeout == 0: + self._wait_timeout = float('inf') + + # Available / idle connections stored in a heap, sorted oldest first. + self._connections = [] + + # Mapping of connection id to PoolConnection. Ordinarily we would want + # to use something like a WeakKeyDictionary, but Python typically won't + # allow us to create weak references to connection objects. + self._in_use = {} + + # Use the memory address of the connection as the key in the event the + # connection object is not hashable. Connections will not get + # garbage-collected, however, because a reference to them will persist + # in "_in_use" as long as the conn has not been closed. + self.conn_key = id + + super(PooledDatabase, self).__init__(database, **kwargs) + + def init(self, database, max_connections=None, stale_timeout=None, + timeout=None, **connect_kwargs): + super(PooledDatabase, self).init(database, **connect_kwargs) + if max_connections is not None: + self._max_connections = make_int(max_connections) + if stale_timeout is not None: + self._stale_timeout = make_int(stale_timeout) + if timeout is not None: + self._wait_timeout = make_int(timeout) + if self._wait_timeout == 0: + self._wait_timeout = float('inf') + + def connect(self, reuse_if_open=False): + if not self._wait_timeout: + return super(PooledDatabase, self).connect(reuse_if_open) + + expires = time.time() + self._wait_timeout + while expires > time.time(): + try: + ret = super(PooledDatabase, self).connect(reuse_if_open) + except MaxConnectionsExceeded: + time.sleep(0.1) + else: + return ret + raise MaxConnectionsExceeded('Max connections exceeded, timed out ' + 'attempting to connect.') + + def _connect(self): + while True: + try: + # Remove the oldest connection from the heap. + ts, conn = heapq.heappop(self._connections) + key = self.conn_key(conn) + except IndexError: + ts = conn = None + logger.debug('No connection available in pool.') + break + else: + if self._is_closed(conn): + # This connecton was closed, but since it was not stale + # it got added back to the queue of available conns. We + # then closed it and marked it as explicitly closed, so + # it's safe to throw it away now. + # (Because Database.close() calls Database._close()). + logger.debug('Connection %s was closed.', key) + ts = conn = None + elif self._stale_timeout and self._is_stale(ts): + # If we are attempting to check out a stale connection, + # then close it. We don't need to mark it in the "closed" + # set, because it is not in the list of available conns + # anymore. + logger.debug('Connection %s was stale, closing.', key) + self._close(conn, True) + ts = conn = None + else: + break + + if conn is None: + if self._max_connections and ( + len(self._in_use) >= self._max_connections): + raise MaxConnectionsExceeded('Exceeded maximum connections.') + conn = super(PooledDatabase, self)._connect() + ts = time.time() - random.random() / 1000 + key = self.conn_key(conn) + logger.debug('Created new connection %s.', key) + + self._in_use[key] = PoolConnection(ts, conn, time.time()) + return conn + + def _is_stale(self, timestamp): + # Called on check-out and check-in to ensure the connection has + # not outlived the stale timeout. + return (time.time() - timestamp) > self._stale_timeout + + def _is_closed(self, conn): + return False + + def _can_reuse(self, conn): + # Called on check-in to make sure the connection can be re-used. + return True + + def _close(self, conn, close_conn=False): + key = self.conn_key(conn) + if close_conn: + super(PooledDatabase, self)._close(conn) + elif key in self._in_use: + pool_conn = self._in_use.pop(key) + if self._stale_timeout and self._is_stale(pool_conn.timestamp): + logger.debug('Closing stale connection %s.', key) + super(PooledDatabase, self)._close(conn) + elif self._can_reuse(conn): + logger.debug('Returning %s to pool.', key) + heapq.heappush(self._connections, (pool_conn.timestamp, conn)) + else: + logger.debug('Closed %s.', key) + + def manual_close(self): + """ + Close the underlying connection without returning it to the pool. + """ + if self.is_closed(): + return False + + # Obtain reference to the connection in-use by the calling thread. + conn = self.connection() + + # A connection will only be re-added to the available list if it is + # marked as "in use" at the time it is closed. We will explicitly + # remove it from the "in use" list, call "close()" for the + # side-effects, and then explicitly close the connection. + self._in_use.pop(self.conn_key(conn), None) + self.close() + self._close(conn, close_conn=True) + + def close_idle(self): + # Close any open connections that are not currently in-use. + with self._lock: + for _, conn in self._connections: + self._close(conn, close_conn=True) + self._connections = [] + + def close_stale(self, age=600): + # Close any connections that are in-use but were checked out quite some + # time ago and can be considered stale. + with self._lock: + in_use = {} + cutoff = time.time() - age + n = 0 + for key, pool_conn in self._in_use.items(): + if pool_conn.checked_out < cutoff: + self._close(pool_conn.connection, close_conn=True) + n += 1 + else: + in_use[key] = pool_conn + self._in_use = in_use + return n + + def close_all(self): + # Close all connections -- available and in-use. Warning: may break any + # active connections used by other threads. + self.close() + with self._lock: + for _, conn in self._connections: + self._close(conn, close_conn=True) + for pool_conn in self._in_use.values(): + self._close(pool_conn.connection, close_conn=True) + self._connections = [] + self._in_use = {} + + +class PooledMySQLDatabase(PooledDatabase, MySQLDatabase): + def _is_closed(self, conn): + try: + conn.ping(False) + except: + return True + else: + return False + + +class _PooledPostgresqlDatabase(PooledDatabase): + def _is_closed(self, conn): + if conn.closed: + return True + + txn_status = conn.get_transaction_status() + if txn_status == TRANSACTION_STATUS_UNKNOWN: + return True + elif txn_status != TRANSACTION_STATUS_IDLE: + conn.rollback() + return False + + def _can_reuse(self, conn): + txn_status = conn.get_transaction_status() + # Do not return connection in an error state, as subsequent queries + # will all fail. If the status is unknown then we lost the connection + # to the server and the connection should not be re-used. + if txn_status == TRANSACTION_STATUS_UNKNOWN: + return False + elif txn_status == TRANSACTION_STATUS_INERROR: + conn.reset() + elif txn_status != TRANSACTION_STATUS_IDLE: + conn.rollback() + return True + +class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase): + pass + +try: + from playhouse.postgres_ext import PostgresqlExtDatabase + + class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase): + pass +except ImportError: + PooledPostgresqlExtDatabase = None + + +class _PooledSqliteDatabase(PooledDatabase): + def _is_closed(self, conn): + try: + conn.total_changes + except: + return True + else: + return False + +class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase): + pass + +try: + from playhouse.sqlite_ext import SqliteExtDatabase + + class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase): + pass +except ImportError: + PooledSqliteExtDatabase = None + +try: + from playhouse.sqlite_ext import CSqliteExtDatabase + + class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase): + pass +except ImportError: + PooledCSqliteExtDatabase = None diff --git a/python3.10libs/playhouse/postgres_ext.py b/python3.10libs/playhouse/postgres_ext.py new file mode 100644 index 0000000..4894acf --- /dev/null +++ b/python3.10libs/playhouse/postgres_ext.py @@ -0,0 +1,498 @@ +""" +Collection of postgres-specific extensions, currently including: + +* Support for hstore, a key/value type storage +""" +import json +import logging +import uuid + +from peewee import * +from peewee import ColumnBase +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import __deprecated__ +from peewee import __exception_wrapper__ + +try: + from psycopg2cffi import compat + compat.register() +except ImportError: + pass + +try: + from psycopg2.extras import register_hstore +except ImportError: + def register_hstore(c, globally): + pass +try: + from psycopg2.extras import Json +except: + Json = None + + +logger = logging.getLogger('peewee') + + +HCONTAINS_DICT = '@>' +HCONTAINS_KEYS = '?&' +HCONTAINS_KEY = '?' +HCONTAINS_ANY_KEY = '?|' +HKEY = '->' +HUPDATE = '||' +ACONTAINS = '@>' +ACONTAINED_BY = '<@' +ACONTAINS_ANY = '&&' +TS_MATCH = '@@' +JSONB_CONTAINS = '@>' +JSONB_CONTAINED_BY = '<@' +JSONB_CONTAINS_KEY = '?' +JSONB_CONTAINS_ANY_KEY = '?|' +JSONB_CONTAINS_ALL_KEYS = '?&' +JSONB_EXISTS = '?' +JSONB_REMOVE = '-' + + +class _LookupNode(ColumnBase): + def __init__(self, node, parts): + self.node = node + self.parts = parts + super(_LookupNode, self).__init__() + + def clone(self): + return type(self)(self.node, list(self.parts)) + + def __hash__(self): + return hash((self.__class__.__name__, id(self))) + + +class _JsonLookupBase(_LookupNode): + def __init__(self, node, parts, as_json=False): + super(_JsonLookupBase, self).__init__(node, parts) + self._as_json = as_json + + def clone(self): + return type(self)(self.node, list(self.parts), self._as_json) + + @Node.copy + def as_json(self, as_json=True): + self._as_json = as_json + + def concat(self, rhs): + if not isinstance(rhs, Node): + rhs = Json(rhs) + return Expression(self.as_json(True), OP.CONCAT, rhs) + + def contains(self, other): + clone = self.as_json(True) + if isinstance(other, (list, dict)): + return Expression(clone, JSONB_CONTAINS, Json(other)) + return Expression(clone, JSONB_EXISTS, other) + + def contains_any(self, *keys): + return Expression( + self.as_json(True), + JSONB_CONTAINS_ANY_KEY, + Value(list(keys), unpack=False)) + + def contains_all(self, *keys): + return Expression( + self.as_json(True), + JSONB_CONTAINS_ALL_KEYS, + Value(list(keys), unpack=False)) + + def has_key(self, key): + return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key) + + +class JsonLookup(_JsonLookupBase): + def __getitem__(self, value): + return JsonLookup(self.node, self.parts + [value], self._as_json) + + def __sql__(self, ctx): + ctx.sql(self.node) + for part in self.parts[:-1]: + ctx.literal('->').sql(part) + if self.parts: + (ctx + .literal('->' if self._as_json else '->>') + .sql(self.parts[-1])) + + return ctx + + +class JsonPath(_JsonLookupBase): + def __sql__(self, ctx): + return (ctx + .sql(self.node) + .literal('#>' if self._as_json else '#>>') + .sql(Value('{%s}' % ','.join(map(str, self.parts))))) + + +class ObjectSlice(_LookupNode): + @classmethod + def create(cls, node, value): + if isinstance(value, slice): + parts = [value.start or 0, value.stop or 0] + elif isinstance(value, int): + parts = [value] + elif isinstance(value, Node): + parts = value + else: + # Assumes colon-separated integer indexes. + parts = [int(i) for i in value.split(':')] + return cls(node, parts) + + def __sql__(self, ctx): + ctx.sql(self.node) + if isinstance(self.parts, Node): + ctx.literal('[').sql(self.parts).literal(']') + else: + ctx.literal('[%s]' % ':'.join(str(p + 1) for p in self.parts)) + return ctx + + def __getitem__(self, value): + return ObjectSlice.create(self, value) + + +class IndexedFieldMixin(object): + default_index_type = 'GIN' + + def __init__(self, *args, **kwargs): + kwargs.setdefault('index', True) # By default, use an index. + super(IndexedFieldMixin, self).__init__(*args, **kwargs) + + +class ArrayField(IndexedFieldMixin, Field): + passthrough = True + + def __init__(self, field_class=IntegerField, field_kwargs=None, + dimensions=1, convert_values=False, *args, **kwargs): + self.__field = field_class(**(field_kwargs or {})) + self.dimensions = dimensions + self.convert_values = convert_values + self.field_type = self.__field.field_type + super(ArrayField, self).__init__(*args, **kwargs) + + def bind(self, model, name, set_attribute=True): + ret = super(ArrayField, self).bind(model, name, set_attribute) + self.__field.bind(model, '__array_%s' % name, False) + return ret + + def ddl_datatype(self, ctx): + data_type = self.__field.ddl_datatype(ctx) + return NodeList((data_type, SQL('[]' * self.dimensions)), glue='') + + def db_value(self, value): + if value is None or isinstance(value, Node): + return value + elif self.convert_values: + return self._process(self.__field.db_value, value, self.dimensions) + else: + return value if isinstance(value, list) else list(value) + + def python_value(self, value): + if self.convert_values and value is not None: + conv = self.__field.python_value + if isinstance(value, list): + return self._process(conv, value, self.dimensions) + else: + return conv(value) + else: + return value + + def _process(self, conv, value, dimensions): + dimensions -= 1 + if dimensions == 0: + return [conv(v) for v in value] + else: + return [self._process(conv, v, dimensions) for v in value] + + def __getitem__(self, value): + return ObjectSlice.create(self, value) + + def _e(op): + def inner(self, rhs): + return Expression(self, op, ArrayValue(self, rhs)) + return inner + __eq__ = _e(OP.EQ) + __ne__ = _e(OP.NE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __hash__ = Field.__hash__ + + def contains(self, *items): + return Expression(self, ACONTAINS, ArrayValue(self, items)) + + def contains_any(self, *items): + return Expression(self, ACONTAINS_ANY, ArrayValue(self, items)) + + def contained_by(self, *items): + return Expression(self, ACONTAINED_BY, ArrayValue(self, items)) + + +class ArrayValue(Node): + def __init__(self, field, value): + self.field = field + self.value = value + + def __sql__(self, ctx): + return (ctx + .sql(Value(self.value, unpack=False)) + .literal('::') + .sql(self.field.ddl_datatype(ctx))) + + +class DateTimeTZField(DateTimeField): + field_type = 'TIMESTAMPTZ' + + +class HStoreField(IndexedFieldMixin, Field): + field_type = 'HSTORE' + __hash__ = Field.__hash__ + + def __getitem__(self, key): + return Expression(self, HKEY, Value(key)) + + def keys(self): + return fn.akeys(self) + + def values(self): + return fn.avals(self) + + def items(self): + return fn.hstore_to_matrix(self) + + def slice(self, *args): + return fn.slice(self, Value(list(args), unpack=False)) + + def exists(self, key): + return fn.exist(self, key) + + def defined(self, key): + return fn.defined(self, key) + + def update(self, **data): + return Expression(self, HUPDATE, data) + + def delete(self, *keys): + return fn.delete(self, Value(list(keys), unpack=False)) + + def contains(self, value): + if isinstance(value, dict): + rhs = Value(value, unpack=False) + return Expression(self, HCONTAINS_DICT, rhs) + elif isinstance(value, (list, tuple)): + rhs = Value(value, unpack=False) + return Expression(self, HCONTAINS_KEYS, rhs) + return Expression(self, HCONTAINS_KEY, value) + + def contains_any(self, *keys): + return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys), + unpack=False)) + + +class JSONField(Field): + field_type = 'JSON' + _json_datatype = 'json' + + def __init__(self, dumps=None, *args, **kwargs): + if Json is None: + raise Exception('Your version of psycopg2 does not support JSON.') + self.dumps = dumps or json.dumps + super(JSONField, self).__init__(*args, **kwargs) + + def db_value(self, value): + if value is None: + return value + if not isinstance(value, Json): + return Cast(self.dumps(value), self._json_datatype) + return value + + def __getitem__(self, value): + return JsonLookup(self, [value]) + + def path(self, *keys): + return JsonPath(self, keys) + + def concat(self, value): + if not isinstance(value, Node): + value = Json(value) + return super(JSONField, self).concat(value) + + +def cast_jsonb(node): + return NodeList((node, SQL('::jsonb')), glue='') + + +class BinaryJSONField(IndexedFieldMixin, JSONField): + field_type = 'JSONB' + _json_datatype = 'jsonb' + __hash__ = Field.__hash__ + + def contains(self, other): + if isinstance(other, (list, dict)): + return Expression(self, JSONB_CONTAINS, Json(other)) + elif isinstance(other, JSONField): + return Expression(self, JSONB_CONTAINS, other) + return Expression(cast_jsonb(self), JSONB_EXISTS, other) + + def contained_by(self, other): + return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other)) + + def contains_any(self, *items): + return Expression( + cast_jsonb(self), + JSONB_CONTAINS_ANY_KEY, + Value(list(items), unpack=False)) + + def contains_all(self, *items): + return Expression( + cast_jsonb(self), + JSONB_CONTAINS_ALL_KEYS, + Value(list(items), unpack=False)) + + def has_key(self, key): + return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key) + + def remove(self, *items): + return Expression( + cast_jsonb(self), + JSONB_REMOVE, + Value(list(items), unpack=False)) + + +class TSVectorField(IndexedFieldMixin, TextField): + field_type = 'TSVECTOR' + __hash__ = Field.__hash__ + + def match(self, query, language=None, plain=False): + params = (language, query) if language is not None else (query,) + func = fn.plainto_tsquery if plain else fn.to_tsquery + return Expression(self, TS_MATCH, func(*params)) + + +def Match(field, query, language=None): + params = (language, query) if language is not None else (query,) + field_params = (language, field) if language is not None else (field,) + return Expression( + fn.to_tsvector(*field_params), + TS_MATCH, + fn.to_tsquery(*params)) + + +class IntervalField(Field): + field_type = 'INTERVAL' + + +class FetchManyCursor(object): + __slots__ = ('cursor', 'array_size', 'exhausted', 'iterable') + + def __init__(self, cursor, array_size=None): + self.cursor = cursor + self.array_size = array_size or cursor.itersize + self.exhausted = False + self.iterable = self.row_gen() + + @property + def description(self): + return self.cursor.description + + def close(self): + self.cursor.close() + + def row_gen(self): + while True: + rows = self.cursor.fetchmany(self.array_size) + if not rows: + return + for row in rows: + yield row + + def fetchone(self): + if self.exhausted: + return + try: + return next(self.iterable) + except StopIteration: + self.exhausted = True + + +class ServerSideQuery(Node): + def __init__(self, query, array_size=None): + self.query = query + self.array_size = array_size + self._cursor_wrapper = None + + def __sql__(self, ctx): + return self.query.__sql__(ctx) + + def __iter__(self): + if self._cursor_wrapper is None: + self._execute(self.query._database) + return iter(self._cursor_wrapper.iterator()) + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self.query, named_cursor=True, + array_size=self.array_size) + self._cursor_wrapper = self.query._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + +def ServerSide(query, database=None, array_size=None): + if database is None: + database = query._database + with database.transaction(): + server_side_query = ServerSideQuery(query, array_size=array_size) + for row in server_side_query: + yield row + + +class _empty_object(object): + __slots__ = () + def __nonzero__(self): + return False + __bool__ = __nonzero__ + + +class PostgresqlExtDatabase(PostgresqlDatabase): + def __init__(self, *args, **kwargs): + self._register_hstore = kwargs.pop('register_hstore', False) + self._server_side_cursors = kwargs.pop('server_side_cursors', False) + super(PostgresqlExtDatabase, self).__init__(*args, **kwargs) + + def _connect(self): + conn = super(PostgresqlExtDatabase, self)._connect() + if self._register_hstore: + register_hstore(conn, globally=True) + return conn + + def cursor(self, commit=None, named_cursor=None): + if commit is not None: + __deprecated__('"commit" has been deprecated and is a no-op.') + if self.is_closed(): + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') + if named_cursor: + curs = self._state.conn.cursor(name=str(uuid.uuid1())) + return curs + return self._state.conn.cursor() + + def execute(self, query, commit=None, named_cursor=False, array_size=None, + **context_options): + if commit is not None: + __deprecated__('"commit" has been deprecated and is a no-op.') + ctx = self.get_sql_context(**context_options) + sql, params = ctx.sql(query).query() + named_cursor = named_cursor or (self._server_side_cursors and + sql[:6].lower() == 'select') + cursor = self.execute_sql(sql, params) + if named_cursor: + cursor = FetchManyCursor(cursor, array_size) + return cursor diff --git a/python3.10libs/playhouse/psycopg3_ext.py b/python3.10libs/playhouse/psycopg3_ext.py new file mode 100644 index 0000000..4eb46b6 --- /dev/null +++ b/python3.10libs/playhouse/psycopg3_ext.py @@ -0,0 +1,36 @@ +from peewee import * + +try: + import psycopg + #from psycopg.types.json import Jsonb +except ImportError: + psycopg = None + + +class Psycopg3Database(PostgresqlDatabase): + def _connect(self): + if psycopg is None: + raise ImproperlyConfigured('psycopg3 is not installed!') + conn = psycopg.connect(dbname=self.database, **self.connect_params) + if self._isolation_level is not None: + conn.isolation_level = self._isolation_level + conn.autocommit = True + return conn + + def get_binary_type(self): + return psycopg.Binary + + def _set_server_version(self, conn): + self.server_version = conn.pgconn.server_version + if self.server_version >= 90600: + self.safe_create_index = True + + def is_connection_usable(self): + if self._state.closed: + return False + + # Returns True if we are idle, running a command, or in an active + # connection. If the connection is in an error state or the connection + # is otherwise unusable, return False. + conn = self._state.conn + return conn.pgconn.transaction_status < conn.TransactionStatus.INERROR diff --git a/python3.10libs/playhouse/reflection.py b/python3.10libs/playhouse/reflection.py new file mode 100644 index 0000000..780eabe --- /dev/null +++ b/python3.10libs/playhouse/reflection.py @@ -0,0 +1,858 @@ +try: + from collections import OrderedDict +except ImportError: + OrderedDict = dict +from collections import namedtuple +from inspect import isclass +import re +import warnings + +from peewee import * +from peewee import _StringField +from peewee import _query_val_transform +from peewee import CommaNodeList +from peewee import SCOPE_VALUES +from peewee import make_snake_case +from peewee import text_type +try: + from pymysql.constants import FIELD_TYPE +except ImportError: + try: + from MySQLdb.constants import FIELD_TYPE + except ImportError: + FIELD_TYPE = None +try: + from playhouse import postgres_ext +except ImportError: + postgres_ext = None +try: + from playhouse.cockroachdb import CockroachDatabase +except ImportError: + CockroachDatabase = None + +RESERVED_WORDS = set([ + 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', + 'else', 'except', 'exec', 'finally', 'for', 'from', 'global', 'if', + 'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 'raise', + 'return', 'try', 'while', 'with', 'yield', +]) + + +class UnknownField(object): + pass + + +class Column(object): + """ + Store metadata about a database column. + """ + primary_key_types = (IntegerField, AutoField) + + def __init__(self, name, field_class, raw_column_type, nullable, + primary_key=False, column_name=None, index=False, + unique=False, default=None, extra_parameters=None): + self.name = name + self.field_class = field_class + self.raw_column_type = raw_column_type + self.nullable = nullable + self.primary_key = primary_key + self.column_name = column_name + self.index = index + self.unique = unique + self.default = default + self.extra_parameters = extra_parameters + + # Foreign key metadata. + self.rel_model = None + self.related_name = None + self.to_field = None + + def __repr__(self): + attrs = [ + 'field_class', + 'raw_column_type', + 'nullable', + 'primary_key', + 'column_name'] + keyword_args = ', '.join( + '%s=%s' % (attr, getattr(self, attr)) + for attr in attrs) + return 'Column(%s, %s)' % (self.name, keyword_args) + + def get_field_parameters(self): + params = {} + if self.extra_parameters is not None: + params.update(self.extra_parameters) + + # Set up default attributes. + if self.nullable: + params['null'] = True + if self.field_class is ForeignKeyField or self.name != self.column_name: + params['column_name'] = "'%s'" % self.column_name + if self.primary_key and not issubclass(self.field_class, AutoField): + params['primary_key'] = True + if self.default is not None: + params['constraints'] = '[SQL("DEFAULT %s")]' % self.default + + # Handle ForeignKeyField-specific attributes. + if self.is_foreign_key(): + params['model'] = self.rel_model + if self.to_field: + params['field'] = "'%s'" % self.to_field + if self.related_name: + params['backref'] = "'%s'" % self.related_name + + # Handle indexes on column. + if not self.is_primary_key(): + if self.unique: + params['unique'] = 'True' + elif self.index and not self.is_foreign_key(): + params['index'] = 'True' + + return params + + def is_primary_key(self): + return self.field_class is AutoField or self.primary_key + + def is_foreign_key(self): + return self.field_class is ForeignKeyField + + def is_self_referential_fk(self): + return (self.field_class is ForeignKeyField and + self.rel_model == "'self'") + + def set_foreign_key(self, foreign_key, model_names, dest=None, + related_name=None): + self.foreign_key = foreign_key + self.field_class = ForeignKeyField + if foreign_key.dest_table == foreign_key.table: + self.rel_model = "'self'" + else: + self.rel_model = model_names[foreign_key.dest_table] + self.to_field = dest and dest.name or None + self.related_name = related_name or None + + def get_field(self): + # Generate the field definition for this column. + field_params = {} + for key, value in self.get_field_parameters().items(): + if isclass(value) and issubclass(value, Field): + value = value.__name__ + field_params[key] = value + + param_str = ', '.join('%s=%s' % (k, v) + for k, v in sorted(field_params.items())) + field = '%s = %s(%s)' % ( + self.name, + self.field_class.__name__, + param_str) + + if self.field_class is UnknownField: + field = '%s # %s' % (field, self.raw_column_type) + + return field + + +class Metadata(object): + column_map = {} + extension_import = '' + + def __init__(self, database): + self.database = database + self.requires_extension = False + + def execute(self, sql, *params): + return self.database.execute_sql(sql, params) + + def get_columns(self, table, schema=None): + metadata = OrderedDict( + (metadata.name, metadata) + for metadata in self.database.get_columns(table, schema)) + + # Look up the actual column type for each column. + column_types, extra_params = self.get_column_types(table, schema) + + # Look up the primary keys. + pk_names = self.get_primary_keys(table, schema) + if len(pk_names) == 1: + pk = pk_names[0] + if column_types[pk] is IntegerField: + column_types[pk] = AutoField + elif column_types[pk] is BigIntegerField: + column_types[pk] = BigAutoField + + columns = OrderedDict() + for name, column_data in metadata.items(): + field_class = column_types[name] + default = self._clean_default(field_class, column_data.default) + + columns[name] = Column( + name, + field_class=field_class, + raw_column_type=column_data.data_type, + nullable=column_data.null, + primary_key=column_data.primary_key, + column_name=name, + default=default, + extra_parameters=extra_params.get(name)) + + return columns + + def get_column_types(self, table, schema=None): + raise NotImplementedError + + def _clean_default(self, field_class, default): + if default is None or field_class in (AutoField, BigAutoField) or \ + default.lower() == 'null': + return + if issubclass(field_class, _StringField) and \ + isinstance(default, text_type) and not default.startswith("'"): + default = "'%s'" % default + return default or "''" + + def get_foreign_keys(self, table, schema=None): + return self.database.get_foreign_keys(table, schema) + + def get_primary_keys(self, table, schema=None): + return self.database.get_primary_keys(table, schema) + + def get_indexes(self, table, schema=None): + return self.database.get_indexes(table, schema) + + +class PostgresqlMetadata(Metadata): + column_map = { + 16: BooleanField, + 17: BlobField, + 20: BigIntegerField, + 21: SmallIntegerField, + 23: IntegerField, + 25: TextField, + 700: FloatField, + 701: DoubleField, + 1042: CharField, # blank-padded CHAR + 1043: CharField, + 1082: DateField, + 1114: DateTimeField, + 1184: DateTimeField, + 1083: TimeField, + 1266: TimeField, + 1700: DecimalField, + 2950: UUIDField, # UUID + } + array_types = { + 1000: BooleanField, + 1001: BlobField, + 1005: SmallIntegerField, + 1007: IntegerField, + 1009: TextField, + 1014: CharField, + 1015: CharField, + 1016: BigIntegerField, + 1115: DateTimeField, + 1182: DateField, + 1183: TimeField, + 2951: UUIDField, + } + extension_import = 'from playhouse.postgres_ext import *' + + def __init__(self, database): + super(PostgresqlMetadata, self).__init__(database) + + if postgres_ext is not None: + # Attempt to add types like HStore and JSON. + cursor = self.execute('select oid, typname, format_type(oid, NULL)' + ' from pg_type;') + results = cursor.fetchall() + + for oid, typname, formatted_type in results: + if typname == 'json': + self.column_map[oid] = postgres_ext.JSONField + elif typname == 'jsonb': + self.column_map[oid] = postgres_ext.BinaryJSONField + elif typname == 'hstore': + self.column_map[oid] = postgres_ext.HStoreField + elif typname == 'tsvector': + self.column_map[oid] = postgres_ext.TSVectorField + + for oid in self.array_types: + self.column_map[oid] = postgres_ext.ArrayField + + def get_column_types(self, table, schema): + column_types = {} + extra_params = {} + extension_types = set(( + postgres_ext.ArrayField, + postgres_ext.BinaryJSONField, + postgres_ext.JSONField, + postgres_ext.TSVectorField, + postgres_ext.HStoreField)) if postgres_ext is not None else set() + + # Look up the actual column type for each column. + identifier = '%s."%s"' % (schema, table) + cursor = self.execute( + 'SELECT attname, atttypid FROM pg_catalog.pg_attribute ' + 'WHERE attrelid = %s::regclass AND attnum > %s', identifier, 0) + + # Store column metadata in dictionary keyed by column name. + for name, oid in cursor.fetchall(): + column_types[name] = self.column_map.get(oid, UnknownField) + if column_types[name] in extension_types: + self.requires_extension = True + if oid in self.array_types: + extra_params[name] = {'field_class': self.array_types[oid]} + + return column_types, extra_params + + def get_columns(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_columns(table, schema) + + def get_foreign_keys(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_foreign_keys(table, schema) + + def get_primary_keys(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_primary_keys(table, schema) + + def get_indexes(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_indexes(table, schema) + + +class CockroachDBMetadata(PostgresqlMetadata): + # CRDB treats INT the same as BIGINT, so we just map bigint type OIDs to + # regular IntegerField. + column_map = PostgresqlMetadata.column_map.copy() + column_map[20] = IntegerField + array_types = PostgresqlMetadata.array_types.copy() + array_types[1016] = IntegerField + extension_import = 'from playhouse.cockroachdb import *' + + def __init__(self, database): + Metadata.__init__(self, database) + self.requires_extension = True + + if postgres_ext is not None: + # Attempt to add JSON types. + cursor = self.execute('select oid, typname, format_type(oid, NULL)' + ' from pg_type;') + results = cursor.fetchall() + + for oid, typname, formatted_type in results: + if typname == 'jsonb': + self.column_map[oid] = postgres_ext.BinaryJSONField + + for oid in self.array_types: + self.column_map[oid] = postgres_ext.ArrayField + + +class MySQLMetadata(Metadata): + if FIELD_TYPE is None: + column_map = {} + else: + column_map = { + FIELD_TYPE.BLOB: TextField, + FIELD_TYPE.CHAR: CharField, + FIELD_TYPE.DATE: DateField, + FIELD_TYPE.DATETIME: DateTimeField, + FIELD_TYPE.DECIMAL: DecimalField, + FIELD_TYPE.DOUBLE: FloatField, + FIELD_TYPE.FLOAT: FloatField, + FIELD_TYPE.INT24: IntegerField, + FIELD_TYPE.LONG_BLOB: TextField, + FIELD_TYPE.LONG: IntegerField, + FIELD_TYPE.LONGLONG: BigIntegerField, + FIELD_TYPE.MEDIUM_BLOB: TextField, + FIELD_TYPE.NEWDECIMAL: DecimalField, + FIELD_TYPE.SHORT: IntegerField, + FIELD_TYPE.STRING: CharField, + FIELD_TYPE.TIMESTAMP: DateTimeField, + FIELD_TYPE.TIME: TimeField, + FIELD_TYPE.TINY_BLOB: TextField, + FIELD_TYPE.TINY: IntegerField, + FIELD_TYPE.VAR_STRING: CharField, + } + + def __init__(self, database, **kwargs): + if 'password' in kwargs: + kwargs['passwd'] = kwargs.pop('password') + super(MySQLMetadata, self).__init__(database, **kwargs) + + def get_column_types(self, table, schema=None): + column_types = {} + + # Look up the actual column type for each column. + cursor = self.execute('SELECT * FROM `%s` LIMIT 1' % table) + + # Store column metadata in dictionary keyed by column name. + for column_description in cursor.description: + name, type_code = column_description[:2] + column_types[name] = self.column_map.get(type_code, UnknownField) + + return column_types, {} + + +class SqliteMetadata(Metadata): + column_map = { + 'bigint': BigIntegerField, + 'blob': BlobField, + 'bool': BooleanField, + 'boolean': BooleanField, + 'char': CharField, + 'date': DateField, + 'datetime': DateTimeField, + 'decimal': DecimalField, + 'float': FloatField, + 'integer': IntegerField, + 'integer unsigned': IntegerField, + 'int': IntegerField, + 'long': BigIntegerField, + 'numeric': DecimalField, + 'real': FloatField, + 'smallinteger': IntegerField, + 'smallint': IntegerField, + 'smallint unsigned': IntegerField, + 'text': TextField, + 'time': TimeField, + 'varchar': CharField, + } + + begin = '(?:["\[\(]+)?' + end = '(?:["\]\)]+)?' + re_foreign_key = ( + '(?:FOREIGN KEY\s*)?' + '{begin}(.+?){end}\s+(?:.+\s+)?' + 'references\s+{begin}(.+?){end}' + '\s*\(["|\[]?(.+?)["|\]]?\)').format(begin=begin, end=end) + re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$' + + def _map_col(self, column_type): + raw_column_type = column_type.lower() + if raw_column_type in self.column_map: + field_class = self.column_map[raw_column_type] + elif re.search(self.re_varchar, raw_column_type): + field_class = CharField + else: + column_type = re.sub('\(.+\)', '', raw_column_type) + if column_type == '': + field_class = BareField + else: + field_class = self.column_map.get(column_type, UnknownField) + return field_class + + def get_column_types(self, table, schema=None): + column_types = {} + columns = self.database.get_columns(table) + + for column in columns: + column_types[column.name] = self._map_col(column.data_type) + + return column_types, {} + + +_DatabaseMetadata = namedtuple('_DatabaseMetadata', ( + 'columns', + 'primary_keys', + 'foreign_keys', + 'model_names', + 'indexes')) + + +class DatabaseMetadata(_DatabaseMetadata): + def multi_column_indexes(self, table): + accum = [] + for index in self.indexes[table]: + if len(index.columns) > 1: + field_names = [self.columns[table][column].name + for column in index.columns + if column in self.columns[table]] + accum.append((field_names, index.unique)) + return accum + + def column_indexes(self, table): + accum = {} + for index in self.indexes[table]: + if len(index.columns) == 1: + accum[index.columns[0]] = index.unique + return accum + + +class Introspector(object): + pk_classes = [AutoField, IntegerField] + + def __init__(self, metadata, schema=None): + self.metadata = metadata + self.schema = schema + + def __repr__(self): + return '' % self.metadata.database + + @classmethod + def from_database(cls, database, schema=None): + if isinstance(database, Proxy): + if database.obj is None: + raise ValueError('Cannot introspect an uninitialized Proxy.') + database = database.obj # Reference the proxied db obj. + if CockroachDatabase and isinstance(database, CockroachDatabase): + metadata = CockroachDBMetadata(database) + elif isinstance(database, PostgresqlDatabase): + metadata = PostgresqlMetadata(database) + elif isinstance(database, MySQLDatabase): + metadata = MySQLMetadata(database) + elif isinstance(database, SqliteDatabase): + metadata = SqliteMetadata(database) + else: + raise ValueError('Introspection not supported for %r' % database) + return cls(metadata, schema=schema) + + def get_database_class(self): + return type(self.metadata.database) + + def get_database_name(self): + return self.metadata.database.database + + def get_database_kwargs(self): + return self.metadata.database.connect_params + + def get_additional_imports(self): + if self.metadata.requires_extension: + return '\n' + self.metadata.extension_import + return '' + + def make_model_name(self, table, snake_case=True): + if snake_case: + table = make_snake_case(table) + model = re.sub(r'[^\w]+', '', table) + model_name = ''.join(sub.title() for sub in model.split('_')) + if not model_name[0].isalpha(): + model_name = 'T' + model_name + return model_name + + def make_column_name(self, column, is_foreign_key=False, snake_case=True): + column = column.strip() + if snake_case: + column = make_snake_case(column) + column = column.lower() + if is_foreign_key: + # Strip "_id" from foreign keys, unless the foreign-key happens to + # be named "_id", in which case the name is retained. + column = re.sub('_id$', '', column) or column + + # Remove characters that are invalid for Python identifiers. + column = re.sub(r'[^\w]+', '_', column) + if column in RESERVED_WORDS: + column += '_' + if len(column) and column[0].isdigit(): + column = '_' + column + return column + + def introspect(self, table_names=None, literal_column_names=False, + include_views=False, snake_case=True): + # Retrieve all the tables in the database. + tables = self.metadata.database.get_tables(schema=self.schema) + if include_views: + views = self.metadata.database.get_views(schema=self.schema) + tables.extend([view.name for view in views]) + + if table_names is not None: + tables = [table for table in tables if table in table_names] + table_set = set(tables) + + # Store a mapping of table name -> dictionary of columns. + columns = {} + + # Store a mapping of table name -> set of primary key columns. + primary_keys = {} + + # Store a mapping of table -> foreign keys. + foreign_keys = {} + + # Store a mapping of table name -> model name. + model_names = {} + + # Store a mapping of table name -> indexes. + indexes = {} + + # Gather the columns for each table. + for table in tables: + table_indexes = self.metadata.get_indexes(table, self.schema) + table_columns = self.metadata.get_columns(table, self.schema) + try: + foreign_keys[table] = self.metadata.get_foreign_keys( + table, self.schema) + except ValueError as exc: + err(*exc.args) + foreign_keys[table] = [] + else: + # If there is a possibility we could exclude a dependent table, + # ensure that we introspect it so FKs will work. + if table_names is not None: + for foreign_key in foreign_keys[table]: + if foreign_key.dest_table not in table_set: + tables.append(foreign_key.dest_table) + table_set.add(foreign_key.dest_table) + + model_names[table] = self.make_model_name(table, snake_case) + + # Collect sets of all the column names as well as all the + # foreign-key column names. + lower_col_names = set(column_name.lower() + for column_name in table_columns) + fks = set(fk_col.column for fk_col in foreign_keys[table]) + + for col_name, column in table_columns.items(): + if literal_column_names: + new_name = re.sub(r'[^\w]+', '_', col_name) + else: + new_name = self.make_column_name(col_name, col_name in fks, + snake_case) + + # If we have two columns, "parent" and "parent_id", ensure + # that when we don't introduce naming conflicts. + lower_name = col_name.lower() + if lower_name.endswith('_id') and new_name in lower_col_names: + new_name = col_name.lower() + + column.name = new_name + + for index in table_indexes: + if len(index.columns) == 1: + column = index.columns[0] + if column in table_columns: + table_columns[column].unique = index.unique + table_columns[column].index = True + + primary_keys[table] = self.metadata.get_primary_keys( + table, self.schema) + columns[table] = table_columns + indexes[table] = table_indexes + + # Gather all instances where we might have a `related_name` conflict, + # either due to multiple FKs on a table pointing to the same table, + # or a related_name that would conflict with an existing field. + related_names = {} + sort_fn = lambda foreign_key: foreign_key.column + for table in tables: + models_referenced = set() + for foreign_key in sorted(foreign_keys[table], key=sort_fn): + try: + column = columns[table][foreign_key.column] + except KeyError: + continue + + dest_table = foreign_key.dest_table + if dest_table in models_referenced: + related_names[column] = '%s_%s_set' % ( + dest_table, + column.name) + else: + models_referenced.add(dest_table) + + # On the second pass convert all foreign keys. + for table in tables: + for foreign_key in foreign_keys[table]: + src = columns[foreign_key.table][foreign_key.column] + try: + dest = columns[foreign_key.dest_table][ + foreign_key.dest_column] + except KeyError: + dest = None + + src.set_foreign_key( + foreign_key=foreign_key, + model_names=model_names, + dest=dest, + related_name=related_names.get(src)) + + return DatabaseMetadata( + columns, + primary_keys, + foreign_keys, + model_names, + indexes) + + def generate_models(self, skip_invalid=False, table_names=None, + literal_column_names=False, bare_fields=False, + include_views=False): + database = self.introspect(table_names, literal_column_names, + include_views) + models = {} + + class BaseModel(Model): + class Meta: + database = self.metadata.database + schema = self.schema + + pending = set() + + def _create_model(table, models): + pending.add(table) + for foreign_key in database.foreign_keys[table]: + dest = foreign_key.dest_table + + if dest not in models and dest != table: + if dest in pending: + warnings.warn('Possible reference cycle found between ' + '%s and %s' % (table, dest)) + else: + _create_model(dest, models) + + primary_keys = [] + columns = database.columns[table] + for column_name, column in columns.items(): + if column.primary_key: + primary_keys.append(column.name) + + multi_column_indexes = database.multi_column_indexes(table) + column_indexes = database.column_indexes(table) + + class Meta: + indexes = multi_column_indexes + table_name = table + + # Fix models with multi-column primary keys. + composite_key = False + if len(primary_keys) == 0: + if 'id' not in columns: + Meta.primary_key = False + else: + primary_keys = columns.keys() + + if len(primary_keys) > 1: + Meta.primary_key = CompositeKey(*[ + field.name for col, field in columns.items() + if col in primary_keys]) + composite_key = True + + attrs = {'Meta': Meta} + for column_name, column in columns.items(): + FieldClass = column.field_class + if FieldClass is not ForeignKeyField and bare_fields: + FieldClass = BareField + elif FieldClass is UnknownField: + FieldClass = BareField + + params = { + 'column_name': column_name, + 'null': column.nullable} + if column.primary_key and composite_key: + if FieldClass is AutoField: + FieldClass = IntegerField + params['primary_key'] = False + elif column.primary_key and FieldClass is not AutoField: + params['primary_key'] = True + if column.is_foreign_key(): + if column.is_self_referential_fk(): + params['model'] = 'self' + else: + dest_table = column.foreign_key.dest_table + if dest_table in models: + params['model'] = models[dest_table] + else: + FieldClass = DeferredForeignKey + params['rel_model_name'] = dest_table + if column.to_field: + params['field'] = column.to_field + + # Generate a unique related name. + params['backref'] = '%s_%s_rel' % (table, column_name) + + if column.default is not None: + constraint = SQL('DEFAULT %s' % column.default) + params['constraints'] = [constraint] + + if not column.is_primary_key(): + if column_name in column_indexes: + if column_indexes[column_name]: + params['unique'] = True + elif not column.is_foreign_key(): + params['index'] = True + else: + params['index'] = False + + attrs[column.name] = FieldClass(**params) + + try: + models[table] = type(str(table), (BaseModel,), attrs) + except ValueError: + if not skip_invalid: + raise + finally: + if table in pending: + pending.remove(table) + + # Actually generate Model classes. + for table, model in sorted(database.model_names.items()): + if table not in models: + _create_model(table, models) + + return models + + +def introspect(database, schema=None): + introspector = Introspector.from_database(database, schema=schema) + return introspector.introspect() + + +def generate_models(database, schema=None, **options): + introspector = Introspector.from_database(database, schema=schema) + return introspector.generate_models(**options) + + +def print_model(model, indexes=True, inline_indexes=False): + print(model._meta.name) + for field in model._meta.sorted_fields: + parts = [' %s %s' % (field.name, field.field_type)] + if field.primary_key: + parts.append(' PK') + elif inline_indexes: + if field.unique: + parts.append(' UNIQUE') + elif field.index: + parts.append(' INDEX') + if isinstance(field, ForeignKeyField): + parts.append(' FK: %s.%s' % (field.rel_model.__name__, + field.rel_field.name)) + print(''.join(parts)) + + if indexes: + index_list = model._meta.fields_to_index() + if not index_list: + return + + print('\nindex(es)') + for index in index_list: + parts = [' '] + ctx = model._meta.database.get_sql_context() + with ctx.scope_values(param='%s', quote='""'): + ctx.sql(CommaNodeList(index._expressions)) + if index._where: + ctx.literal(' WHERE ') + ctx.sql(index._where) + sql, params = ctx.query() + + clean = sql % tuple(map(_query_val_transform, params)) + parts.append(clean.replace('"', '')) + + if index._unique: + parts.append(' UNIQUE') + print(''.join(parts)) + + +def get_table_sql(model): + sql, params = model._schema._create_table().query() + if model._meta.database.param != '%s': + sql = sql.replace(model._meta.database.param, '%s') + + # Format and indent the table declaration, simplest possible approach. + match_obj = re.match('^(.+?\()(.+)(\).*)', sql) + create, columns, extra = match_obj.groups() + indented = ',\n'.join(' %s' % column for column in columns.split(', ')) + + clean = '\n'.join((create, indented, extra)).strip() + return clean % tuple(map(_query_val_transform, params)) + +def print_table_sql(model): + print(get_table_sql(model)) diff --git a/python3.10libs/playhouse/shortcuts.py b/python3.10libs/playhouse/shortcuts.py new file mode 100644 index 0000000..cba882b --- /dev/null +++ b/python3.10libs/playhouse/shortcuts.py @@ -0,0 +1,331 @@ +import threading + +from peewee import * +from peewee import Alias +from peewee import CompoundSelectQuery +from peewee import Metadata +from peewee import callable_ +from peewee import __deprecated__ + + +_clone_set = lambda s: set(s) if s else set() + + +def model_to_dict(model, recurse=True, backrefs=False, only=None, + exclude=None, seen=None, extra_attrs=None, + fields_from_query=None, max_depth=None, manytomany=False): + """ + Convert a model instance (and any related objects) to a dictionary. + + :param bool recurse: Whether foreign-keys should be recursed. + :param bool backrefs: Whether lists of related objects should be recursed. + :param only: A list (or set) of field instances indicating which fields + should be included. + :param exclude: A list (or set) of field instances that should be + excluded from the dictionary. + :param list extra_attrs: Names of model instance attributes or methods + that should be included. + :param SelectQuery fields_from_query: Query that was source of model. Take + fields explicitly selected by the query and serialize them. + :param int max_depth: Maximum depth to recurse, value <= 0 means no max. + :param bool manytomany: Process many-to-many fields. + """ + max_depth = -1 if max_depth is None else max_depth + if max_depth == 0: + recurse = False + + only = _clone_set(only) + extra_attrs = _clone_set(extra_attrs) + should_skip = lambda n: (n in exclude) or (only and (n not in only)) + + if fields_from_query is not None: + for item in fields_from_query._returning: + if isinstance(item, Field): + only.add(item) + elif isinstance(item, Alias): + extra_attrs.add(item._alias) + + data = {} + exclude = _clone_set(exclude) + seen = _clone_set(seen) + exclude |= seen + model_class = type(model) + + if manytomany: + for name, m2m in model._meta.manytomany.items(): + if should_skip(name): + continue + + exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref])) + for fkf in m2m.through_model._meta.refs: + exclude.add(fkf) + + accum = [] + for rel_obj in getattr(model, name): + accum.append(model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + max_depth=max_depth - 1)) + data[name] = accum + + for field in model._meta.sorted_fields: + if should_skip(field): + continue + + field_data = model.__data__.get(field.name) + if isinstance(field, ForeignKeyField) and recurse: + if field_data is not None: + seen.add(field) + rel_obj = getattr(model, field.name) + field_data = model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + seen=seen, + max_depth=max_depth - 1) + else: + field_data = None + + data[field.name] = field_data + + if extra_attrs: + for attr_name in extra_attrs: + attr = getattr(model, attr_name) + if callable_(attr): + data[attr_name] = attr() + else: + data[attr_name] = attr + + if backrefs and recurse: + for foreign_key, rel_model in model._meta.backrefs.items(): + if foreign_key.backref == '+': continue + descriptor = getattr(model_class, foreign_key.backref) + if descriptor in exclude or foreign_key in exclude: + continue + if only and (descriptor not in only) and (foreign_key not in only): + continue + + accum = [] + exclude.add(foreign_key) + related_query = getattr(model, foreign_key.backref) + + for rel_obj in related_query: + accum.append(model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + max_depth=max_depth - 1)) + + data[foreign_key.backref] = accum + + return data + + +def update_model_from_dict(instance, data, ignore_unknown=False): + meta = instance._meta + backrefs = dict([(fk.backref, fk) for fk in meta.backrefs]) + + for key, value in data.items(): + if key in meta.combined: + field = meta.combined[key] + is_backref = False + elif key in backrefs: + field = backrefs[key] + is_backref = True + elif ignore_unknown: + setattr(instance, key, value) + continue + else: + raise AttributeError('Unrecognized attribute "%s" for model ' + 'class %s.' % (key, type(instance))) + + is_foreign_key = isinstance(field, ForeignKeyField) + + if not is_backref and is_foreign_key and isinstance(value, dict): + try: + rel_instance = instance.__rel__[field.name] + except KeyError: + rel_instance = field.rel_model() + setattr( + instance, + field.name, + update_model_from_dict(rel_instance, value, ignore_unknown)) + elif is_backref and isinstance(value, (list, tuple)): + instances = [ + dict_to_model(field.model, row_data, ignore_unknown) + for row_data in value] + for rel_instance in instances: + setattr(rel_instance, field.name, instance) + setattr(instance, field.backref, instances) + else: + setattr(instance, field.name, value) + + return instance + + +def dict_to_model(model_class, data, ignore_unknown=False): + return update_model_from_dict(model_class(), data, ignore_unknown) + + +def insert_where(cls, data, where=None): + """ + Helper for generating conditional INSERT queries. + + For example, prevent INSERTing a new tweet if the user has tweeted within + the last hour:: + + INSERT INTO "tweet" ("user_id", "content", "timestamp") + SELECT 234, 'some content', now() + WHERE NOT EXISTS ( + SELECT 1 FROM "tweet" + WHERE user_id = 234 AND timestamp > now() - interval '1 hour') + + Using this helper: + + cond = ~fn.EXISTS(Tweet.select().where( + Tweet.user == user_obj, + Tweet.timestamp > one_hour_ago)) + + iq = insert_where(Tweet, { + Tweet.user: user_obj, + Tweet.content: 'some content'}, where=cond) + + res = iq.execute() + """ + for field, default in cls._meta.defaults.items(): + if field.name in data or field in data: continue + value = default() if callable_(default) else default + data[field] = value + fields, values = zip(*data.items()) + sq = Select(columns=values).where(where) + return cls.insert_from(sq, fields).as_rowcount() + + +class ReconnectMixin(object): + """ + Mixin class that attempts to automatically reconnect to the database under + certain error conditions. + + For example, MySQL servers will typically close connections that are idle + for 28800 seconds ("wait_timeout" setting). If your application makes use + of long-lived connections, you may find your connections are closed after + a period of no activity. This mixin will attempt to reconnect automatically + when these errors occur. + + This mixin class probably should not be used with Postgres (unless you + REALLY know what you are doing) and definitely has no business being used + with Sqlite. If you wish to use with Postgres, you will need to adapt the + `reconnect_errors` attribute to something appropriate for Postgres. + """ + reconnect_errors = ( + # Error class, error message fragment (or empty string for all). + (OperationalError, '2006'), # MySQL server has gone away. + (OperationalError, '2013'), # Lost connection to MySQL server. + (OperationalError, '2014'), # Commands out of sync. + (OperationalError, '4031'), # Client interaction timeout. + + # mysql-connector raises a slightly different error when an idle + # connection is terminated by the server. This is equivalent to 2013. + (OperationalError, 'MySQL Connection not available.'), + + # Postgres error examples: + #(OperationalError, 'terminat'), + #(InterfaceError, 'connection already closed'), + ) + + def __init__(self, *args, **kwargs): + super(ReconnectMixin, self).__init__(*args, **kwargs) + + # Normalize the reconnect errors to a more efficient data-structure. + self._reconnect_errors = {} + for exc_class, err_fragment in self.reconnect_errors: + self._reconnect_errors.setdefault(exc_class, []) + self._reconnect_errors[exc_class].append(err_fragment.lower()) + + def execute_sql(self, sql, params=None, commit=None): + if commit is not None: + __deprecated__('"commit" has been deprecated and is a no-op.') + return self._reconnect(super(ReconnectMixin, self).execute_sql, sql, params) + + def begin(self): + return self._reconnect(super(ReconnectMixin, self).begin) + + def _reconnect(self, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as exc: + # If we are in a transaction, do not reconnect silently as + # any changes could be lost. + if self.in_transaction(): + raise exc + + exc_class = type(exc) + if exc_class not in self._reconnect_errors: + raise exc + + exc_repr = str(exc).lower() + for err_fragment in self._reconnect_errors[exc_class]: + if err_fragment in exc_repr: + break + else: + raise exc + + if not self.is_closed(): + self.close() + self.connect() + + return func(*args, **kwargs) + + +def resolve_multimodel_query(query, key='_model_identifier'): + mapping = {} + accum = [query] + while accum: + curr = accum.pop() + if isinstance(curr, CompoundSelectQuery): + accum.extend((curr.lhs, curr.rhs)) + continue + + model_class = curr.model + name = model_class._meta.table_name + mapping[name] = model_class + curr._returning.append(Value(name).alias(key)) + + def wrapped_iterator(): + for row in query.dicts().iterator(): + identifier = row.pop(key) + model = mapping[identifier] + yield model(**row) + + return wrapped_iterator() + + +class ThreadSafeDatabaseMetadata(Metadata): + """ + Metadata class to allow swapping database at run-time in a multi-threaded + application. To use: + + class Base(Model): + class Meta: + model_metadata_class = ThreadSafeDatabaseMetadata + """ + def __init__(self, *args, **kwargs): + # The database attribute is stored in a thread-local. + self._database = None + self._local = threading.local() + super(ThreadSafeDatabaseMetadata, self).__init__(*args, **kwargs) + + def _get_db(self): + return getattr(self._local, 'database', self._database) + def _set_db(self, db): + if self._database is None: + self._database = db + self._local.database = db + database = property(_get_db, _set_db) diff --git a/python3.10libs/playhouse/signals.py b/python3.10libs/playhouse/signals.py new file mode 100644 index 0000000..b860a2c --- /dev/null +++ b/python3.10libs/playhouse/signals.py @@ -0,0 +1,79 @@ +""" +Provide django-style hooks for model events. +""" +from peewee import Model as _Model + + +class Signal(object): + def __init__(self): + self._flush() + + def _flush(self): + self._receivers = set() + self._receiver_list = [] + + def connect(self, receiver, name=None, sender=None): + name = name or receiver.__name__ + key = (name, sender) + if key not in self._receivers: + self._receivers.add(key) + self._receiver_list.append((name, receiver, sender)) + else: + raise ValueError('receiver named %s (for sender=%s) already ' + 'connected' % (name, sender or 'any')) + + def disconnect(self, receiver=None, name=None, sender=None): + if receiver: + name = name or receiver.__name__ + if not name: + raise ValueError('a receiver or a name must be provided') + + key = (name, sender) + if key not in self._receivers: + raise ValueError('receiver named %s for sender=%s not found.' % + (name, sender or 'any')) + + self._receivers.remove(key) + self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list + if (n, s) != key] + + def __call__(self, name=None, sender=None): + def decorator(fn): + self.connect(fn, name, sender) + return fn + return decorator + + def send(self, instance, *args, **kwargs): + sender = type(instance) + responses = [] + for n, r, s in self._receiver_list: + if s is None or isinstance(instance, s): + responses.append((r, r(sender, instance, *args, **kwargs))) + return responses + + +pre_save = Signal() +post_save = Signal() +pre_delete = Signal() +post_delete = Signal() +pre_init = Signal() + + +class Model(_Model): + def __init__(self, *args, **kwargs): + super(Model, self).__init__(*args, **kwargs) + pre_init.send(self) + + def save(self, *args, **kwargs): + pk_value = self._pk if self._meta.primary_key else True + created = kwargs.get('force_insert', False) or not bool(pk_value) + pre_save.send(self, created=created) + ret = super(Model, self).save(*args, **kwargs) + post_save.send(self, created=created) + return ret + + def delete_instance(self, *args, **kwargs): + pre_delete.send(self) + ret = super(Model, self).delete_instance(*args, **kwargs) + post_delete.send(self) + return ret diff --git a/python3.10libs/playhouse/sqlcipher_ext.py b/python3.10libs/playhouse/sqlcipher_ext.py new file mode 100644 index 0000000..66558d0 --- /dev/null +++ b/python3.10libs/playhouse/sqlcipher_ext.py @@ -0,0 +1,106 @@ +""" +Peewee integration with pysqlcipher. + +Project page: https://github.com/leapcode/pysqlcipher/ + +**WARNING!!! EXPERIMENTAL!!!** + +* Although this extention's code is short, it has not been properly + peer-reviewed yet and may have introduced vulnerabilities. + +Also note that this code relies on pysqlcipher and sqlcipher, and +the code there might have vulnerabilities as well, but since these +are widely used crypto modules, we can expect "short zero days" there. + +Example usage: + + from peewee.playground.ciphersql_ext import SqlCipherDatabase + db = SqlCipherDatabase('/path/to/my.db', passphrase="don'tuseme4real") + +* `passphrase`: should be "long enough". + Note that *length beats vocabulary* (much exponential), and even + a lowercase-only passphrase like easytorememberyethardforotherstoguess + packs more noise than 8 random printable characters and *can* be memorized. + +When opening an existing database, passphrase should be the one used when the +database was created. If the passphrase is incorrect, an exception will only be +raised **when you access the database**. + +If you need to ask for an interactive passphrase, here's example code you can +put after the `db = ...` line: + + try: # Just access the database so that it checks the encryption. + db.get_tables() + # We're looking for a DatabaseError with a specific error message. + except peewee.DatabaseError as e: + # Check whether the message *means* "passphrase is wrong" + if e.args[0] == 'file is encrypted or is not a database': + raise Exception('Developer should Prompt user for passphrase ' + 'again.') + else: + # A different DatabaseError. Raise it. + raise e + +See a more elaborate example with this code at +https://gist.github.com/thedod/11048875 +""" +import datetime +import decimal +import sys + +from peewee import * +from playhouse.sqlite_ext import SqliteExtDatabase +if sys.version_info[0] != 3: + from pysqlcipher import dbapi2 as sqlcipher +else: + try: + from sqlcipher3 import dbapi2 as sqlcipher + except ImportError: + from pysqlcipher3 import dbapi2 as sqlcipher + +sqlcipher.register_adapter(decimal.Decimal, str) +sqlcipher.register_adapter(datetime.date, str) +sqlcipher.register_adapter(datetime.time, str) +__sqlcipher_version__ = sqlcipher.sqlite_version_info + + +class _SqlCipherDatabase(object): + server_version = __sqlcipher_version__ + + def _connect(self): + params = dict(self.connect_params) + passphrase = params.pop('passphrase', '').replace("'", "''") + + conn = sqlcipher.connect(self.database, isolation_level=None, **params) + try: + if passphrase: + conn.execute("PRAGMA key='%s'" % passphrase) + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def set_passphrase(self, passphrase): + if not self.is_closed(): + raise ImproperlyConfigured('Cannot set passphrase when database ' + 'is open. To change passphrase of an ' + 'open database use the rekey() method.') + + self.connect_params['passphrase'] = passphrase + + def rekey(self, passphrase): + if self.is_closed(): + self.connect() + + self.execute_sql("PRAGMA rekey='%s'" % passphrase.replace("'", "''")) + self.connect_params['passphrase'] = passphrase + return True + + +class SqlCipherDatabase(_SqlCipherDatabase, SqliteDatabase): + pass + + +class SqlCipherExtDatabase(_SqlCipherDatabase, SqliteExtDatabase): + pass diff --git a/python3.10libs/playhouse/sqlite_changelog.py b/python3.10libs/playhouse/sqlite_changelog.py new file mode 100644 index 0000000..f793ed2 --- /dev/null +++ b/python3.10libs/playhouse/sqlite_changelog.py @@ -0,0 +1,128 @@ +from peewee import * +from playhouse.sqlite_ext import JSONField + + +class BaseChangeLog(Model): + timestamp = DateTimeField(constraints=[SQL('DEFAULT CURRENT_TIMESTAMP')]) + action = TextField() + table = TextField() + primary_key = IntegerField() + changes = JSONField() + + +class ChangeLog(object): + # Model class that will serve as the base for the changelog. This model + # will be subclassed and mapped to your application database. + base_model = BaseChangeLog + + # Template for the triggers that handle updating the changelog table. + # table: table name + # action: insert / update / delete + # new_old: NEW or OLD (OLD is for DELETE) + # primary_key: table primary key column name + # column_array: output of build_column_array() + # change_table: changelog table name + template = """CREATE TRIGGER IF NOT EXISTS %(table)s_changes_%(action)s + AFTER %(action)s ON %(table)s + BEGIN + INSERT INTO %(change_table)s + ("action", "table", "primary_key", "changes") + SELECT + '%(action)s', '%(table)s', %(new_old)s."%(primary_key)s", "changes" + FROM ( + SELECT json_group_object( + col, + json_array( + case when json_valid("oldval") then json("oldval") + else "oldval" end, + case when json_valid("newval") then json("newval") + else "newval" end) + ) AS "changes" + FROM ( + SELECT json_extract(value, '$[0]') as "col", + json_extract(value, '$[1]') as "oldval", + json_extract(value, '$[2]') as "newval" + FROM json_each(json_array(%(column_array)s)) + WHERE "oldval" IS NOT "newval" + ) + ); + END;""" + + drop_template = 'DROP TRIGGER IF EXISTS %(table)s_changes_%(action)s' + + _actions = ('INSERT', 'UPDATE', 'DELETE') + + def __init__(self, db, table_name='changelog'): + self.db = db + self.table_name = table_name + + def _build_column_array(self, model, use_old, use_new, skip_fields=None): + # Builds a list of SQL expressions for each field we are tracking. This + # is used as the data source for change tracking in our trigger. + col_array = [] + for field in model._meta.sorted_fields: + if field.primary_key: + continue + + if skip_fields is not None and field.name in skip_fields: + continue + + column = field.column_name + new = 'NULL' if not use_new else 'NEW."%s"' % column + old = 'NULL' if not use_old else 'OLD."%s"' % column + + if isinstance(field, JSONField): + # Ensure that values are cast to JSON so that the serialization + # is preserved when calculating the old / new. + if use_old: old = 'json(%s)' % old + if use_new: new = 'json(%s)' % new + + col_array.append("json_array('%s', %s, %s)" % (column, old, new)) + + return ', '.join(col_array) + + def trigger_sql(self, model, action, skip_fields=None): + assert action in self._actions + use_old = action != 'INSERT' + use_new = action != 'DELETE' + cols = self._build_column_array(model, use_old, use_new, skip_fields) + return self.template % { + 'table': model._meta.table_name, + 'action': action, + 'new_old': 'NEW' if action != 'DELETE' else 'OLD', + 'primary_key': model._meta.primary_key.column_name, + 'column_array': cols, + 'change_table': self.table_name} + + def drop_trigger_sql(self, model, action): + assert action in self._actions + return self.drop_template % { + 'table': model._meta.table_name, + 'action': action} + + @property + def model(self): + if not hasattr(self, '_changelog_model'): + class ChangeLog(self.base_model): + class Meta: + database = self.db + table_name = self.table_name + self._changelog_model = ChangeLog + + return self._changelog_model + + def install(self, model, skip_fields=None, drop=True, insert=True, + update=True, delete=True, create_table=True): + ChangeLog = self.model + if create_table: + ChangeLog.create_table() + + actions = list(zip((insert, update, delete), self._actions)) + if drop: + for _, action in actions: + self.db.execute_sql(self.drop_trigger_sql(model, action)) + + for enabled, action in actions: + if enabled: + sql = self.trigger_sql(model, action, skip_fields) + self.db.execute_sql(sql) diff --git a/python3.10libs/playhouse/sqlite_ext.py b/python3.10libs/playhouse/sqlite_ext.py new file mode 100644 index 0000000..dbaf5ee --- /dev/null +++ b/python3.10libs/playhouse/sqlite_ext.py @@ -0,0 +1,1359 @@ +import json +import math +import re +import struct +import sys + +from peewee import * +from peewee import ColumnBase +from peewee import EnclosedNodeList +from peewee import Entity +from peewee import Expression +from peewee import Insert +from peewee import Node +from peewee import NodeList +from peewee import OP +from peewee import VirtualField +from peewee import merge_dict +from peewee import sqlite3 +try: + from playhouse._sqlite_ext import ( + backup, + backup_to_file, + Blob, + ConnectionHelper, + register_bloomfilter, + register_hash_functions, + register_rank_functions, + sqlite_get_db_status, + sqlite_get_status, + TableFunction, + ZeroBlob, + ) + CYTHON_SQLITE_EXTENSIONS = True +except ImportError: + CYTHON_SQLITE_EXTENSIONS = False + + +if sys.version_info[0] == 3: + basestring = str + + +FTS3_MATCHINFO = 'pcx' +FTS4_MATCHINFO = 'pcnalx' +if sqlite3 is not None: + FTS_VERSION = 4 if sqlite3.sqlite_version_info[:3] >= (3, 7, 4) else 3 +else: + FTS_VERSION = 3 + +FTS5_MIN_SQLITE_VERSION = (3, 9, 0) + + +class RowIDField(AutoField): + auto_increment = True + column_name = name = required_name = 'rowid' + + def bind(self, model, name, *args): + if name != self.required_name: + raise ValueError('%s must be named "%s".' % + (type(self), self.required_name)) + super(RowIDField, self).bind(model, name, *args) + + +class DocIDField(RowIDField): + column_name = name = required_name = 'docid' + + +class AutoIncrementField(AutoField): + def ddl(self, ctx): + node_list = super(AutoIncrementField, self).ddl(ctx) + return NodeList((node_list, SQL('AUTOINCREMENT'))) + + +class TDecimalField(DecimalField): + field_type = 'TEXT' + def get_modifiers(self): pass + + +class JSONPath(ColumnBase): + def __init__(self, field, path=None): + super(JSONPath, self).__init__() + self._field = field + self._path = path or () + + @property + def path(self): + return Value('$%s' % ''.join(self._path)) + + def __getitem__(self, idx): + if isinstance(idx, int) or idx == '#': + item = '[%s]' % idx + else: + item = '.%s' % idx + return JSONPath(self._field, self._path + (item,)) + + def append(self, value, as_json=None): + if as_json or isinstance(value, (list, dict)): + value = fn.json(self._field._json_dumps(value)) + return fn.json_set(self._field, self['#'].path, value) + + def _json_operation(self, func, value, as_json=None): + if as_json or isinstance(value, (list, dict)): + value = fn.json(self._field._json_dumps(value)) + return func(self._field, self.path, value) + + def insert(self, value, as_json=None): + return self._json_operation(fn.json_insert, value, as_json) + + def set(self, value, as_json=None): + return self._json_operation(fn.json_set, value, as_json) + + def replace(self, value, as_json=None): + return self._json_operation(fn.json_replace, value, as_json) + + def update(self, value): + return self.set(fn.json_patch(self, self._field._json_dumps(value))) + + def remove(self): + return fn.json_remove(self._field, self.path) + + def json_type(self): + return fn.json_type(self._field, self.path) + + def length(self): + return fn.json_array_length(self._field, self.path) + + def children(self): + return fn.json_each(self._field, self.path) + + def tree(self): + return fn.json_tree(self._field, self.path) + + def __sql__(self, ctx): + return ctx.sql(fn.json_extract(self._field, self.path) + if self._path else self._field) + + +class JSONField(TextField): + field_type = 'JSON' + unpack = False + + def __init__(self, json_dumps=None, json_loads=None, **kwargs): + self._json_dumps = json_dumps or json.dumps + self._json_loads = json_loads or json.loads + super(JSONField, self).__init__(**kwargs) + + def python_value(self, value): + if value is not None: + try: + return self._json_loads(value) + except (TypeError, ValueError): + return value + + def db_value(self, value): + if value is not None: + if not isinstance(value, Node): + value = fn.json(self._json_dumps(value)) + return value + + def _e(op): + def inner(self, rhs): + if isinstance(rhs, (list, dict)): + rhs = Value(rhs, converter=self.db_value, unpack=False) + return Expression(self, op, rhs) + return inner + __eq__ = _e(OP.EQ) + __ne__ = _e(OP.NE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __hash__ = Field.__hash__ + + def __getitem__(self, item): + return JSONPath(self)[item] + + def extract(self, *paths): + paths = [Value(p, converter=False) for p in paths] + return fn.json_extract(self, *paths) + def extract_json(self, path): + return Expression(self, '->', Value(path, converter=False)) + def extract_text(self, path): + return Expression(self, '->>', Value(path, converter=False)) + + def append(self, value, as_json=None): + return JSONPath(self).append(value, as_json) + + def insert(self, value, as_json=None): + return JSONPath(self).insert(value, as_json) + + def set(self, value, as_json=None): + return JSONPath(self).set(value, as_json) + + def replace(self, value, as_json=None): + return JSONPath(self).replace(value, as_json) + + def update(self, data): + return JSONPath(self).update(data) + + def remove(self, *paths): + if not paths: + return JSONPath(self).remove() + return fn.json_remove(self, *paths) + + def json_type(self): + return fn.json_type(self) + + def length(self, path=None): + args = (self, path) if path else (self,) + return fn.json_array_length(*args) + + def children(self): + """ + Schema of `json_each` and `json_tree`: + + key, + value, + type TEXT (object, array, string, etc), + atom (value for primitive/scalar types, NULL for array and object) + id INTEGER (unique identifier for element) + parent INTEGER (unique identifier of parent element or NULL) + fullkey TEXT (full path describing element) + path TEXT (path to the container of the current element) + json JSON hidden (1st input parameter to function) + root TEXT hidden (2nd input parameter, path at which to start) + """ + return fn.json_each(self) + + def tree(self): + return fn.json_tree(self) + + +class SearchField(Field): + def __init__(self, unindexed=False, column_name=None, **k): + if k: + raise ValueError('SearchField does not accept these keyword ' + 'arguments: %s.' % sorted(k)) + super(SearchField, self).__init__(unindexed=unindexed, + column_name=column_name, null=True) + + def match(self, term): + return match(self, term) + + @property + def fts_column_index(self): + if not hasattr(self, '_fts_column_index'): + search_fields = [f.name for f in self.model._meta.sorted_fields + if isinstance(f, SearchField)] + self._fts_column_index = search_fields.index(self.name) + return self._fts_column_index + + def highlight(self, left, right): + column_idx = self.fts_column_index + return fn.highlight(self.model._meta.entity, column_idx, left, right) + + def snippet(self, left, right, over_length='...', max_tokens=16): + if not (0 < max_tokens < 65): + raise ValueError('max_tokens must be between 1 and 64 (inclusive)') + column_idx = self.fts_column_index + return fn.snippet(self.model._meta.entity, column_idx, left, right, + over_length, max_tokens) + + +class VirtualTableSchemaManager(SchemaManager): + def _create_virtual_table(self, safe=True, **options): + options = self.model.clean_options( + merge_dict(self.model._meta.options, options)) + + # Structure: + # CREATE VIRTUAL TABLE + # USING + # ([prefix_arguments, ...] fields, ... [arguments, ...], [options...]) + ctx = self._create_context() + ctx.literal('CREATE VIRTUAL TABLE ') + if safe: + ctx.literal('IF NOT EXISTS ') + (ctx + .sql(self.model) + .literal(' USING ')) + + ext_module = self.model._meta.extension_module + if isinstance(ext_module, Node): + return ctx.sql(ext_module) + + ctx.sql(SQL(ext_module)).literal(' ') + arguments = [] + meta = self.model._meta + + if meta.prefix_arguments: + arguments.extend([SQL(a) for a in meta.prefix_arguments]) + + # Constraints, data-types, foreign and primary keys are all omitted. + for field in meta.sorted_fields: + if isinstance(field, (RowIDField)) or field._hidden: + continue + field_def = [Entity(field.column_name)] + if field.unindexed: + field_def.append(SQL('UNINDEXED')) + arguments.append(NodeList(field_def)) + + if meta.arguments: + arguments.extend([SQL(a) for a in meta.arguments]) + + if options: + arguments.extend(self._create_table_option_sql(options)) + return ctx.sql(EnclosedNodeList(arguments)) + + def _create_table(self, safe=True, **options): + if issubclass(self.model, VirtualModel): + return self._create_virtual_table(safe, **options) + + return super(VirtualTableSchemaManager, self)._create_table( + safe, **options) + + +class VirtualModel(Model): + class Meta: + arguments = None + extension_module = None + prefix_arguments = None + primary_key = False + schema_manager_class = VirtualTableSchemaManager + + @classmethod + def clean_options(cls, options): + return options + + +class BaseFTSModel(VirtualModel): + @classmethod + def clean_options(cls, options): + content = options.get('content') + prefix = options.get('prefix') + tokenize = options.get('tokenize') + + if isinstance(content, basestring) and content == '': + # Special-case content-less full-text search tables. + options['content'] = "''" + elif isinstance(content, Field): + # Special-case to ensure fields are fully-qualified. + options['content'] = Entity(content.model._meta.table_name, + content.column_name) + + if prefix: + if isinstance(prefix, (list, tuple)): + prefix = ','.join([str(i) for i in prefix]) + options['prefix'] = "'%s'" % prefix.strip("' ") + + if tokenize and cls._meta.extension_module.lower() == 'fts5': + # Tokenizers need to be in quoted string for FTS5, but not for FTS3 + # or FTS4. + options['tokenize'] = '"%s"' % tokenize + + return options + + +class FTSModel(BaseFTSModel): + """ + VirtualModel class for creating tables that use either the FTS3 or FTS4 + search extensions. Peewee automatically determines which version of the + FTS extension is supported and will use FTS4 if possible. + """ + # FTS3/4 uses "docid" in the same way a normal table uses "rowid". + docid = DocIDField() + + class Meta: + extension_module = 'FTS%s' % FTS_VERSION + + @classmethod + def _fts_cmd(cls, cmd): + tbl = cls._meta.table_name + res = cls._meta.database.execute_sql( + "INSERT INTO %s(%s) VALUES('%s');" % (tbl, tbl, cmd)) + return res.fetchone() + + @classmethod + def optimize(cls): + return cls._fts_cmd('optimize') + + @classmethod + def rebuild(cls): + return cls._fts_cmd('rebuild') + + @classmethod + def integrity_check(cls): + return cls._fts_cmd('integrity-check') + + @classmethod + def merge(cls, blocks=200, segments=8): + return cls._fts_cmd('merge=%s,%s' % (blocks, segments)) + + @classmethod + def automerge(cls, state=True): + return cls._fts_cmd('automerge=%s' % (state and '1' or '0')) + + @classmethod + def match(cls, term): + """ + Generate a `MATCH` expression appropriate for searching this table. + """ + return match(cls._meta.entity, term) + + @classmethod + def rank(cls, *weights): + matchinfo = fn.matchinfo(cls._meta.entity, FTS3_MATCHINFO) + return fn.fts_rank(matchinfo, *weights) + + @classmethod + def bm25(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_bm25(match_info, *weights) + + @classmethod + def bm25f(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_bm25f(match_info, *weights) + + @classmethod + def lucene(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_lucene(match_info, *weights) + + @classmethod + def _search(cls, term, weights, with_score, score_alias, score_fn, + explicit_ordering): + if not weights: + rank = score_fn() + elif isinstance(weights, dict): + weight_args = [] + for field in cls._meta.sorted_fields: + # Attempt to get the specified weight of the field by looking + # it up using it's field instance followed by name. + field_weight = weights.get(field, weights.get(field.name, 1.0)) + weight_args.append(field_weight) + rank = score_fn(*weight_args) + else: + rank = score_fn(*weights) + + selection = () + order_by = rank + if with_score: + selection = (cls, rank.alias(score_alias)) + if with_score and not explicit_ordering: + order_by = SQL(score_alias) + + return (cls + .select(*selection) + .where(cls.match(term)) + .order_by(order_by)) + + @classmethod + def search(cls, term, weights=None, with_score=False, score_alias='score', + explicit_ordering=False): + """Full-text search using selected `term`.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.rank, + explicit_ordering) + + @classmethod + def search_bm25(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.bm25, + explicit_ordering) + + @classmethod + def search_bm25f(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.bm25f, + explicit_ordering) + + @classmethod + def search_lucene(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.lucene, + explicit_ordering) + + +_alphabet = 'abcdefghijklmnopqrstuvwxyz' +_alphanum = (set('\t ,"(){}*:_+0123456789') | + set(_alphabet) | + set(_alphabet.upper()) | + set((chr(26),))) +_invalid_ascii = set(chr(p) for p in range(128) if chr(p) not in _alphanum) +del _alphabet +del _alphanum +_quote_re = re.compile(r'(?:[^\s"]|"(?:\\.|[^"])*")+') + + +class FTS5Model(BaseFTSModel): + """ + Requires SQLite >= 3.9.0. + + Table options: + + content: table name of external content, or empty string for "contentless" + content_rowid: column name of external content primary key + prefix: integer(s). Ex: '2' or '2 3 4' + tokenize: porter, unicode61, ascii. Ex: 'porter unicode61' + + The unicode tokenizer supports the following parameters: + + * remove_diacritics (1 or 0, default is 1) + * tokenchars (string of characters, e.g. '-_' + * separators (string of characters) + + Parameters are passed as alternating parameter name and value, so: + + {'tokenize': "unicode61 remove_diacritics 0 tokenchars '-_'"} + + Content-less tables: + + If you don't need the full-text content in it's original form, you can + specify a content-less table. Searches and auxiliary functions will work + as usual, but the only values returned when SELECT-ing can be rowid. Also + content-less tables do not support UPDATE or DELETE. + + External content tables: + + You can set up triggers to sync these, e.g. + + -- Create a table. And an external content fts5 table to index it. + CREATE TABLE tbl(a INTEGER PRIMARY KEY, b); + CREATE VIRTUAL TABLE ft USING fts5(b, content='tbl', content_rowid='a'); + + -- Triggers to keep the FTS index up to date. + CREATE TRIGGER tbl_ai AFTER INSERT ON tbl BEGIN + INSERT INTO ft(rowid, b) VALUES (new.a, new.b); + END; + CREATE TRIGGER tbl_ad AFTER DELETE ON tbl BEGIN + INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); + END; + CREATE TRIGGER tbl_au AFTER UPDATE ON tbl BEGIN + INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); + INSERT INTO ft(rowid, b) VALUES (new.a, new.b); + END; + + Built-in auxiliary functions: + + * bm25(tbl[, weight_0, ... weight_n]) + * highlight(tbl, col_idx, prefix, suffix) + * snippet(tbl, col_idx, prefix, suffix, ?, max_tokens) + """ + # FTS5 does not support declared primary keys, but we can use the + # implicit rowid. + rowid = RowIDField() + + class Meta: + extension_module = 'fts5' + + _error_messages = { + 'field_type': ('Besides the implicit `rowid` column, all columns must ' + 'be instances of SearchField'), + 'index': 'Secondary indexes are not supported for FTS5 models', + 'pk': 'FTS5 models must use the default `rowid` primary key', + } + + @classmethod + def validate_model(cls): + # Perform FTS5-specific validation and options post-processing. + if cls._meta.primary_key.name != 'rowid': + raise ImproperlyConfigured(cls._error_messages['pk']) + for field in cls._meta.fields.values(): + if not isinstance(field, (SearchField, RowIDField)): + raise ImproperlyConfigured(cls._error_messages['field_type']) + if cls._meta.indexes: + raise ImproperlyConfigured(cls._error_messages['index']) + + @classmethod + def fts5_installed(cls): + if sqlite3.sqlite_version_info[:3] < FTS5_MIN_SQLITE_VERSION: + return False + + # Test in-memory DB to determine if the FTS5 extension is installed. + tmp_db = sqlite3.connect(':memory:') + try: + tmp_db.execute('CREATE VIRTUAL TABLE fts5test USING fts5 (data);') + except: + try: + tmp_db.enable_load_extension(True) + tmp_db.load_extension('fts5') + except: + return False + else: + cls._meta.database.load_extension('fts5') + finally: + tmp_db.close() + + return True + + @staticmethod + def validate_query(query): + """ + Simple helper function to indicate whether a search query is a + valid FTS5 query. Note: this simply looks at the characters being + used, and is not guaranteed to catch all problematic queries. + """ + tokens = _quote_re.findall(query) + for token in tokens: + if token.startswith('"') and token.endswith('"'): + continue + if set(token) & _invalid_ascii: + return False + return True + + @staticmethod + def clean_query(query, replace=chr(26)): + """ + Clean a query of invalid tokens. + """ + accum = [] + any_invalid = False + tokens = _quote_re.findall(query) + for token in tokens: + if token.startswith('"') and token.endswith('"'): + accum.append(token) + continue + token_set = set(token) + invalid_for_token = token_set & _invalid_ascii + if invalid_for_token: + any_invalid = True + for c in invalid_for_token: + token = token.replace(c, replace) + accum.append(token) + + if any_invalid: + return ' '.join(accum) + return query + + @classmethod + def match(cls, term): + """ + Generate a `MATCH` expression appropriate for searching this table. + """ + return match(cls._meta.entity, term) + + @classmethod + def rank(cls, *args): + return cls.bm25(*args) if args else SQL('rank') + + @classmethod + def bm25(cls, *weights): + return fn.bm25(cls._meta.entity, *weights) + + @classmethod + def search(cls, term, weights=None, with_score=False, score_alias='score', + explicit_ordering=False): + """Full-text search using selected `term`.""" + return cls.search_bm25( + FTS5Model.clean_query(term), + weights, + with_score, + score_alias, + explicit_ordering) + + @classmethod + def search_bm25(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search using selected `term`.""" + if not weights: + rank = SQL('rank') + elif isinstance(weights, dict): + weight_args = [] + for field in cls._meta.sorted_fields: + if isinstance(field, SearchField) and not field.unindexed: + weight_args.append( + weights.get(field, weights.get(field.name, 1.0))) + rank = fn.bm25(cls._meta.entity, *weight_args) + else: + rank = fn.bm25(cls._meta.entity, *weights) + + selection = () + order_by = rank + if with_score: + selection = (cls, rank.alias(score_alias)) + if with_score and not explicit_ordering: + order_by = SQL(score_alias) + + return (cls + .select(*selection) + .where(cls.match(FTS5Model.clean_query(term))) + .order_by(order_by)) + + @classmethod + def _fts_cmd_sql(cls, cmd, **extra_params): + tbl = cls._meta.entity + columns = [tbl] + values = [cmd] + for key, value in extra_params.items(): + columns.append(Entity(key)) + values.append(value) + + return NodeList(( + SQL('INSERT INTO'), + cls._meta.entity, + EnclosedNodeList(columns), + SQL('VALUES'), + EnclosedNodeList(values))) + + @classmethod + def _fts_cmd(cls, cmd, **extra_params): + query = cls._fts_cmd_sql(cmd, **extra_params) + return cls._meta.database.execute(query) + + @classmethod + def automerge(cls, level): + if not (0 <= level <= 16): + raise ValueError('level must be between 0 and 16') + return cls._fts_cmd('automerge', rank=level) + + @classmethod + def merge(cls, npages): + return cls._fts_cmd('merge', rank=npages) + + @classmethod + def optimize(cls): + return cls._fts_cmd('optimize') + + @classmethod + def rebuild(cls): + return cls._fts_cmd('rebuild') + + @classmethod + def set_pgsz(cls, pgsz): + return cls._fts_cmd('pgsz', rank=pgsz) + + @classmethod + def set_rank(cls, rank_expression): + return cls._fts_cmd('rank', rank=rank_expression) + + @classmethod + def delete_all(cls): + return cls._fts_cmd('delete-all') + + @classmethod + def integrity_check(cls, rank=0): + return cls._fts_cmd('integrity-check', rank=rank) + + @classmethod + def VocabModel(cls, table_type='row', table=None): + if table_type not in ('row', 'col', 'instance'): + raise ValueError('table_type must be either "row", "col" or ' + '"instance".') + + attr = '_vocab_model_%s' % table_type + + if not hasattr(cls, attr): + class Meta: + database = cls._meta.database + table_name = table or cls._meta.table_name + '_v' + extension_module = fn.fts5vocab( + cls._meta.entity, + SQL(table_type)) + + attrs = { + 'term': VirtualField(TextField), + 'doc': IntegerField(), + 'cnt': IntegerField(), + 'rowid': RowIDField(), + 'Meta': Meta, + } + if table_type == 'col': + attrs['col'] = VirtualField(TextField) + elif table_type == 'instance': + attrs['offset'] = VirtualField(IntegerField) + + class_name = '%sVocab' % cls.__name__ + setattr(cls, attr, type(class_name, (VirtualModel,), attrs)) + + return getattr(cls, attr) + + +def ClosureTable(model_class, foreign_key=None, referencing_class=None, + referencing_key=None): + """Model factory for the transitive closure extension.""" + if referencing_class is None: + referencing_class = model_class + + if foreign_key is None: + for field_obj in model_class._meta.refs: + if field_obj.rel_model is model_class: + foreign_key = field_obj + break + else: + raise ValueError('Unable to find self-referential foreign key.') + + source_key = model_class._meta.primary_key + if referencing_key is None: + referencing_key = source_key + + class BaseClosureTable(VirtualModel): + depth = VirtualField(IntegerField) + id = VirtualField(IntegerField) + idcolumn = VirtualField(TextField) + parentcolumn = VirtualField(TextField) + root = VirtualField(IntegerField) + tablename = VirtualField(TextField) + + class Meta: + extension_module = 'transitive_closure' + + @classmethod + def descendants(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(source_key == cls.id)) + .where(cls.root == node) + .objects()) + if depth is not None: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def ancestors(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(source_key == cls.root)) + .where(cls.id == node) + .objects()) + if depth: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def siblings(cls, node, include_node=False): + if referencing_class is model_class: + # self-join + fk_value = node.__data__.get(foreign_key.name) + query = model_class.select().where(foreign_key == fk_value) + else: + # siblings as given in reference_class + siblings = (referencing_class + .select(referencing_key) + .join(cls, on=(foreign_key == cls.root)) + .where((cls.id == node) & (cls.depth == 1))) + + # the according models + query = (model_class + .select() + .where(source_key << siblings) + .objects()) + + if not include_node: + query = query.where(source_key != node) + + return query + + class Meta: + database = referencing_class._meta.database + options = { + 'tablename': referencing_class._meta.table_name, + 'idcolumn': referencing_key.column_name, + 'parentcolumn': foreign_key.column_name} + primary_key = False + + name = '%sClosure' % model_class.__name__ + return type(name, (BaseClosureTable,), {'Meta': Meta}) + + +class LSMTable(VirtualModel): + class Meta: + extension_module = 'lsm1' + filename = None + + @classmethod + def clean_options(cls, options): + filename = cls._meta.filename + if not filename: + raise ValueError('LSM1 extension requires that you specify a ' + 'filename for the LSM database.') + else: + if len(filename) >= 2 and filename[0] != '"': + filename = '"%s"' % filename + if not cls._meta.primary_key: + raise ValueError('LSM1 models must specify a primary-key field.') + + key = cls._meta.primary_key + if isinstance(key, AutoField): + raise ValueError('LSM1 models must explicitly declare a primary ' + 'key field.') + if not isinstance(key, (TextField, BlobField, IntegerField)): + raise ValueError('LSM1 key must be a TextField, BlobField, or ' + 'IntegerField.') + key._hidden = True + if isinstance(key, IntegerField): + data_type = 'UINT' + elif isinstance(key, BlobField): + data_type = 'BLOB' + else: + data_type = 'TEXT' + cls._meta.prefix_arguments = [filename, '"%s"' % key.name, data_type] + + # Does the key map to a scalar value, or a tuple of values? + if len(cls._meta.sorted_fields) == 2: + cls._meta._value_field = cls._meta.sorted_fields[1] + else: + cls._meta._value_field = None + + return options + + @classmethod + def load_extension(cls, path='lsm.so'): + cls._meta.database.load_extension(path) + + @staticmethod + def slice_to_expr(key, idx): + if idx.start is not None and idx.stop is not None: + return key.between(idx.start, idx.stop) + elif idx.start is not None: + return key >= idx.start + elif idx.stop is not None: + return key <= idx.stop + + @staticmethod + def _apply_lookup_to_query(query, key, lookup): + if isinstance(lookup, slice): + expr = LSMTable.slice_to_expr(key, lookup) + if expr is not None: + query = query.where(expr) + return query, False + elif isinstance(lookup, Expression): + return query.where(lookup), False + else: + return query.where(key == lookup), True + + @classmethod + def get_by_id(cls, pk): + query, is_single = cls._apply_lookup_to_query( + cls.select().namedtuples(), + cls._meta.primary_key, + pk) + + if is_single: + row = query.get() + return row[1] if cls._meta._value_field is not None else row + else: + return query + + @classmethod + def set_by_id(cls, key, value): + if cls._meta._value_field is not None: + data = {cls._meta._value_field: value} + elif isinstance(value, tuple): + data = {} + for field, fval in zip(cls._meta.sorted_fields[1:], value): + data[field] = fval + elif isinstance(value, dict): + data = value + elif isinstance(value, cls): + data = value.__dict__ + data[cls._meta.primary_key] = key + cls.replace(data).execute() + + @classmethod + def delete_by_id(cls, pk): + query, is_single = cls._apply_lookup_to_query( + cls.delete(), + cls._meta.primary_key, + pk) + return query.execute() + + +OP.MATCH = 'MATCH' + +def _sqlite_regexp(regex, value): + return re.search(regex, value) is not None + + +class SqliteExtDatabase(SqliteDatabase): + def __init__(self, database, c_extensions=None, rank_functions=True, + hash_functions=False, regexp_function=False, + bloomfilter=False, json_contains=False, *args, **kwargs): + super(SqliteExtDatabase, self).__init__(database, *args, **kwargs) + self._row_factory = None + + if c_extensions and not CYTHON_SQLITE_EXTENSIONS: + raise ImproperlyConfigured('SqliteExtDatabase initialized with ' + 'C extensions, but shared library was ' + 'not found!') + prefer_c = CYTHON_SQLITE_EXTENSIONS and (c_extensions is not False) + if rank_functions: + if prefer_c: + register_rank_functions(self) + else: + self.register_function(bm25, 'fts_bm25') + self.register_function(rank, 'fts_rank') + self.register_function(bm25, 'fts_bm25f') # Fall back to bm25. + self.register_function(bm25, 'fts_lucene') + if hash_functions: + if not prefer_c: + raise ValueError('C extension required to register hash ' + 'functions.') + register_hash_functions(self) + if regexp_function: + self.register_function(_sqlite_regexp, 'regexp', 2) + if bloomfilter: + if not prefer_c: + raise ValueError('C extension required to use bloomfilter.') + register_bloomfilter(self) + if json_contains: + self.register_function(_json_contains, 'json_contains') + + self._c_extensions = prefer_c + + def _add_conn_hooks(self, conn): + super(SqliteExtDatabase, self)._add_conn_hooks(conn) + if self._row_factory: + conn.row_factory = self._row_factory + + def row_factory(self, fn): + self._row_factory = fn + + +if CYTHON_SQLITE_EXTENSIONS: + SQLITE_STATUS_MEMORY_USED = 0 + SQLITE_STATUS_PAGECACHE_USED = 1 + SQLITE_STATUS_PAGECACHE_OVERFLOW = 2 + SQLITE_STATUS_SCRATCH_USED = 3 + SQLITE_STATUS_SCRATCH_OVERFLOW = 4 + SQLITE_STATUS_MALLOC_SIZE = 5 + SQLITE_STATUS_PARSER_STACK = 6 + SQLITE_STATUS_PAGECACHE_SIZE = 7 + SQLITE_STATUS_SCRATCH_SIZE = 8 + SQLITE_STATUS_MALLOC_COUNT = 9 + SQLITE_DBSTATUS_LOOKASIDE_USED = 0 + SQLITE_DBSTATUS_CACHE_USED = 1 + SQLITE_DBSTATUS_SCHEMA_USED = 2 + SQLITE_DBSTATUS_STMT_USED = 3 + SQLITE_DBSTATUS_LOOKASIDE_HIT = 4 + SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5 + SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6 + SQLITE_DBSTATUS_CACHE_HIT = 7 + SQLITE_DBSTATUS_CACHE_MISS = 8 + SQLITE_DBSTATUS_CACHE_WRITE = 9 + SQLITE_DBSTATUS_DEFERRED_FKS = 10 + #SQLITE_DBSTATUS_CACHE_USED_SHARED = 11 + + def __status__(flag, return_highwater=False): + """ + Expose a sqlite3_status() call for a particular flag as a property of + the Database object. + """ + def getter(self): + result = sqlite_get_status(flag) + return result[1] if return_highwater else result + return property(getter) + + def __dbstatus__(flag, return_highwater=False, return_current=False): + """ + Expose a sqlite3_dbstatus() call for a particular flag as a property of + the Database instance. Unlike sqlite3_status(), the dbstatus properties + pertain to the current connection. + """ + def getter(self): + if self._state.conn is None: + raise ImproperlyConfigured('database connection not opened.') + result = sqlite_get_db_status(self._state.conn, flag) + if return_current: + return result[0] + return result[1] if return_highwater else result + return property(getter) + + class CSqliteExtDatabase(SqliteExtDatabase): + def __init__(self, *args, **kwargs): + self._conn_helper = None + self._commit_hook = self._rollback_hook = self._update_hook = None + self._replace_busy_handler = False + super(CSqliteExtDatabase, self).__init__(*args, **kwargs) + + def init(self, database, replace_busy_handler=False, **kwargs): + super(CSqliteExtDatabase, self).init(database, **kwargs) + self._replace_busy_handler = replace_busy_handler + + def _close(self, conn): + if self._commit_hook: + self._conn_helper.set_commit_hook(None) + if self._rollback_hook: + self._conn_helper.set_rollback_hook(None) + if self._update_hook: + self._conn_helper.set_update_hook(None) + return super(CSqliteExtDatabase, self)._close(conn) + + def _add_conn_hooks(self, conn): + super(CSqliteExtDatabase, self)._add_conn_hooks(conn) + self._conn_helper = ConnectionHelper(conn) + if self._commit_hook is not None: + self._conn_helper.set_commit_hook(self._commit_hook) + if self._rollback_hook is not None: + self._conn_helper.set_rollback_hook(self._rollback_hook) + if self._update_hook is not None: + self._conn_helper.set_update_hook(self._update_hook) + if self._replace_busy_handler: + timeout = self._timeout or 5 + self._conn_helper.set_busy_handler(timeout * 1000) + + def on_commit(self, fn): + self._commit_hook = fn + if not self.is_closed(): + self._conn_helper.set_commit_hook(fn) + return fn + + def on_rollback(self, fn): + self._rollback_hook = fn + if not self.is_closed(): + self._conn_helper.set_rollback_hook(fn) + return fn + + def on_update(self, fn): + self._update_hook = fn + if not self.is_closed(): + self._conn_helper.set_update_hook(fn) + return fn + + def changes(self): + return self._conn_helper.changes() + + @property + def last_insert_rowid(self): + return self._conn_helper.last_insert_rowid() + + @property + def autocommit(self): + return self._conn_helper.autocommit() + + def backup(self, destination, pages=None, name=None, progress=None): + return backup(self.connection(), destination.connection(), + pages=pages, name=name, progress=progress) + + def backup_to_file(self, filename, pages=None, name=None, + progress=None): + return backup_to_file(self.connection(), filename, pages=pages, + name=name, progress=progress) + + def blob_open(self, table, column, rowid, read_only=False): + return Blob(self, table, column, rowid, read_only) + + # Status properties. + memory_used = __status__(SQLITE_STATUS_MEMORY_USED) + malloc_size = __status__(SQLITE_STATUS_MALLOC_SIZE, True) + malloc_count = __status__(SQLITE_STATUS_MALLOC_COUNT) + pagecache_used = __status__(SQLITE_STATUS_PAGECACHE_USED) + pagecache_overflow = __status__(SQLITE_STATUS_PAGECACHE_OVERFLOW) + pagecache_size = __status__(SQLITE_STATUS_PAGECACHE_SIZE, True) + scratch_used = __status__(SQLITE_STATUS_SCRATCH_USED) + scratch_overflow = __status__(SQLITE_STATUS_SCRATCH_OVERFLOW) + scratch_size = __status__(SQLITE_STATUS_SCRATCH_SIZE, True) + + # Connection status properties. + lookaside_used = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_USED) + lookaside_hit = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_HIT, True) + lookaside_miss = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE, + True) + lookaside_miss_full = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL, + True) + cache_used = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED, False, True) + #cache_used_shared = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED_SHARED, + # False, True) + schema_used = __dbstatus__(SQLITE_DBSTATUS_SCHEMA_USED, False, True) + statement_used = __dbstatus__(SQLITE_DBSTATUS_STMT_USED, False, True) + cache_hit = __dbstatus__(SQLITE_DBSTATUS_CACHE_HIT, False, True) + cache_miss = __dbstatus__(SQLITE_DBSTATUS_CACHE_MISS, False, True) + cache_write = __dbstatus__(SQLITE_DBSTATUS_CACHE_WRITE, False, True) + + +def match(lhs, rhs): + return Expression(lhs, OP.MATCH, rhs) + +def _parse_match_info(buf): + # See http://sqlite.org/fts3.html#matchinfo + bufsize = len(buf) # Length in bytes. + return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] + +def get_weights(ncol, raw_weights): + if not raw_weights: + return [1] * ncol + else: + weights = [0] * ncol + for i, weight in enumerate(raw_weights): + weights[i] = weight + return weights + +# Ranking implementation, which parse matchinfo. +def rank(raw_match_info, *raw_weights): + # Handle match_info called w/default args 'pcx' - based on the example rank + # function http://sqlite.org/fts3.html#appendix_a + match_info = _parse_match_info(raw_match_info) + score = 0.0 + + p, c = match_info[:2] + weights = get_weights(c, raw_weights) + + # matchinfo X value corresponds to, for each phrase in the search query, a + # list of 3 values for each column in the search table. + # So if we have a two-phrase search query and three columns of data, the + # following would be the layout: + # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] + # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] + for phrase_num in range(p): + phrase_info_idx = 2 + (phrase_num * c * 3) + for col_num in range(c): + weight = weights[col_num] + if not weight: + continue + + col_idx = phrase_info_idx + (col_num * 3) + + # The idea is that we count the number of times the phrase appears + # in this column of the current row, compared to how many times it + # appears in this column across all rows. The ratio of these values + # provides a rough way to score based on "high value" terms. + row_hits = match_info[col_idx] + all_rows_hits = match_info[col_idx + 1] + if row_hits > 0: + score += weight * (float(row_hits) / all_rows_hits) + + return -score + +# Okapi BM25 ranking implementation (FTS4 only). +def bm25(raw_match_info, *args): + """ + Usage: + + # Format string *must* be pcnalx + # Second parameter to bm25 specifies the index of the column, on + # the table being queries. + bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank + """ + match_info = _parse_match_info(raw_match_info) + K = 1.2 + B = 0.75 + score = 0.0 + + P_O, C_O, N_O, A_O = range(4) # Offsets into the matchinfo buffer. + term_count = match_info[P_O] # n + col_count = match_info[C_O] + total_docs = match_info[N_O] # N + L_O = A_O + col_count + X_O = L_O + col_count + + # Worked example of pcnalx for two columns and two phrases, 100 docs total. + # { + # p = 2 + # c = 2 + # n = 100 + # a0 = 4 -- avg number of tokens for col0, e.g. title + # a1 = 40 -- avg number of tokens for col1, e.g. body + # l0 = 5 -- curr doc has 5 tokens in col0 + # l1 = 30 -- curr doc has 30 tokens in col1 + # + # x000 -- hits this row for phrase0, col0 + # x001 -- hits all rows for phrase0, col0 + # x002 -- rows with phrase0 in col0 at least once + # + # x010 -- hits this row for phrase0, col1 + # x011 -- hits all rows for phrase0, col1 + # x012 -- rows with phrase0 in col1 at least once + # + # x100 -- hits this row for phrase1, col0 + # x101 -- hits all rows for phrase1, col0 + # x102 -- rows with phrase1 in col0 at least once + # + # x110 -- hits this row for phrase1, col1 + # x111 -- hits all rows for phrase1, col1 + # x112 -- rows with phrase1 in col1 at least once + # } + + weights = get_weights(col_count, args) + + for i in range(term_count): + for j in range(col_count): + weight = weights[j] + if weight == 0: + continue + + x = X_O + (3 * (j + i * col_count)) + term_frequency = float(match_info[x]) # f(qi, D) + docs_with_term = float(match_info[x + 2]) # n(qi) + + # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) + idf = math.log( + (total_docs - docs_with_term + 0.5) / + (docs_with_term + 0.5)) + if idf <= 0.0: + idf = 1e-6 + + doc_length = float(match_info[L_O + j]) # |D| + avg_length = float(match_info[A_O + j]) or 1. # avgdl + ratio = doc_length / avg_length + + num = term_frequency * (K + 1.0) + b_part = 1.0 - B + (B * ratio) + denom = term_frequency + (K * b_part) + + pc_score = idf * (num / denom) + score += (pc_score * weight) + + return -score + + +def _json_contains(src_json, obj_json): + stack = [] + try: + stack.append((json.loads(obj_json), json.loads(src_json))) + except: + # Invalid JSON! + return False + + while stack: + obj, src = stack.pop() + if isinstance(src, dict): + if isinstance(obj, dict): + for key in obj: + if key not in src: + return False + stack.append((obj[key], src[key])) + elif isinstance(obj, list): + for item in obj: + if item not in src: + return False + elif obj not in src: + return False + elif isinstance(src, list): + if isinstance(obj, dict): + return False + elif isinstance(obj, list): + try: + for i in range(len(obj)): + stack.append((obj[i], src[i])) + except IndexError: + return False + elif obj not in src: + return False + elif obj != src: + return False + return True diff --git a/python3.10libs/playhouse/sqlite_udf.py b/python3.10libs/playhouse/sqlite_udf.py new file mode 100644 index 0000000..050dc9b --- /dev/null +++ b/python3.10libs/playhouse/sqlite_udf.py @@ -0,0 +1,536 @@ +import datetime +import hashlib +import heapq +import math +import os +import random +import re +import sys +import threading +import zlib +try: + from collections import Counter +except ImportError: + Counter = None +try: + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse + +try: + from playhouse._sqlite_ext import TableFunction +except ImportError: + TableFunction = None + + +SQLITE_DATETIME_FORMATS = ( + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d', + '%H:%M:%S', + '%H:%M:%S.%f', + '%H:%M') + +from peewee import format_date_time + +def format_date_time_sqlite(date_value): + return format_date_time(date_value, SQLITE_DATETIME_FORMATS) + +try: + from playhouse import _sqlite_udf as cython_udf +except ImportError: + cython_udf = None + + +# Group udf by function. +CONTROL_FLOW = 'control_flow' +DATE = 'date' +FILE = 'file' +HELPER = 'helpers' +MATH = 'math' +STRING = 'string' + +AGGREGATE_COLLECTION = {} +TABLE_FUNCTION_COLLECTION = {} +UDF_COLLECTION = {} + + +class synchronized_dict(dict): + def __init__(self, *args, **kwargs): + super(synchronized_dict, self).__init__(*args, **kwargs) + self._lock = threading.Lock() + + def __getitem__(self, key): + with self._lock: + return super(synchronized_dict, self).__getitem__(key) + + def __setitem__(self, key, value): + with self._lock: + return super(synchronized_dict, self).__setitem__(key, value) + + def __delitem__(self, key): + with self._lock: + return super(synchronized_dict, self).__delitem__(key) + + +STATE = synchronized_dict() +SETTINGS = synchronized_dict() + +# Class and function decorators. +def aggregate(*groups): + def decorator(klass): + for group in groups: + AGGREGATE_COLLECTION.setdefault(group, []) + AGGREGATE_COLLECTION[group].append(klass) + return klass + return decorator + +def table_function(*groups): + def decorator(klass): + for group in groups: + TABLE_FUNCTION_COLLECTION.setdefault(group, []) + TABLE_FUNCTION_COLLECTION[group].append(klass) + return klass + return decorator + +def udf(*groups): + def decorator(fn): + for group in groups: + UDF_COLLECTION.setdefault(group, []) + UDF_COLLECTION[group].append(fn) + return fn + return decorator + +# Register aggregates / functions with connection. +def register_aggregate_groups(db, *groups): + seen = set() + for group in groups: + klasses = AGGREGATE_COLLECTION.get(group, ()) + for klass in klasses: + name = getattr(klass, 'name', klass.__name__) + if name not in seen: + seen.add(name) + db.register_aggregate(klass, name) + +def register_table_function_groups(db, *groups): + seen = set() + for group in groups: + klasses = TABLE_FUNCTION_COLLECTION.get(group, ()) + for klass in klasses: + if klass.name not in seen: + seen.add(klass.name) + db.register_table_function(klass) + +def register_udf_groups(db, *groups): + seen = set() + for group in groups: + functions = UDF_COLLECTION.get(group, ()) + for function in functions: + name = function.__name__ + if name not in seen: + seen.add(name) + db.register_function(function, name) + +def register_groups(db, *groups): + register_aggregate_groups(db, *groups) + register_table_function_groups(db, *groups) + register_udf_groups(db, *groups) + +def register_all(db): + register_aggregate_groups(db, *AGGREGATE_COLLECTION) + register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION) + register_udf_groups(db, *UDF_COLLECTION) + + +# Begin actual user-defined functions and aggregates. + +# Scalar functions. +@udf(CONTROL_FLOW) +def if_then_else(cond, truthy, falsey=None): + if cond: + return truthy + return falsey + +@udf(DATE) +def strip_tz(date_str): + date_str = date_str.replace('T', ' ') + tz_idx1 = date_str.find('+') + if tz_idx1 != -1: + return date_str[:tz_idx1] + tz_idx2 = date_str.find('-') + if tz_idx2 > 13: + return date_str[:tz_idx2] + return date_str + +@udf(DATE) +def human_delta(nseconds, glue=', '): + parts = ( + (86400 * 365, 'year'), + (86400 * 30, 'month'), + (86400 * 7, 'week'), + (86400, 'day'), + (3600, 'hour'), + (60, 'minute'), + (1, 'second'), + ) + accum = [] + for offset, name in parts: + val, nseconds = divmod(nseconds, offset) + if val: + suffix = val != 1 and 's' or '' + accum.append('%s %s%s' % (val, name, suffix)) + if not accum: + return '0 seconds' + return glue.join(accum) + +@udf(FILE) +def file_ext(filename): + try: + res = os.path.splitext(filename) + except ValueError: + return None + return res[1] + +@udf(FILE) +def file_read(filename): + try: + with open(filename) as fh: + return fh.read() + except: + pass + +if sys.version_info[0] == 2: + @udf(HELPER) + def gzip(data, compression=9): + return buffer(zlib.compress(data, compression)) + + @udf(HELPER) + def gunzip(data): + return zlib.decompress(data) +else: + @udf(HELPER) + def gzip(data, compression=9): + if isinstance(data, str): + data = bytes(data.encode('raw_unicode_escape')) + return zlib.compress(data, compression) + + @udf(HELPER) + def gunzip(data): + return zlib.decompress(data) + +@udf(HELPER) +def hostname(url): + parse_result = urlparse(url) + if parse_result: + return parse_result.netloc + +@udf(HELPER) +def toggle(key): + key = key.lower() + STATE[key] = ret = not STATE.get(key) + return ret + +@udf(HELPER) +def setting(key, value=None): + if value is None: + return SETTINGS.get(key) + else: + SETTINGS[key] = value + return value + +@udf(HELPER) +def clear_settings(): + SETTINGS.clear() + +@udf(HELPER) +def clear_toggles(): + STATE.clear() + +@udf(MATH) +def randomrange(start, end=None, step=None): + if end is None: + start, end = 0, start + elif step is None: + step = 1 + return random.randrange(start, end, step) + +@udf(MATH) +def gauss_distribution(mean, sigma): + try: + return random.gauss(mean, sigma) + except ValueError: + return None + +@udf(MATH) +def sqrt(n): + try: + return math.sqrt(n) + except ValueError: + return None + +@udf(MATH) +def tonumber(s): + try: + return int(s) + except ValueError: + try: + return float(s) + except: + return None + +@udf(STRING) +def substr_count(haystack, needle): + if not haystack or not needle: + return 0 + return haystack.count(needle) + +@udf(STRING) +def strip_chars(haystack, chars): + return haystack.strip(chars) + +def _hash(constructor, *args): + hash_obj = constructor() + for arg in args: + hash_obj.update(arg) + return hash_obj.hexdigest() + +# Aggregates. +class _heap_agg(object): + def __init__(self): + self.heap = [] + self.ct = 0 + + def process(self, value): + return value + + def step(self, value): + self.ct += 1 + heapq.heappush(self.heap, self.process(value)) + +class _datetime_heap_agg(_heap_agg): + def process(self, value): + return format_date_time_sqlite(value) + +if sys.version_info[:2] == (2, 6): + def total_seconds(td): + return (td.seconds + + (td.days * 86400) + + (td.microseconds / (10.**6))) +else: + total_seconds = lambda td: td.total_seconds() + +@aggregate(DATE) +class mintdiff(_datetime_heap_agg): + def finalize(self): + dtp = min_diff = None + while self.heap: + if min_diff is None: + if dtp is None: + dtp = heapq.heappop(self.heap) + continue + dt = heapq.heappop(self.heap) + diff = dt - dtp + if min_diff is None or min_diff > diff: + min_diff = diff + dtp = dt + if min_diff is not None: + return total_seconds(min_diff) + +@aggregate(DATE) +class avgtdiff(_datetime_heap_agg): + def finalize(self): + if self.ct < 1: + return + elif self.ct == 1: + return 0 + + total = ct = 0 + dtp = None + while self.heap: + if total == 0: + if dtp is None: + dtp = heapq.heappop(self.heap) + continue + + dt = heapq.heappop(self.heap) + diff = dt - dtp + ct += 1 + total += total_seconds(diff) + dtp = dt + + return float(total) / ct + +@aggregate(DATE) +class duration(object): + def __init__(self): + self._min = self._max = None + + def step(self, value): + dt = format_date_time_sqlite(value) + if self._min is None or dt < self._min: + self._min = dt + if self._max is None or dt > self._max: + self._max = dt + + def finalize(self): + if self._min and self._max: + td = (self._max - self._min) + return total_seconds(td) + return None + +@aggregate(MATH) +class mode(object): + if Counter: + def __init__(self): + self.items = Counter() + + def step(self, *args): + self.items.update(args) + + def finalize(self): + if self.items: + return self.items.most_common(1)[0][0] + else: + def __init__(self): + self.items = [] + + def step(self, item): + self.items.append(item) + + def finalize(self): + if self.items: + return max(set(self.items), key=self.items.count) + +@aggregate(MATH) +class minrange(_heap_agg): + def finalize(self): + if self.ct == 0: + return + elif self.ct == 1: + return 0 + + prev = min_diff = None + + while self.heap: + if min_diff is None: + if prev is None: + prev = heapq.heappop(self.heap) + continue + curr = heapq.heappop(self.heap) + diff = curr - prev + if min_diff is None or min_diff > diff: + min_diff = diff + prev = curr + return min_diff + +@aggregate(MATH) +class avgrange(_heap_agg): + def finalize(self): + if self.ct == 0: + return + elif self.ct == 1: + return 0 + + total = ct = 0 + prev = None + while self.heap: + if total == 0: + if prev is None: + prev = heapq.heappop(self.heap) + continue + + curr = heapq.heappop(self.heap) + diff = curr - prev + ct += 1 + total += diff + prev = curr + + return float(total) / ct + +@aggregate(MATH) +class _range(object): + name = 'range' + + def __init__(self): + self._min = self._max = None + + def step(self, value): + if self._min is None or value < self._min: + self._min = value + if self._max is None or value > self._max: + self._max = value + + def finalize(self): + if self._min is not None and self._max is not None: + return self._max - self._min + return None + +@aggregate(MATH) +class stddev(object): + def __init__(self): + self.n = 0 + self.values = [] + def step(self, v): + self.n += 1 + self.values.append(v) + def finalize(self): + if self.n <= 1: + return 0 + mean = sum(self.values) / self.n + return math.sqrt(sum((i - mean) ** 2 for i in self.values) / (self.n - 1)) + + +if cython_udf is not None: + damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist) + levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist) + str_dist = udf(STRING)(cython_udf.str_dist) + median = aggregate(MATH)(cython_udf.median) + + +if TableFunction is not None: + @table_function(STRING) + class RegexSearch(TableFunction): + params = ['regex', 'search_string'] + columns = ['match'] + name = 'regex_search' + + def initialize(self, regex=None, search_string=None): + self._iter = re.finditer(regex, search_string) + + def iterate(self, idx): + return (next(self._iter).group(0),) + + @table_function(DATE) + class DateSeries(TableFunction): + params = ['start', 'stop', 'step_seconds'] + columns = ['date'] + name = 'date_series' + + def initialize(self, start, stop, step_seconds=86400): + self.start = format_date_time_sqlite(start) + self.stop = format_date_time_sqlite(stop) + step_seconds = int(step_seconds) + self.step_seconds = datetime.timedelta(seconds=step_seconds) + + if (self.start.hour == 0 and + self.start.minute == 0 and + self.start.second == 0 and + step_seconds >= 86400): + self.format = '%Y-%m-%d' + elif (self.start.year == 1900 and + self.start.month == 1 and + self.start.day == 1 and + self.stop.year == 1900 and + self.stop.month == 1 and + self.stop.day == 1 and + step_seconds < 86400): + self.format = '%H:%M:%S' + else: + self.format = '%Y-%m-%d %H:%M:%S' + + def iterate(self, idx): + if self.start > self.stop: + raise StopIteration + current = self.start + self.start += self.step_seconds + return (current.strftime(self.format),) diff --git a/python3.10libs/playhouse/sqliteq.py b/python3.10libs/playhouse/sqliteq.py new file mode 100644 index 0000000..5e3be5a --- /dev/null +++ b/python3.10libs/playhouse/sqliteq.py @@ -0,0 +1,328 @@ +import logging +import weakref +from threading import local as thread_local +from threading import Event +from threading import Thread +try: + from Queue import Queue +except ImportError: + from queue import Queue + +try: + import gevent + from gevent import Greenlet as GThread + from gevent.event import Event as GEvent + from gevent.local import local as greenlet_local + from gevent.queue import Queue as GQueue +except ImportError: + GThread = GQueue = GEvent = None + +from peewee import __deprecated__ +from playhouse.sqlite_ext import SqliteExtDatabase + + +logger = logging.getLogger('peewee.sqliteq') + + +class ResultTimeout(Exception): + pass + +class WriterPaused(Exception): + pass + +class ShutdownException(Exception): + pass + + +class AsyncCursor(object): + __slots__ = ('sql', 'params', 'timeout', + '_event', '_cursor', '_exc', '_idx', '_rows', '_ready') + + def __init__(self, event, sql, params, timeout): + self._event = event + self.sql = sql + self.params = params + self.timeout = timeout + self._cursor = self._exc = self._idx = self._rows = None + self._ready = False + + def set_result(self, cursor, exc=None): + self._cursor = cursor + self._exc = exc + self._idx = 0 + self._rows = cursor.fetchall() if exc is None else [] + self._event.set() + return self + + def _wait(self, timeout=None): + timeout = timeout if timeout is not None else self.timeout + if not self._event.wait(timeout=timeout) and timeout: + raise ResultTimeout('results not ready, timed out.') + if self._exc is not None: + raise self._exc + self._ready = True + + def __iter__(self): + if not self._ready: + self._wait() + if self._exc is not None: + raise self._exc + return self + + def next(self): + if not self._ready: + self._wait() + try: + obj = self._rows[self._idx] + except IndexError: + raise StopIteration + else: + self._idx += 1 + return obj + __next__ = next + + @property + def lastrowid(self): + if not self._ready: + self._wait() + return self._cursor.lastrowid + + @property + def rowcount(self): + if not self._ready: + self._wait() + return self._cursor.rowcount + + @property + def description(self): + return self._cursor.description + + def close(self): + self._cursor.close() + + def fetchall(self): + return list(self) # Iterating implies waiting until populated. + + def fetchone(self): + if not self._ready: + self._wait() + try: + return next(self) + except StopIteration: + return None + +SHUTDOWN = StopIteration +PAUSE = object() +UNPAUSE = object() + + +class Writer(object): + __slots__ = ('database', 'queue') + + def __init__(self, database, queue): + self.database = database + self.queue = queue + + def run(self): + conn = self.database.connection() + try: + while True: + try: + if conn is None: # Paused. + if self.wait_unpause(): + conn = self.database.connection() + else: + conn = self.loop(conn) + except ShutdownException: + logger.info('writer received shutdown request, exiting.') + return + finally: + if conn is not None: + self.database._close(conn) + self.database._state.reset() + + def wait_unpause(self): + obj = self.queue.get() + if obj is UNPAUSE: + logger.info('writer unpaused - reconnecting to database.') + return True + elif obj is SHUTDOWN: + raise ShutdownException() + elif obj is PAUSE: + logger.error('writer received pause, but is already paused.') + else: + obj.set_result(None, WriterPaused()) + logger.warning('writer paused, not handling %s', obj) + + def loop(self, conn): + obj = self.queue.get() + if isinstance(obj, AsyncCursor): + self.execute(obj) + elif obj is PAUSE: + logger.info('writer paused - closing database connection.') + self.database._close(conn) + self.database._state.reset() + return + elif obj is UNPAUSE: + logger.error('writer received unpause, but is already running.') + elif obj is SHUTDOWN: + raise ShutdownException() + else: + logger.error('writer received unsupported object: %s', obj) + return conn + + def execute(self, obj): + logger.debug('received query %s', obj.sql) + try: + cursor = self.database._execute(obj.sql, obj.params) + except Exception as execute_err: + cursor = None + exc = execute_err # python3 is so fucking lame. + else: + exc = None + return obj.set_result(cursor, exc) + + +class SqliteQueueDatabase(SqliteExtDatabase): + WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL ' + 'journal mode when using this feature. WAL mode ' + 'allows one or more readers to continue reading ' + 'while another connection writes to the ' + 'database.') + + def __init__(self, database, use_gevent=False, autostart=True, + queue_max_size=None, results_timeout=None, *args, **kwargs): + kwargs['check_same_thread'] = False + + # Ensure that journal_mode is WAL. This value is passed to the parent + # class constructor below. + pragmas = self._validate_journal_mode(kwargs.pop('pragmas', None)) + + # Reference to execute_sql on the parent class. Since we've overridden + # execute_sql(), this is just a handy way to reference the real + # implementation. + Parent = super(SqliteQueueDatabase, self) + self._execute = Parent.execute_sql + + # Call the parent class constructor with our modified pragmas. + Parent.__init__(database, pragmas=pragmas, *args, **kwargs) + + self._autostart = autostart + self._results_timeout = results_timeout + self._is_stopped = True + + # Get different objects depending on the threading implementation. + self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size) + + # Create the writer thread, optionally starting it. + self._create_write_queue() + if self._autostart: + self.start() + + def get_thread_impl(self, use_gevent): + return GreenletHelper if use_gevent else ThreadHelper + + def _validate_journal_mode(self, pragmas=None): + if not pragmas: + return {'journal_mode': 'wal'} + + if not isinstance(pragmas, dict): + pragmas = dict((k.lower(), v) for (k, v) in pragmas) + if pragmas.get('journal_mode', 'wal').lower() != 'wal': + raise ValueError(self.WAL_MODE_ERROR_MESSAGE) + + pragmas['journal_mode'] = 'wal' + return pragmas + + def _create_write_queue(self): + self._write_queue = self._thread_helper.queue() + + def queue_size(self): + return self._write_queue.qsize() + + def execute_sql(self, sql, params=None, commit=None, timeout=None): + if commit is not None: + __deprecated__('"commit" has been deprecated and is a no-op.') + if sql.lower().startswith('select'): + return self._execute(sql, params) + + cursor = AsyncCursor( + event=self._thread_helper.event(), + sql=sql, + params=params, + timeout=self._results_timeout if timeout is None else timeout) + self._write_queue.put(cursor) + return cursor + + def start(self): + with self._lock: + if not self._is_stopped: + return False + def run(): + writer = Writer(self, self._write_queue) + writer.run() + + self._writer = self._thread_helper.thread(run) + self._writer.start() + self._is_stopped = False + return True + + def stop(self): + logger.debug('environment stop requested.') + with self._lock: + if self._is_stopped: + return False + self._write_queue.put(SHUTDOWN) + self._writer.join() + self._is_stopped = True + return True + + def is_stopped(self): + with self._lock: + return self._is_stopped + + def pause(self): + with self._lock: + self._write_queue.put(PAUSE) + + def unpause(self): + with self._lock: + self._write_queue.put(UNPAUSE) + + def __unsupported__(self, *args, **kwargs): + raise ValueError('This method is not supported by %r.' % type(self)) + atomic = transaction = savepoint = __unsupported__ + + +class ThreadHelper(object): + __slots__ = ('queue_max_size',) + + def __init__(self, queue_max_size=None): + self.queue_max_size = queue_max_size + + def event(self): return Event() + + def queue(self, max_size=None): + max_size = max_size if max_size is not None else self.queue_max_size + return Queue(maxsize=max_size or 0) + + def thread(self, fn, *args, **kwargs): + thread = Thread(target=fn, args=args, kwargs=kwargs) + thread.daemon = True + return thread + + +class GreenletHelper(ThreadHelper): + __slots__ = () + + def event(self): return GEvent() + + def queue(self, max_size=None): + max_size = max_size if max_size is not None else self.queue_max_size + return GQueue(maxsize=max_size or 0) + + def thread(self, fn, *args, **kwargs): + def wrap(*a, **k): + gevent.sleep() + return fn(*a, **k) + return GThread(wrap, *args, **kwargs) diff --git a/python3.10libs/playhouse/test_utils.py b/python3.10libs/playhouse/test_utils.py new file mode 100644 index 0000000..83c1de7 --- /dev/null +++ b/python3.10libs/playhouse/test_utils.py @@ -0,0 +1,64 @@ +from functools import wraps +import logging + + +logger = logging.getLogger('peewee') + + +class _QueryLogHandler(logging.Handler): + def __init__(self, *args, **kwargs): + self.queries = [] + logging.Handler.__init__(self, *args, **kwargs) + + def emit(self, record): + # Counts all entries logged to the "peewee" logger by execute_sql(). + if record.name == 'peewee': + self.queries.append(record) + + +class count_queries(object): + def __init__(self, only_select=False): + self.only_select = only_select + self.count = 0 + + def get_queries(self): + return self._handler.queries + + def __enter__(self): + self._handler = _QueryLogHandler() + logger.setLevel(logging.DEBUG) + logger.addHandler(self._handler) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + logger.removeHandler(self._handler) + if self.only_select: + self.count = len([q for q in self._handler.queries + if q.msg[0].startswith('SELECT ')]) + else: + self.count = len(self._handler.queries) + + +class assert_query_count(count_queries): + def __init__(self, expected, only_select=False): + super(assert_query_count, self).__init__(only_select=only_select) + self.expected = expected + + def __call__(self, f): + @wraps(f) + def decorated(*args, **kwds): + with self: + ret = f(*args, **kwds) + + self._assert_count() + return ret + + return decorated + + def _assert_count(self): + error_msg = '%s != %s' % (self.count, self.expected) + assert self.count == self.expected, error_msg + + def __exit__(self, exc_type, exc_val, exc_tb): + super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb) + self._assert_count() diff --git a/python3.10libs/pythonrc.py b/python3.10libs/pythonrc.py new file mode 100644 index 0000000..7303bf2 --- /dev/null +++ b/python3.10libs/pythonrc.py @@ -0,0 +1,24 @@ +from __future__ import print_function +from __future__ import absolute_import + +import inspect +import os + +from searcher import searchersetup + +# info +__author__ = "instance.id" +__copyright__ = "2023 All rights reserved." +__status__ = "Release Candidate" + +# current_file_path = os.path.abspath( +# inspect.getsourcefile(lambda: 0) +# ) + + +def main(): + searchersetup.main() + + +if __name__ == '__main__': + main() diff --git a/python3.10libs/qtpy/Qsci.py b/python3.10libs/qtpy/Qsci.py new file mode 100644 index 0000000..85fec91 --- /dev/null +++ b/python3.10libs/qtpy/Qsci.py @@ -0,0 +1,36 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides Qsci classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.Qsci import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qsci", + missing_package="QScintilla", + ) from error +elif PYQT6: + try: + from PyQt6.Qsci import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qsci", + missing_package="PyQt6-QScintilla", + ) from error +elif PYSIDE2 or PYSIDE6: + raise QtBindingMissingModuleError(name="Qsci") diff --git a/python3.10libs/qtpy/Qt3DAnimation.py b/python3.10libs/qtpy/Qt3DAnimation.py new file mode 100644 index 0000000..8d68e59 --- /dev/null +++ b/python3.10libs/qtpy/Qt3DAnimation.py @@ -0,0 +1,49 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides Qt3DAnimation classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.Qt3DAnimation import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DAnimation", + missing_package="PyQt3D", + ) from error +elif PYQT6: + try: + from PyQt6.Qt3DAnimation import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DAnimation", + missing_package="PyQt6-3D", + ) from error +elif PYSIDE2: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide2.Qt3DAnimation as __temp + + for __name in inspect.getmembers(__temp.Qt3DAnimation): + globals()[__name[0]] = __name[1] +elif PYSIDE6: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide6.Qt3DAnimation as __temp + + for __name in inspect.getmembers(__temp.Qt3DAnimation): + globals()[__name[0]] = __name[1] diff --git a/python3.10libs/qtpy/Qt3DCore.py b/python3.10libs/qtpy/Qt3DCore.py new file mode 100644 index 0000000..5c15df1 --- /dev/null +++ b/python3.10libs/qtpy/Qt3DCore.py @@ -0,0 +1,49 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides Qt3DCore classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.Qt3DCore import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DCore", + missing_package="PyQt3D", + ) from error +elif PYQT6: + try: + from PyQt6.Qt3DCore import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DCore", + missing_package="PyQt6-3D", + ) from error +elif PYSIDE2: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide2.Qt3DCore as __temp + + for __name in inspect.getmembers(__temp.Qt3DCore): + globals()[__name[0]] = __name[1] +elif PYSIDE6: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide6.Qt3DCore as __temp + + for __name in inspect.getmembers(__temp.Qt3DCore): + globals()[__name[0]] = __name[1] diff --git a/python3.10libs/qtpy/Qt3DExtras.py b/python3.10libs/qtpy/Qt3DExtras.py new file mode 100644 index 0000000..d146065 --- /dev/null +++ b/python3.10libs/qtpy/Qt3DExtras.py @@ -0,0 +1,49 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides Qt3DExtras classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.Qt3DExtras import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DExtras", + missing_package="PyQt3D", + ) from error +elif PYQT6: + try: + from PyQt6.Qt3DExtras import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DExtras", + missing_package="PyQt6-3D", + ) from error +elif PYSIDE2: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide2.Qt3DExtras as __temp + + for __name in inspect.getmembers(__temp.Qt3DExtras): + globals()[__name[0]] = __name[1] +elif PYSIDE6: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide6.Qt3DExtras as __temp + + for __name in inspect.getmembers(__temp.Qt3DExtras): + globals()[__name[0]] = __name[1] diff --git a/python3.10libs/qtpy/Qt3DInput.py b/python3.10libs/qtpy/Qt3DInput.py new file mode 100644 index 0000000..1dbc3b4 --- /dev/null +++ b/python3.10libs/qtpy/Qt3DInput.py @@ -0,0 +1,49 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides Qt3DInput classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.Qt3DInput import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DInput", + missing_package="PyQt3D", + ) from error +elif PYQT6: + try: + from PyQt6.Qt3DInput import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DInput", + missing_package="PyQt6-3D", + ) from error +elif PYSIDE2: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide2.Qt3DInput as __temp + + for __name in inspect.getmembers(__temp.Qt3DInput): + globals()[__name[0]] = __name[1] +elif PYSIDE6: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide6.Qt3DInput as __temp + + for __name in inspect.getmembers(__temp.Qt3DInput): + globals()[__name[0]] = __name[1] diff --git a/python3.10libs/qtpy/Qt3DLogic.py b/python3.10libs/qtpy/Qt3DLogic.py new file mode 100644 index 0000000..3fea3f5 --- /dev/null +++ b/python3.10libs/qtpy/Qt3DLogic.py @@ -0,0 +1,49 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides Qt3DLogic classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.Qt3DLogic import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DLogic", + missing_package="PyQt3D", + ) from error +elif PYQT6: + try: + from PyQt6.Qt3DLogic import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DLogic", + missing_package="PyQt6-3D", + ) from error +elif PYSIDE2: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide2.Qt3DLogic as __temp + + for __name in inspect.getmembers(__temp.Qt3DLogic): + globals()[__name[0]] = __name[1] +elif PYSIDE6: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide6.Qt3DLogic as __temp + + for __name in inspect.getmembers(__temp.Qt3DLogic): + globals()[__name[0]] = __name[1] diff --git a/python3.10libs/qtpy/Qt3DRender.py b/python3.10libs/qtpy/Qt3DRender.py new file mode 100644 index 0000000..72f44f2 --- /dev/null +++ b/python3.10libs/qtpy/Qt3DRender.py @@ -0,0 +1,49 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides Qt3DRender classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.Qt3DRender import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DRender", + missing_package="PyQt3D", + ) from error +elif PYQT6: + try: + from PyQt6.Qt3DRender import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="Qt3DRender", + missing_package="PyQt6-3D", + ) from error +elif PYSIDE2: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide2.Qt3DRender as __temp + + for __name in inspect.getmembers(__temp.Qt3DRender): + globals()[__name[0]] = __name[1] +elif PYSIDE6: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide6.Qt3DRender as __temp + + for __name in inspect.getmembers(__temp.Qt3DRender): + globals()[__name[0]] = __name[1] diff --git a/python3.10libs/qtpy/QtAxContainer.py b/python3.10libs/qtpy/QtAxContainer.py new file mode 100644 index 0000000..fa2ba4e --- /dev/null +++ b/python3.10libs/qtpy/QtAxContainer.py @@ -0,0 +1,23 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtAxContainer classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5 or PYQT6: + raise QtBindingMissingModuleError(name="QtAxContainer") +elif PYSIDE2: + from PySide2.QtAxContainer import * +elif PYSIDE6: + from PySide6.QtAxContainer import * diff --git a/python3.10libs/qtpy/QtBluetooth.py b/python3.10libs/qtpy/QtBluetooth.py new file mode 100644 index 0000000..237f81f --- /dev/null +++ b/python3.10libs/qtpy/QtBluetooth.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtBluetooth classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + from PyQt5.QtBluetooth import * +elif PYQT6: + from PyQt6.QtBluetooth import * +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtBluetooth") +elif PYSIDE6: + from PySide6.QtBluetooth import * diff --git a/python3.10libs/qtpy/QtCharts.py b/python3.10libs/qtpy/QtCharts.py new file mode 100644 index 0000000..81716b0 --- /dev/null +++ b/python3.10libs/qtpy/QtCharts.py @@ -0,0 +1,47 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2019- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtChart classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5 import QtChart as QtCharts + from PyQt5.QtChart import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtCharts", + missing_package="PyQtChart", + ) from error +elif PYQT6: + try: + from PyQt6 import QtCharts + from PyQt6.QtCharts import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtCharts", + missing_package="PyQt6-Charts", + ) from error +elif PYSIDE2: + import inspect + + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import PySide2.QtCharts as __temp + from PySide2.QtCharts import * + + for __name in inspect.getmembers(__temp.QtCharts): + globals()[__name[0]] = __name[1] +elif PYSIDE6: + from PySide6 import QtCharts + from PySide6.QtCharts import * diff --git a/python3.10libs/qtpy/QtConcurrent.py b/python3.10libs/qtpy/QtConcurrent.py new file mode 100644 index 0000000..a8e3ed3 --- /dev/null +++ b/python3.10libs/qtpy/QtConcurrent.py @@ -0,0 +1,23 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtConcurrent classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5 or PYQT6: + raise QtBindingMissingModuleError(name="QtConcurrent") +elif PYSIDE2: + from PySide2.QtConcurrent import * +elif PYSIDE6: + from PySide6.QtConcurrent import * diff --git a/python3.10libs/qtpy/QtCore.py b/python3.10libs/qtpy/QtCore.py new file mode 100644 index 0000000..09a42b1 --- /dev/null +++ b/python3.10libs/qtpy/QtCore.py @@ -0,0 +1,197 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2014-2015 Colin Duquesnoy +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtCore classes and functions.""" +import contextlib +from typing import TYPE_CHECKING + +from packaging.version import parse + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 +from . import QT_VERSION as _qt_version +from ._utils import possibly_static_exec, possibly_static_exec_ + +if PYQT5: + from PyQt5.QtCore import * + from PyQt5.QtCore import pyqtBoundSignal as SignalInstance + from PyQt5.QtCore import pyqtProperty as Property + from PyQt5.QtCore import pyqtSignal as Signal + from PyQt5.QtCore import pyqtSlot as Slot + + try: + from PyQt5.QtCore import Q_ENUM as QEnum + + del Q_ENUM + except ImportError: # fallback for Qt5.9 + from PyQt5.QtCore import Q_ENUMS as QEnum + + del Q_ENUMS + from PyQt5.QtCore import QT_VERSION_STR as __version__ + + # Those are imported from `import *` + del pyqtSignal, pyqtBoundSignal, pyqtSlot, pyqtProperty, QT_VERSION_STR + +elif PYQT6: + from PyQt6 import QtCore + from PyQt6.QtCore import * + from PyQt6.QtCore import QT_VERSION_STR as __version__ + from PyQt6.QtCore import pyqtBoundSignal as SignalInstance + from PyQt6.QtCore import pyqtEnum as QEnum + from PyQt6.QtCore import pyqtProperty as Property + from PyQt6.QtCore import pyqtSignal as Signal + from PyQt6.QtCore import pyqtSlot as Slot + + # For issue #311 + # Seems like there is an error with sip. Without first + # trying to import `PyQt6.QtGui.Qt`, some functions like + # `PyQt6.QtCore.Qt.mightBeRichText` are missing. + if not TYPE_CHECKING: + with contextlib.suppress(ImportError): + from PyQt6.QtGui import Qt + + # Map missing methods + QCoreApplication.exec_ = lambda *args, **kwargs: possibly_static_exec( + QCoreApplication, + *args, + **kwargs, + ) + QEventLoop.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) + QThread.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) + + # Those are imported from `import *` + del ( + pyqtSignal, + pyqtBoundSignal, + pyqtSlot, + pyqtProperty, + pyqtEnum, + QT_VERSION_STR, + ) + + # Allow unscoped access for enums inside the QtCore module + from .enums_compat import promote_enums + + promote_enums(QtCore) + del QtCore + + # Alias deprecated ItemDataRole enum values removed in Qt6 + Qt.BackgroundColorRole = ( + Qt.ItemDataRole.BackgroundColorRole + ) = Qt.BackgroundRole + Qt.TextColorRole = Qt.ItemDataRole.TextColorRole = Qt.ForegroundRole + + # Alias for MiddleButton removed in PyQt6 but available in PyQt5, PySide2 and PySide6 + Qt.MidButton = Qt.MiddleButton + + # Add removed definition for `Qt.ItemFlags` as an alias of `Qt.ItemFlag` + # passing as default value 0 in the same way PySide6 6.5+ does. + # Note that for PyQt5 and PySide2 those definitions are two different classes + # (one is the flag definition and the other the enum definition) + Qt.ItemFlags = lambda value=0: Qt.ItemFlag(value) + +elif PYSIDE2: + import PySide2.QtCore + from PySide2.QtCore import * + + __version__ = PySide2.QtCore.__version__ + + # Missing QtGui utility functions on Qt + if getattr(Qt, "mightBeRichText", None) is None: + try: + from PySide2.QtGui import Qt as guiQt + + Qt.mightBeRichText = guiQt.mightBeRichText + del guiQt + except ImportError: + # Fails with PySide2 5.12.0 + pass + + QCoreApplication.exec = lambda *args, **kwargs: possibly_static_exec_( + QCoreApplication, + *args, + **kwargs, + ) + QEventLoop.exec = lambda self, *args, **kwargs: self.exec_(*args, **kwargs) + QThread.exec = lambda self, *args, **kwargs: self.exec_(*args, **kwargs) + QTextStreamManipulator.exec = lambda self, *args, **kwargs: self.exec_( + *args, + **kwargs, + ) + +elif PYSIDE6: + import PySide6.QtCore + from PySide6.QtCore import * + + __version__ = PySide6.QtCore.__version__ + + # Missing QtGui utility functions on Qt + if getattr(Qt, "mightBeRichText", None) is None: + from PySide6.QtGui import Qt as guiQt + + Qt.mightBeRichText = guiQt.mightBeRichText + del guiQt + + # Alias deprecated ItemDataRole enum values removed in Qt6 + Qt.BackgroundColorRole = ( + Qt.ItemDataRole.BackgroundColorRole + ) = Qt.BackgroundRole + Qt.TextColorRole = Qt.ItemDataRole.TextColorRole = Qt.ForegroundRole + Qt.MidButton = Qt.MiddleButton + + # Map DeprecationWarning methods + QCoreApplication.exec_ = lambda *args, **kwargs: possibly_static_exec( + QCoreApplication, + *args, + **kwargs, + ) + QEventLoop.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) + QThread.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) + QTextStreamManipulator.exec_ = lambda self, *args, **kwargs: self.exec( + *args, + **kwargs, + ) + + # Passing as default value 0 in the same way PySide6 6.3.2 does for the `Qt.ItemFlags` definition. + if parse(_qt_version) > parse("6.3"): + Qt.ItemFlags = lambda value=0: Qt.ItemFlag(value) + +# For issue #153 and updated for issue #305 +if PYQT5 or PYQT6: + QDate.toPython = lambda self, *args, **kwargs: self.toPyDate( + *args, + **kwargs, + ) + QDateTime.toPython = lambda self, *args, **kwargs: self.toPyDateTime( + *args, + **kwargs, + ) + QTime.toPython = lambda self, *args, **kwargs: self.toPyTime( + *args, + **kwargs, + ) +if PYSIDE2 or PYSIDE6: + QDate.toPyDate = lambda self, *args, **kwargs: self.toPython( + *args, + **kwargs, + ) + QDateTime.toPyDateTime = lambda self, *args, **kwargs: self.toPython( + *args, + **kwargs, + ) + QTime.toPyTime = lambda self, *args, **kwargs: self.toPython( + *args, + **kwargs, + ) + +# Mirror https://github.com/spyder-ide/qtpy/pull/393 +if PYQT5 or PYSIDE2: + QLibraryInfo.path = QLibraryInfo.location + QLibraryInfo.LibraryPath = QLibraryInfo.LibraryLocation +if PYQT6 or PYSIDE6: + QLibraryInfo.location = QLibraryInfo.path + QLibraryInfo.LibraryLocation = QLibraryInfo.LibraryPath diff --git a/python3.10libs/qtpy/QtDBus.py b/python3.10libs/qtpy/QtDBus.py new file mode 100644 index 0000000..1d41aab --- /dev/null +++ b/python3.10libs/qtpy/QtDBus.py @@ -0,0 +1,31 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtDBus classes and functions.""" + +import sys + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, + QtModuleNotInOSError, +) + +if PYQT5: + from PyQt5.QtDBus import * +elif PYQT6: + from PyQt6.QtDBus import * +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtDBus") +elif PYSIDE6: + if sys.platform != "win32": + from PySide6.QtDBus import * + else: + raise QtModuleNotInOSError(name="QtDBus") diff --git a/python3.10libs/qtpy/QtDataVisualization.py b/python3.10libs/qtpy/QtDataVisualization.py new file mode 100644 index 0000000..0a4facf --- /dev/null +++ b/python3.10libs/qtpy/QtDataVisualization.py @@ -0,0 +1,43 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtDataVisualization classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.QtDataVisualization import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtDataVisualization", + missing_package="PyQtDataVisualization", + ) from error +elif PYQT6: + try: + from PyQt6.QtDataVisualization import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtDataVisualization", + missing_package="PyQt6-DataVisualization", + ) from error +elif PYSIDE2: + # https://bugreports.qt.io/projects/PYSIDE/issues/PYSIDE-1026 + import inspect + + import PySide2.QtDataVisualization as __temp + + for __name in inspect.getmembers(__temp.QtDataVisualization): + globals()[__name[0]] = __name[1] +elif PYSIDE6: + from PySide6.QtDataVisualization import * diff --git a/python3.10libs/qtpy/QtDesigner.py b/python3.10libs/qtpy/QtDesigner.py new file mode 100644 index 0000000..acf61d5 --- /dev/null +++ b/python3.10libs/qtpy/QtDesigner.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2014-2015 Colin Duquesnoy +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtDesigner classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + from PyQt5.QtDesigner import * +elif PYQT6: + from PyQt6.QtDesigner import * +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtDesigner") +elif PYSIDE6: + from PySide6.QtDesigner import * diff --git a/python3.10libs/qtpy/QtGui.py b/python3.10libs/qtpy/QtGui.py new file mode 100644 index 0000000..8b3a0ba --- /dev/null +++ b/python3.10libs/qtpy/QtGui.py @@ -0,0 +1,254 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2014-2015 Colin Duquesnoy +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtGui classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6, QtModuleNotInstalledError +from ._utils import getattr_missing_optional_dep, possibly_static_exec + +_missing_optional_names = {} + +_QTOPENGL_NAMES = { + "QOpenGLBuffer", + "QOpenGLContext", + "QOpenGLContextGroup", + "QOpenGLDebugLogger", + "QOpenGLDebugMessage", + "QOpenGLFramebufferObject", + "QOpenGLFramebufferObjectFormat", + "QOpenGLPixelTransferOptions", + "QOpenGLShader", + "QOpenGLShaderProgram", + "QOpenGLTexture", + "QOpenGLTextureBlitter", + "QOpenGLVersionProfile", + "QOpenGLVertexArrayObject", + "QOpenGLWindow", +} + + +def __getattr__(name): + """Custom getattr to chain and wrap errors due to missing optional deps.""" + raise getattr_missing_optional_dep( + name, + module_name=__name__, + optional_names=_missing_optional_names, + ) + + +if PYQT5: + from PyQt5.QtGui import * + + # Backport items moved to QtGui in Qt6 + from PyQt5.QtWidgets import ( + QAction, + QActionGroup, + QFileSystemModel, + QShortcut, + QUndoCommand, + ) + +elif PYQT6: + from PyQt6 import QtGui + from PyQt6.QtGui import * + + # Attempt to import QOpenGL* classes, but if that fails, + # don't raise an exception until the name is explicitly accessed. + # See https://github.com/spyder-ide/qtpy/pull/387/ + try: + from PyQt6.QtOpenGL import * + except ImportError as error: + for name in _QTOPENGL_NAMES: + _missing_optional_names[name] = { + "name": "PyQt6.QtOpenGL", + "missing_package": "pyopengl", + "import_error": error, + } + + QFontMetrics.width = lambda self, *args, **kwargs: self.horizontalAdvance( + *args, + **kwargs, + ) + QFontMetricsF.width = lambda self, *args, **kwargs: self.horizontalAdvance( + *args, + **kwargs, + ) + + # Map missing/renamed methods + QDrag.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) + QGuiApplication.exec_ = lambda *args, **kwargs: possibly_static_exec( + QGuiApplication, + *args, + **kwargs, + ) + QTextDocument.print_ = lambda self, *args, **kwargs: self.print( + *args, + **kwargs, + ) + + # Allow unscoped access for enums inside the QtGui module + from .enums_compat import promote_enums + + promote_enums(QtGui) + del QtGui +elif PYSIDE2: + from PySide2.QtGui import * + + # Backport items moved to QtGui in Qt6 + from PySide2.QtWidgets import ( + QAction, + QActionGroup, + QFileSystemModel, + QShortcut, + QUndoCommand, + ) + + if hasattr(QFontMetrics, "horizontalAdvance"): + # Needed to prevent raising a DeprecationWarning when using QFontMetrics.width + QFontMetrics.width = ( + lambda self, *args, **kwargs: self.horizontalAdvance( + *args, + **kwargs, + ) + ) +elif PYSIDE6: + from PySide6.QtGui import * + + # Attempt to import QOpenGL* classes, but if that fails, + # don't raise an exception until the name is explicitly accessed. + # See https://github.com/spyder-ide/qtpy/pull/387/ + try: + from PySide6.QtOpenGL import * + except ImportError as error: + for name in _QTOPENGL_NAMES: + _missing_optional_names[name] = { + "name": "PySide6.QtOpenGL", + "missing_package": "pyopengl", + "import_error": error, + } + + # Backport `QFileSystemModel` moved to QtGui in Qt6 + from PySide6.QtWidgets import QFileSystemModel + + QFontMetrics.width = lambda self, *args, **kwargs: self.horizontalAdvance( + *args, + **kwargs, + ) + QFontMetricsF.width = lambda self, *args, **kwargs: self.horizontalAdvance( + *args, + **kwargs, + ) + + # Map DeprecationWarning methods + QDrag.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) + QGuiApplication.exec_ = lambda *args, **kwargs: possibly_static_exec( + QGuiApplication, + *args, + **kwargs, + ) + +if PYSIDE2 or PYSIDE6: + # PySide{2,6} do not accept the `mode` keyword argument in + # QTextCursor.movePosition() even though it is a valid optional argument + # as per C++ API. Fix this by monkeypatching. + # + # Notes: + # + # * The `mode` argument is called `arg__2` in PySide{2,6} as per + # QTextCursor.movePosition.__doc__ and __signature__. Using `arg__2` as + # keyword argument works as intended, so does using a positional + # argument. Tested with PySide2 5.15.0, 5.15.2.1 and 5.15.3 and PySide6 + # 6.3.0; older version, down to PySide 1, are probably affected as well [1]. + # + # * PySide2 5.15.0 and 5.15.2.1 silently ignore invalid keyword arguments, + # i.e. passing the `mode` keyword argument has no effect and doesn`t + # raise an exception. Older versions, down to PySide 1, are probably + # affected as well [1]. At least PySide2 5.15.3 and PySide6 6.3.0 raise an + # exception when `mode` or any other invalid keyword argument is passed. + # + # [1] https://bugreports.qt.io/browse/PYSIDE-185 + movePosition = QTextCursor.movePosition + + def movePositionPatched( + self, + operation: QTextCursor.MoveOperation, + mode: QTextCursor.MoveMode = QTextCursor.MoveAnchor, + n: int = 1, + ) -> bool: + return movePosition(self, operation, mode, n) + + QTextCursor.movePosition = movePositionPatched + +if PYQT5 or PYSIDE2: + # Part of the fix for https://github.com/spyder-ide/qtpy/issues/394 + from qtpy.QtCore import QPointF as __QPointF + + QNativeGestureEvent.x = lambda self: self.localPos().toPoint().x() + QNativeGestureEvent.y = lambda self: self.localPos().toPoint().y() + QNativeGestureEvent.position = lambda self: self.localPos() + QNativeGestureEvent.globalX = lambda self: self.globalPos().x() + QNativeGestureEvent.globalY = lambda self: self.globalPos().y() + QNativeGestureEvent.globalPosition = lambda self: __QPointF( + float(self.globalPos().x()), + float(self.globalPos().y()), + ) + QEnterEvent.position = lambda self: self.localPos() + QEnterEvent.globalPosition = lambda self: __QPointF( + float(self.globalX()), + float(self.globalY()), + ) + QTabletEvent.position = lambda self: self.posF() + QTabletEvent.globalPosition = lambda self: self.globalPosF() + QHoverEvent.x = lambda self: self.pos().x() + QHoverEvent.y = lambda self: self.pos().y() + QHoverEvent.position = lambda self: self.posF() + # No `QHoverEvent.globalPosition`, `QHoverEvent.globalX`, + # nor `QHoverEvent.globalY` in the Qt5 docs. + QMouseEvent.position = lambda self: self.localPos() + QMouseEvent.globalPosition = lambda self: __QPointF( + float(self.globalX()), + float(self.globalY()), + ) + + # Follow similar approach for `QDropEvent` and child classes + QDropEvent.position = lambda self: self.posF() +if PYQT6 or PYSIDE6: + # Part of the fix for https://github.com/spyder-ide/qtpy/issues/394 + for _class in ( + QNativeGestureEvent, + QEnterEvent, + QTabletEvent, + QHoverEvent, + QMouseEvent, + ): + for _obsolete_function in ( + "pos", + "x", + "y", + "globalPos", + "globalX", + "globalY", + ): + if hasattr(_class, _obsolete_function): + delattr(_class, _obsolete_function) + QSinglePointEvent.pos = lambda self: self.position().toPoint() + QSinglePointEvent.posF = lambda self: self.position() + QSinglePointEvent.localPos = lambda self: self.position() + QSinglePointEvent.x = lambda self: self.position().toPoint().x() + QSinglePointEvent.y = lambda self: self.position().toPoint().y() + QSinglePointEvent.globalPos = lambda self: self.globalPosition().toPoint() + QSinglePointEvent.globalX = ( + lambda self: self.globalPosition().toPoint().x() + ) + QSinglePointEvent.globalY = ( + lambda self: self.globalPosition().toPoint().y() + ) + + # Follow similar approach for `QDropEvent` and child classes + QDropEvent.pos = lambda self: self.position().toPoint() + QDropEvent.posF = lambda self: self.position() diff --git a/python3.10libs/qtpy/QtHelp.py b/python3.10libs/qtpy/QtHelp.py new file mode 100644 index 0000000..d0e2bab --- /dev/null +++ b/python3.10libs/qtpy/QtHelp.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""QtHelp Wrapper.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtHelp import * +elif PYQT6: + from PyQt6.QtHelp import * +elif PYSIDE2: + from PySide2.QtHelp import * +elif PYSIDE6: + from PySide6.QtHelp import * diff --git a/python3.10libs/qtpy/QtLocation.py b/python3.10libs/qtpy/QtLocation.py new file mode 100644 index 0000000..160c947 --- /dev/null +++ b/python3.10libs/qtpy/QtLocation.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtLocation classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + from PyQt5.QtLocation import * +elif PYQT6: + raise QtBindingMissingModuleError(name="QtLocation") +elif PYSIDE2: + from PySide2.QtLocation import * +elif PYSIDE6: + raise QtBindingMissingModuleError(name="QtLocation") diff --git a/python3.10libs/qtpy/QtMacExtras.py b/python3.10libs/qtpy/QtMacExtras.py new file mode 100644 index 0000000..58d23b8 --- /dev/null +++ b/python3.10libs/qtpy/QtMacExtras.py @@ -0,0 +1,31 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides classes and functions specific to macOS and iOS operating systems""" + +import sys + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInOSError, + QtModuleNotInQtVersionError, +) + +if sys.platform == "darwin": + if PYQT5: + from PyQt5.QtMacExtras import * + elif PYQT6: + raise QtModuleNotInQtVersionError(name="QtMacExtras") + elif PYSIDE2: + from PySide2.QtMacExtras import * + elif PYSIDE6: + raise QtModuleNotInQtVersionError(name="QtMacExtras") +else: + raise QtModuleNotInOSError(name="QtMacExtras") diff --git a/python3.10libs/qtpy/QtMultimedia.py b/python3.10libs/qtpy/QtMultimedia.py new file mode 100644 index 0000000..7403e64 --- /dev/null +++ b/python3.10libs/qtpy/QtMultimedia.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides low-level multimedia functionality.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtMultimedia import * +elif PYQT6: + from PyQt6.QtMultimedia import * +elif PYSIDE2: + from PySide2.QtMultimedia import * +elif PYSIDE6: + from PySide6.QtMultimedia import * diff --git a/python3.10libs/qtpy/QtMultimediaWidgets.py b/python3.10libs/qtpy/QtMultimediaWidgets.py new file mode 100644 index 0000000..69af111 --- /dev/null +++ b/python3.10libs/qtpy/QtMultimediaWidgets.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtMultimediaWidgets classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtMultimediaWidgets import * +elif PYQT6: + from PyQt6.QtMultimediaWidgets import * +elif PYSIDE2: + from PySide2.QtMultimediaWidgets import * +elif PYSIDE6: + from PySide6.QtMultimediaWidgets import * diff --git a/python3.10libs/qtpy/QtNetwork.py b/python3.10libs/qtpy/QtNetwork.py new file mode 100644 index 0000000..2c4e547 --- /dev/null +++ b/python3.10libs/qtpy/QtNetwork.py @@ -0,0 +1,20 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2014-2015 Colin Duquesnoy +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtNetwork classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtNetwork import * +elif PYQT6: + from PyQt6.QtNetwork import * +elif PYSIDE2: + from PySide2.QtNetwork import * +elif PYSIDE6: + from PySide6.QtNetwork import * diff --git a/python3.10libs/qtpy/QtNetworkAuth.py b/python3.10libs/qtpy/QtNetworkAuth.py new file mode 100644 index 0000000..f49adce --- /dev/null +++ b/python3.10libs/qtpy/QtNetworkAuth.py @@ -0,0 +1,38 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtNetworkAuth classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.QtNetworkAuth import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtNetworkAuth", + missing_package="PyQtNetworkAuth", + ) from error +elif PYQT6: + try: + from PyQt6.QtNetworkAuth import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtNetworkAuth", + missing_package="PyQt6-NetworkAuth", + ) from error +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtNetworkAuth") +elif PYSIDE6: + from PySide6.QtNetworkAuth import * diff --git a/python3.10libs/qtpy/QtNfc.py b/python3.10libs/qtpy/QtNfc.py new file mode 100644 index 0000000..8dafd7c --- /dev/null +++ b/python3.10libs/qtpy/QtNfc.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtNfc classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + from PyQt5.QtNfc import * +elif PYQT6: + from PyQt6.QtNfc import * +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtNfc") +elif PYSIDE6: + from PySide6.QtNfc import * diff --git a/python3.10libs/qtpy/QtOpenGL.py b/python3.10libs/qtpy/QtOpenGL.py new file mode 100644 index 0000000..af881ed --- /dev/null +++ b/python3.10libs/qtpy/QtOpenGL.py @@ -0,0 +1,66 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtOpenGL classes and functions.""" + +import contextlib + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtGui import ( + QOpenGLBuffer, + QOpenGLContext, + QOpenGLContextGroup, + QOpenGLDebugLogger, + QOpenGLDebugMessage, + QOpenGLFramebufferObject, + QOpenGLFramebufferObjectFormat, + QOpenGLPixelTransferOptions, + QOpenGLShader, + QOpenGLShaderProgram, + QOpenGLTexture, + QOpenGLTextureBlitter, + QOpenGLVersionProfile, + QOpenGLVertexArrayObject, + QOpenGLWindow, + ) + from PyQt5.QtOpenGL import * + + # These are not present on some architectures such as armhf + with contextlib.suppress(ImportError): + from PyQt5.QtGui import QOpenGLTimeMonitor, QOpenGLTimerQuery + +elif PYQT6: + from PyQt6.QtGui import QOpenGLContext, QOpenGLContextGroup + from PyQt6.QtOpenGL import * +elif PYSIDE6: + from PySide6.QtGui import QOpenGLContext, QOpenGLContextGroup + from PySide6.QtOpenGL import * +elif PYSIDE2: + from PySide2.QtGui import ( + QOpenGLBuffer, + QOpenGLContext, + QOpenGLContextGroup, + QOpenGLDebugLogger, + QOpenGLDebugMessage, + QOpenGLFramebufferObject, + QOpenGLFramebufferObjectFormat, + QOpenGLPixelTransferOptions, + QOpenGLShader, + QOpenGLShaderProgram, + QOpenGLTexture, + QOpenGLTextureBlitter, + QOpenGLVersionProfile, + QOpenGLVertexArrayObject, + QOpenGLWindow, + ) + from PySide2.QtOpenGL import * + + # These are not present on some architectures such as armhf + with contextlib.suppress(ImportError): + from PySide2.QtGui import QOpenGLTimeMonitor, QOpenGLTimerQuery diff --git a/python3.10libs/qtpy/QtOpenGLWidgets.py b/python3.10libs/qtpy/QtOpenGLWidgets.py new file mode 100644 index 0000000..c000e0f --- /dev/null +++ b/python3.10libs/qtpy/QtOpenGLWidgets.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtOpenGLWidgets classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + raise QtBindingMissingModuleError(name="QtOpenGLWidgets") +elif PYQT6: + from PyQt6.QtOpenGLWidgets import * +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtOpenGLWidgets") +elif PYSIDE6: + from PySide6.QtOpenGLWidgets import * diff --git a/python3.10libs/qtpy/QtPdf.py b/python3.10libs/qtpy/QtPdf.py new file mode 100644 index 0000000..f98cbd0 --- /dev/null +++ b/python3.10libs/qtpy/QtPdf.py @@ -0,0 +1,27 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtPdf classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + raise QtBindingMissingModuleError(name="QtPdf") +elif PYQT6: + # Available with version >=6.4.0 + from PyQt6.QtPdf import * +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtPdf") +elif PYSIDE6: + # Available with version >=6.4.0 + from PySide6.QtPdf import * diff --git a/python3.10libs/qtpy/QtPdfWidgets.py b/python3.10libs/qtpy/QtPdfWidgets.py new file mode 100644 index 0000000..a437db6 --- /dev/null +++ b/python3.10libs/qtpy/QtPdfWidgets.py @@ -0,0 +1,27 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtPdfWidgets classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + raise QtBindingMissingModuleError(name="QtPdfWidgets") +elif PYQT6: + # Available with version >=6.4.0 + from PyQt6.QtPdfWidgets import * +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtPdfWidgets") +elif PYSIDE6: + # Available with version >=6.4.0 + from PySide6.QtPdfWidgets import * diff --git a/python3.10libs/qtpy/QtPositioning.py b/python3.10libs/qtpy/QtPositioning.py new file mode 100644 index 0000000..5083cee --- /dev/null +++ b/python3.10libs/qtpy/QtPositioning.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright 2020 Antonio Valentino +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtPositioning classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtPositioning import * +elif PYQT6: + from PyQt6.QtPositioning import * +elif PYSIDE2: + from PySide2.QtPositioning import * +elif PYSIDE6: + from PySide6.QtPositioning import * diff --git a/python3.10libs/qtpy/QtPrintSupport.py b/python3.10libs/qtpy/QtPrintSupport.py new file mode 100644 index 0000000..d78ff4c --- /dev/null +++ b/python3.10libs/qtpy/QtPrintSupport.py @@ -0,0 +1,42 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtPrintSupport classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtPrintSupport import * +elif PYQT6: + from PyQt6.QtPrintSupport import * + + QPageSetupDialog.exec_ = lambda self, *args, **kwargs: self.exec( + *args, + **kwargs, + ) + QPrintDialog.exec_ = lambda self, *args, **kwargs: self.exec( + *args, + **kwargs, + ) + QPrintPreviewWidget.print_ = lambda self, *args, **kwargs: self.print( + *args, + **kwargs, + ) +elif PYSIDE6: + from PySide6.QtPrintSupport import * + + # Map DeprecationWarning methods + QPageSetupDialog.exec_ = lambda self, *args, **kwargs: self.exec( + *args, + **kwargs, + ) + QPrintDialog.exec_ = lambda self, *args, **kwargs: self.exec( + *args, + **kwargs, + ) +elif PYSIDE2: + from PySide2.QtPrintSupport import * diff --git a/python3.10libs/qtpy/QtPurchasing.py b/python3.10libs/qtpy/QtPurchasing.py new file mode 100644 index 0000000..fd69448 --- /dev/null +++ b/python3.10libs/qtpy/QtPurchasing.py @@ -0,0 +1,28 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtPurchasing classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.QtPurchasing import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtPurchasing", + missing_package="PyQtPurchasing", + ) from error +elif PYQT6 or PYSIDE2 or PYSIDE6: + raise QtBindingMissingModuleError(name="QtPurchasing") diff --git a/python3.10libs/qtpy/QtQml.py b/python3.10libs/qtpy/QtQml.py new file mode 100644 index 0000000..9d07f0e --- /dev/null +++ b/python3.10libs/qtpy/QtQml.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtQml classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtQml import * +elif PYQT6: + from PyQt6.QtQml import * +elif PYSIDE6: + from PySide6.QtQml import * +elif PYSIDE2: + from PySide2.QtQml import * diff --git a/python3.10libs/qtpy/QtQuick.py b/python3.10libs/qtpy/QtQuick.py new file mode 100644 index 0000000..1c2b7da --- /dev/null +++ b/python3.10libs/qtpy/QtQuick.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtQuick classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtQuick import * +elif PYQT6: + from PyQt6.QtQuick import * +elif PYSIDE6: + from PySide6.QtQuick import * +elif PYSIDE2: + from PySide2.QtQuick import * diff --git a/python3.10libs/qtpy/QtQuick3D.py b/python3.10libs/qtpy/QtQuick3D.py new file mode 100644 index 0000000..a8138f9 --- /dev/null +++ b/python3.10libs/qtpy/QtQuick3D.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtQuick3D classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + from PyQt5.QtQuick3D import * +elif PYQT6: + from PyQt6.QtQuick3D import * +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtQuick3D") +elif PYSIDE6: + from PySide6.QtQuick3D import * diff --git a/python3.10libs/qtpy/QtQuickControls2.py b/python3.10libs/qtpy/QtQuickControls2.py new file mode 100644 index 0000000..634d544 --- /dev/null +++ b/python3.10libs/qtpy/QtQuickControls2.py @@ -0,0 +1,23 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtQuickControls2 classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5 or PYQT6: + raise QtBindingMissingModuleError(name="QtQuickControls2") +elif PYSIDE2: + from PySide2.QtQuickControls2 import * +elif PYSIDE6: + from PySide6.QtQuickControls2 import * diff --git a/python3.10libs/qtpy/QtQuickWidgets.py b/python3.10libs/qtpy/QtQuickWidgets.py new file mode 100644 index 0000000..136b411 --- /dev/null +++ b/python3.10libs/qtpy/QtQuickWidgets.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtQuickWidgets classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtQuickWidgets import * +elif PYQT6: + from PyQt6.QtQuickWidgets import * +elif PYSIDE6: + from PySide6.QtQuickWidgets import * +elif PYSIDE2: + from PySide2.QtQuickWidgets import * diff --git a/python3.10libs/qtpy/QtRemoteObjects.py b/python3.10libs/qtpy/QtRemoteObjects.py new file mode 100644 index 0000000..1035586 --- /dev/null +++ b/python3.10libs/qtpy/QtRemoteObjects.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtRemoteObjects classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtRemoteObjects import * +elif PYQT6: + from PyQt6.QtRemoteObjects import * +elif PYSIDE6: + from PySide6.QtRemoteObjects import * +elif PYSIDE2: + from PySide2.QtRemoteObjects import * diff --git a/python3.10libs/qtpy/QtScxml.py b/python3.10libs/qtpy/QtScxml.py new file mode 100644 index 0000000..40da5ef --- /dev/null +++ b/python3.10libs/qtpy/QtScxml.py @@ -0,0 +1,23 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtScxml classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5 or PYQT6: + raise QtBindingMissingModuleError(name="QtScxml") +elif PYSIDE2: + from PySide2.QtScxml import * +elif PYSIDE6: + from PySide6.QtScxml import * diff --git a/python3.10libs/qtpy/QtSensors.py b/python3.10libs/qtpy/QtSensors.py new file mode 100644 index 0000000..25b2919 --- /dev/null +++ b/python3.10libs/qtpy/QtSensors.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtSensors classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtSensors import * +elif PYQT6: + from PyQt6.QtSensors import * +elif PYSIDE6: + from PySide6.QtSensors import * +elif PYSIDE2: + from PySide2.QtSensors import * diff --git a/python3.10libs/qtpy/QtSerialPort.py b/python3.10libs/qtpy/QtSerialPort.py new file mode 100644 index 0000000..878c35b --- /dev/null +++ b/python3.10libs/qtpy/QtSerialPort.py @@ -0,0 +1,20 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2020 Marcin Stano +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtSerialPort classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtSerialPort import * +elif PYQT6: + from PyQt6.QtSerialPort import * +elif PYSIDE6: + from PySide6.QtSerialPort import * +elif PYSIDE2: + from PySide2.QtSerialPort import * diff --git a/python3.10libs/qtpy/QtSql.py b/python3.10libs/qtpy/QtSql.py new file mode 100644 index 0000000..76a6376 --- /dev/null +++ b/python3.10libs/qtpy/QtSql.py @@ -0,0 +1,34 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtSql classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtSql import * +elif PYQT6: + from PyQt6.QtSql import * + + QSqlDatabase.exec_ = lambda self, *args, **kwargs: self.exec( + *args, + **kwargs, + ) + QSqlQuery.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) + QSqlResult.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) +elif PYSIDE6: + from PySide6.QtSql import * + + # Map DeprecationWarning methods + QSqlDatabase.exec_ = lambda self, *args, **kwargs: self.exec( + *args, + **kwargs, + ) + QSqlQuery.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) + QSqlResult.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) +elif PYSIDE2: + from PySide2.QtSql import * diff --git a/python3.10libs/qtpy/QtStateMachine.py b/python3.10libs/qtpy/QtStateMachine.py new file mode 100644 index 0000000..343ce4a --- /dev/null +++ b/python3.10libs/qtpy/QtStateMachine.py @@ -0,0 +1,21 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtStateMachine classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5 or PYQT6 or PYSIDE2: + raise QtBindingMissingModuleError(name="QtStateMachine") +elif PYSIDE6: + from PySide6.QtStateMachine import * diff --git a/python3.10libs/qtpy/QtSvg.py b/python3.10libs/qtpy/QtSvg.py new file mode 100644 index 0000000..0ee4f9e --- /dev/null +++ b/python3.10libs/qtpy/QtSvg.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtSvg classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtSvg import * +elif PYQT6: + from PyQt6.QtSvg import * +elif PYSIDE2: + from PySide2.QtSvg import * +elif PYSIDE6: + from PySide6.QtSvg import * diff --git a/python3.10libs/qtpy/QtSvgWidgets.py b/python3.10libs/qtpy/QtSvgWidgets.py new file mode 100644 index 0000000..7e91dfc --- /dev/null +++ b/python3.10libs/qtpy/QtSvgWidgets.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtSvgWidgets classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + raise QtBindingMissingModuleError(name="QtSvgWidgets") +elif PYQT6: + from PyQt6.QtSvgWidgets import * +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtSvgWidgets") +elif PYSIDE6: + from PySide6.QtSvgWidgets import * diff --git a/python3.10libs/qtpy/QtTest.py b/python3.10libs/qtpy/QtTest.py new file mode 100644 index 0000000..b14418f --- /dev/null +++ b/python3.10libs/qtpy/QtTest.py @@ -0,0 +1,27 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2014-2015 Colin Duquesnoy +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtTest and functions""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtTest import * +elif PYQT6: + from PyQt6 import QtTest + from PyQt6.QtTest import * + + # Allow unscoped access for enums inside the QtTest module + from .enums_compat import promote_enums + + promote_enums(QtTest) + del QtTest +elif PYSIDE2: + from PySide2.QtTest import * +elif PYSIDE6: + from PySide6.QtTest import * diff --git a/python3.10libs/qtpy/QtTextToSpeech.py b/python3.10libs/qtpy/QtTextToSpeech.py new file mode 100644 index 0000000..c37192a --- /dev/null +++ b/python3.10libs/qtpy/QtTextToSpeech.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtTextToSpeech classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + from PyQt5.QtTextToSpeech import * +elif PYQT6: + raise QtBindingMissingModuleError(name="QtTextToSpeech") +elif PYSIDE2: + from PySide2.QtTextToSpeech import * +elif PYSIDE6: + raise QtBindingMissingModuleError(name="QtTextToSpeech") diff --git a/python3.10libs/qtpy/QtUiTools.py b/python3.10libs/qtpy/QtUiTools.py new file mode 100644 index 0000000..ceca1ef --- /dev/null +++ b/python3.10libs/qtpy/QtUiTools.py @@ -0,0 +1,23 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtUiTools classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5 or PYQT6: + raise QtBindingMissingModuleError(name="QtUiTools") +elif PYSIDE2: + from PySide2.QtUiTools import * +elif PYSIDE6: + from PySide6.QtUiTools import * diff --git a/python3.10libs/qtpy/QtWebChannel.py b/python3.10libs/qtpy/QtWebChannel.py new file mode 100644 index 0000000..b2c35ff --- /dev/null +++ b/python3.10libs/qtpy/QtWebChannel.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtWebChannel classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtWebChannel import * +elif PYQT6: + from PyQt6.QtWebChannel import * +elif PYSIDE2: + from PySide2.QtWebChannel import * +elif PYSIDE6: + from PySide6.QtWebChannel import * diff --git a/python3.10libs/qtpy/QtWebEngine.py b/python3.10libs/qtpy/QtWebEngine.py new file mode 100644 index 0000000..7cc08ff --- /dev/null +++ b/python3.10libs/qtpy/QtWebEngine.py @@ -0,0 +1,33 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2014-2015 Colin Duquesnoy +# Copyright © 2009- The Spyder development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtWebEngine classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInQtVersionError, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.QtWebEngine import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtWebEngine", + missing_package="PyQtWebEngine", + ) from error +elif PYQT6: + raise QtModuleNotInQtVersionError(name="QtWebEngine") +elif PYSIDE2: + from PySide2.QtWebEngine import * +elif PYSIDE6: + raise QtModuleNotInQtVersionError(name="QtWebEngine") diff --git a/python3.10libs/qtpy/QtWebEngineCore.py b/python3.10libs/qtpy/QtWebEngineCore.py new file mode 100644 index 0000000..69aa4ee --- /dev/null +++ b/python3.10libs/qtpy/QtWebEngineCore.py @@ -0,0 +1,37 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtWebEngineCore classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInstalledError, +) + +if PYQT5: + try: + from PyQt5.QtWebEngineCore import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtWebEngineCore", + missing_package="PyQtWebEngine", + ) from error +elif PYQT6: + try: + from PyQt6.QtWebEngineCore import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtWebEngineCore", + missing_package="PyQt6-WebEngine", + ) from error +elif PYSIDE2: + from PySide2.QtWebEngineCore import * +elif PYSIDE6: + from PySide6.QtWebEngineCore import * diff --git a/python3.10libs/qtpy/QtWebEngineQuick.py b/python3.10libs/qtpy/QtWebEngineQuick.py new file mode 100644 index 0000000..717ac94 --- /dev/null +++ b/python3.10libs/qtpy/QtWebEngineQuick.py @@ -0,0 +1,32 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtWebEngineQuick classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, + QtModuleNotInstalledError, +) + +if PYQT5: + raise QtBindingMissingModuleError(name="QtWebEngineQuick") +elif PYQT6: + try: + from PyQt6.QtWebEngineQuick import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtWebEngineQuick", + missing_package="PyQt6-WebEngine", + ) from error +elif PYSIDE2: + raise QtBindingMissingModuleError(name="QtWebEngineQuick") +elif PYSIDE6: + from PySide6.QtWebEngineQuick import * diff --git a/python3.10libs/qtpy/QtWebEngineWidgets.py b/python3.10libs/qtpy/QtWebEngineWidgets.py new file mode 100644 index 0000000..a39885f --- /dev/null +++ b/python3.10libs/qtpy/QtWebEngineWidgets.py @@ -0,0 +1,70 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2014-2015 Colin Duquesnoy +# Copyright © 2009- The Spyder development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtWebEngineWidgets classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInstalledError, +) + +# To test if we are using WebEngine or WebKit +# NOTE: This constant is imported by other projects (e.g. Spyder), so please +# don't remove it. +WEBENGINE = True + + +if PYQT5: + try: + # Based on the work at https://github.com/spyder-ide/qtpy/pull/203 + from PyQt5.QtWebEngineWidgets import ( + QWebEnginePage, + QWebEngineProfile, + QWebEngineScript, + QWebEngineSettings, + QWebEngineView, + ) + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtWebEngineWidgets", + missing_package="PyQtWebEngine", + ) from error +elif PYQT6: + try: + from PyQt6.QtWebEngineCore import ( + QWebEnginePage, + QWebEngineProfile, + QWebEngineScript, + QWebEngineSettings, + ) + from PyQt6.QtWebEngineWidgets import * + except ModuleNotFoundError as error: + raise QtModuleNotInstalledError( + name="QtWebEngineWidgets", + missing_package="PyQt6-WebEngine", + ) from error +elif PYSIDE2: + # Based on the work at https://github.com/spyder-ide/qtpy/pull/203 + from PySide2.QtWebEngineWidgets import ( + QWebEnginePage, + QWebEngineProfile, + QWebEngineScript, + QWebEngineSettings, + QWebEngineView, + ) +elif PYSIDE6: + from PySide6.QtWebEngineCore import ( + QWebEnginePage, + QWebEngineProfile, + QWebEngineScript, + QWebEngineSettings, + ) + from PySide6.QtWebEngineWidgets import * diff --git a/python3.10libs/qtpy/QtWebSockets.py b/python3.10libs/qtpy/QtWebSockets.py new file mode 100644 index 0000000..a9bd33d --- /dev/null +++ b/python3.10libs/qtpy/QtWebSockets.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtWebSockets classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtWebSockets import * +elif PYQT6: + from PyQt6.QtWebSockets import * +elif PYSIDE2: + from PySide2.QtWebSockets import * +elif PYSIDE6: + from PySide6.QtWebSockets import * diff --git a/python3.10libs/qtpy/QtWidgets.py b/python3.10libs/qtpy/QtWidgets.py new file mode 100644 index 0000000..5e8e088 --- /dev/null +++ b/python3.10libs/qtpy/QtWidgets.py @@ -0,0 +1,217 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2014-2015 Colin Duquesnoy +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides widget classes and functions.""" +from functools import partialmethod + +from packaging.version import parse + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 +from . import QT_VERSION as _qt_version +from ._utils import ( + add_action, + getattr_missing_optional_dep, + possibly_static_exec, + static_method_kwargs_wrapper, +) + +_missing_optional_names = {} + + +def __getattr__(name): + """Custom getattr to chain and wrap errors due to missing optional deps.""" + raise getattr_missing_optional_dep( + name, + module_name=__name__, + optional_names=_missing_optional_names, + ) + + +if PYQT5: + from PyQt5.QtWidgets import * +elif PYQT6: + from PyQt6 import QtWidgets + from PyQt6.QtGui import ( + QAction, + QActionGroup, + QFileSystemModel, + QShortcut, + QUndoCommand, + ) + from PyQt6.QtWidgets import * + + # Attempt to import QOpenGLWidget, but if that fails, + # don't raise an exception until the name is explicitly accessed. + # See https://github.com/spyder-ide/qtpy/pull/387/ + try: + from PyQt6.QtOpenGLWidgets import QOpenGLWidget + except ImportError as error: + _missing_optional_names["QOpenGLWidget"] = { + "name": "PyQt6.QtOpenGLWidgets", + "missing_package": "pyopengl", + "import_error": error, + } + + # Map missing/renamed methods + QTextEdit.setTabStopWidth = ( + lambda self, *args, **kwargs: self.setTabStopDistance(*args, **kwargs) + ) + QTextEdit.tabStopWidth = ( + lambda self, *args, **kwargs: self.tabStopDistance(*args, **kwargs) + ) + QTextEdit.print_ = lambda self, *args, **kwargs: self.print( + *args, + **kwargs, + ) + QPlainTextEdit.setTabStopWidth = ( + lambda self, *args, **kwargs: self.setTabStopDistance(*args, **kwargs) + ) + QPlainTextEdit.tabStopWidth = ( + lambda self, *args, **kwargs: self.tabStopDistance(*args, **kwargs) + ) + QPlainTextEdit.print_ = lambda self, *args, **kwargs: self.print( + *args, + **kwargs, + ) + QApplication.exec_ = lambda *args, **kwargs: possibly_static_exec( + QApplication, + *args, + **kwargs, + ) + QDialog.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) + QMenu.exec_ = lambda *args, **kwargs: possibly_static_exec( + QMenu, + *args, + **kwargs, + ) + QLineEdit.getTextMargins = lambda self: ( + self.textMargins().left(), + self.textMargins().top(), + self.textMargins().right(), + self.textMargins().bottom(), + ) + + # Add removed definition for `QFileDialog.Options` as an alias of `QFileDialog.Option` + # passing as default value 0 in the same way PySide6 6.5+ does. + # Note that for PyQt5 and PySide2 those definitions are two different classes + # (one is the flag definition and the other the enum definition) + QFileDialog.Options = lambda value=0: QFileDialog.Option(value) + + # Allow unscoped access for enums inside the QtWidgets module + from .enums_compat import promote_enums + + promote_enums(QtWidgets) + del QtWidgets +elif PYSIDE2: + from PySide2.QtWidgets import * +elif PYSIDE6: + from PySide6.QtGui import QAction, QActionGroup, QShortcut, QUndoCommand + from PySide6.QtWidgets import * + + # Attempt to import QOpenGLWidget, but if that fails, + # don't raise an exception until the name is explicitly accessed. + # See https://github.com/spyder-ide/qtpy/pull/387/ + try: + from PySide6.QtOpenGLWidgets import QOpenGLWidget + except ImportError as error: + _missing_optional_names["QOpenGLWidget"] = { + "name": "PySide6.QtOpenGLWidgets", + "missing_package": "pyopengl", + "import_error": error, + } + + # Map missing/renamed methods + QTextEdit.setTabStopWidth = ( + lambda self, *args, **kwargs: self.setTabStopDistance(*args, **kwargs) + ) + QTextEdit.tabStopWidth = ( + lambda self, *args, **kwargs: self.tabStopDistance(*args, **kwargs) + ) + QPlainTextEdit.setTabStopWidth = ( + lambda self, *args, **kwargs: self.setTabStopDistance(*args, **kwargs) + ) + QPlainTextEdit.tabStopWidth = ( + lambda self, *args, **kwargs: self.tabStopDistance(*args, **kwargs) + ) + QLineEdit.getTextMargins = lambda self: ( + self.textMargins().left(), + self.textMargins().top(), + self.textMargins().right(), + self.textMargins().bottom(), + ) + + # Map DeprecationWarning methods + QApplication.exec_ = lambda *args, **kwargs: possibly_static_exec( + QApplication, + *args, + **kwargs, + ) + QDialog.exec_ = lambda self, *args, **kwargs: self.exec(*args, **kwargs) + QMenu.exec_ = lambda *args, **kwargs: possibly_static_exec( + QMenu, + *args, + **kwargs, + ) + + # Passing as default value 0 in the same way PySide6 < 6.3.2 does for the `QFileDialog.Options` definition. + if parse(_qt_version) > parse("6.3"): + QFileDialog.Options = lambda value=0: QFileDialog.Option(value) + + +if PYSIDE2 or PYSIDE6: + # Make PySide2/6 `QFileDialog` static methods accept the `directory` kwarg as `dir` + QFileDialog.getExistingDirectory = static_method_kwargs_wrapper( + QFileDialog.getExistingDirectory, + "directory", + "dir", + ) + QFileDialog.getOpenFileName = static_method_kwargs_wrapper( + QFileDialog.getOpenFileName, + "directory", + "dir", + ) + QFileDialog.getOpenFileNames = static_method_kwargs_wrapper( + QFileDialog.getOpenFileNames, + "directory", + "dir", + ) + QFileDialog.getSaveFileName = static_method_kwargs_wrapper( + QFileDialog.getSaveFileName, + "directory", + "dir", + ) +else: + # Make PyQt5/6 `QFileDialog` static methods accept the `dir` kwarg as `directory` + QFileDialog.getExistingDirectory = static_method_kwargs_wrapper( + QFileDialog.getExistingDirectory, + "dir", + "directory", + ) + QFileDialog.getOpenFileName = static_method_kwargs_wrapper( + QFileDialog.getOpenFileName, + "dir", + "directory", + ) + QFileDialog.getOpenFileNames = static_method_kwargs_wrapper( + QFileDialog.getOpenFileNames, + "dir", + "directory", + ) + QFileDialog.getSaveFileName = static_method_kwargs_wrapper( + QFileDialog.getSaveFileName, + "dir", + "directory", + ) + +# Make `addAction` compatible with Qt6 >= 6.3 +if PYQT5 or PYSIDE2 or parse(_qt_version) < parse("6.3"): + QMenu.addAction = partialmethod(add_action, old_add_action=QMenu.addAction) + QToolBar.addAction = partialmethod( + add_action, + old_add_action=QToolBar.addAction, + ) diff --git a/python3.10libs/qtpy/QtWinExtras.py b/python3.10libs/qtpy/QtWinExtras.py new file mode 100644 index 0000000..bf2fb78 --- /dev/null +++ b/python3.10libs/qtpy/QtWinExtras.py @@ -0,0 +1,31 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides Windows-specific utilities""" + +import sys + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInOSError, + QtModuleNotInQtVersionError, +) + +if sys.platform == "win32": + if PYQT5: + from PyQt5.QtWinExtras import * + elif PYQT6: + raise QtModuleNotInQtVersionError(name="QtWinExtras") + elif PYSIDE2: + from PySide2.QtWinExtras import * + elif PYSIDE6: + raise QtModuleNotInQtVersionError(name="QtWinExtras") +else: + raise QtModuleNotInOSError(name="QtWinExtras") diff --git a/python3.10libs/qtpy/QtX11Extras.py b/python3.10libs/qtpy/QtX11Extras.py new file mode 100644 index 0000000..016727f --- /dev/null +++ b/python3.10libs/qtpy/QtX11Extras.py @@ -0,0 +1,31 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides Linux-specific utilities""" + +import sys + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtModuleNotInOSError, + QtModuleNotInQtVersionError, +) + +if sys.platform == "linux": + if PYQT5: + from PyQt5.QtX11Extras import * + elif PYQT6: + raise QtModuleNotInQtVersionError(name="QtX11Extras") + elif PYSIDE2: + from PySide2.QtX11Extras import * + elif PYSIDE6: + raise QtModuleNotInQtVersionError(name="QtX11Extras") +else: + raise QtModuleNotInOSError(name="QtX11Extras") diff --git a/python3.10libs/qtpy/QtXml.py b/python3.10libs/qtpy/QtXml.py new file mode 100644 index 0000000..5f1e3b8 --- /dev/null +++ b/python3.10libs/qtpy/QtXml.py @@ -0,0 +1,19 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtXml classes and functions.""" + +from . import PYQT5, PYQT6, PYSIDE2, PYSIDE6 + +if PYQT5: + from PyQt5.QtXml import * +elif PYQT6: + from PyQt6.QtXml import * +elif PYSIDE2: + from PySide2.QtXml import * +elif PYSIDE6: + from PySide6.QtXml import * diff --git a/python3.10libs/qtpy/QtXmlPatterns.py b/python3.10libs/qtpy/QtXmlPatterns.py new file mode 100644 index 0000000..a7e0b73 --- /dev/null +++ b/python3.10libs/qtpy/QtXmlPatterns.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides QtXmlPatterns classes and functions.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + from PyQt5.QtXmlPatterns import * +elif PYQT6: + raise QtBindingMissingModuleError(name="QtXmlPatterns") +elif PYSIDE2: + from PySide2.QtXmlPatterns import * +elif PYSIDE6: + raise QtBindingMissingModuleError(name="QtXmlPatterns") diff --git a/python3.10libs/qtpy/__init__.py b/python3.10libs/qtpy/__init__.py new file mode 100644 index 0000000..387cc67 --- /dev/null +++ b/python3.10libs/qtpy/__init__.py @@ -0,0 +1,338 @@ +# +# Copyright © 2009- The Spyder Development Team +# Copyright © 2014-2015 Colin Duquesnoy +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) + +""" +**QtPy** is a shim over the various Python Qt bindings. It is used to write +Qt binding independent libraries or applications. + +If one of the APIs has already been imported, then it will be used. + +Otherwise, the shim will automatically select the first available API (PyQt5, PySide2, +PyQt6 and PySide6); in that case, you can force the use of one +specific bindings (e.g. if your application is using one specific bindings and +you need to use library that use QtPy) by setting up the ``QT_API`` environment +variable. + +PyQt5 +===== + +For PyQt5, you don't have to set anything as it will be used automatically:: + + >>> from qtpy import QtGui, QtWidgets, QtCore + >>> print(QtWidgets.QWidget) + +PySide2 +====== + +Set the QT_API environment variable to 'pyside2' before importing other +packages:: + + >>> import os + >>> os.environ['QT_API'] = 'pyside2' + >>> from qtpy import QtGui, QtWidgets, QtCore + >>> print(QtWidgets.QWidget) + +PyQt6 +===== + + >>> import os + >>> os.environ['QT_API'] = 'pyqt6' + >>> from qtpy import QtGui, QtWidgets, QtCore + >>> print(QtWidgets.QWidget) + +PySide6 +======= + + >>> import os + >>> os.environ['QT_API'] = 'pyside6' + >>> from qtpy import QtGui, QtWidgets, QtCore + >>> print(QtWidgets.QWidget) + +""" + +import contextlib +import os +import platform +import sys +import warnings + +from packaging.version import parse + +# Version of QtPy +__version__ = "2.4.1" + + +class PythonQtError(RuntimeError): + """Generic error superclass for QtPy.""" + + +class PythonQtWarning(RuntimeWarning): + """Warning class for QtPy.""" + + +class PythonQtValueError(ValueError): + """Error raised if an invalid QT_API is specified.""" + + +class QtBindingsNotFoundError(PythonQtError, ImportError): + """Error raised if no bindings could be selected.""" + + _msg = "No Qt bindings could be found" + + def __init__(self): + super().__init__(self._msg) + + +class QtModuleNotFoundError(ModuleNotFoundError, PythonQtError): + """Raised when a Python Qt binding submodule is not installed/supported.""" + + _msg = "The {name} module was not found." + _msg_binding = "{binding}" + _msg_extra = "" + + def __init__(self, *, name, msg=None, **msg_kwargs): + global API_NAME + binding = self._msg_binding.format(binding=API_NAME) + msg = msg or f"{self._msg} {self._msg_extra}".strip() + msg = msg.format(name=name, binding=binding, **msg_kwargs) + super().__init__(msg, name=name) + + +class QtModuleNotInOSError(QtModuleNotFoundError): + """Raised when a module is not supported on the current operating system.""" + + _msg = "{name} does not exist on this operating system." + + +class QtModuleNotInQtVersionError(QtModuleNotFoundError): + """Raised when a module is not implemented in the current Qt version.""" + + _msg = "{name} does not exist in {version}." + + def __init__(self, *, name, msg=None, **msg_kwargs): + global QT5, QT6 + version = "Qt5" if QT5 else "Qt6" + super().__init__(name=name, version=version) + + +class QtBindingMissingModuleError(QtModuleNotFoundError): + """Raised when a module is not supported by a given binding.""" + + _msg_extra = "It is not currently implemented in {binding}." + + +class QtModuleNotInstalledError(QtModuleNotFoundError): + """Raise when a module is supported by the binding, but not installed.""" + + _msg_extra = "It must be installed separately" + + def __init__(self, *, missing_package=None, **superclass_kwargs): + self.missing_package = missing_package + if missing_package is not None: + self._msg_extra += " as {missing_package}." + super().__init__(missing_package=missing_package, **superclass_kwargs) + + +# Qt API environment variable name +QT_API = "QT_API" + +# Names of the expected PyQt5 api +PYQT5_API = ["pyqt5"] + +PYQT6_API = ["pyqt6"] + +# Names of the expected PySide2 api +PYSIDE2_API = ["pyside2"] + +# Names of the expected PySide6 api +PYSIDE6_API = ["pyside6"] + +# Minimum supported versions of Qt and the bindings +QT5_VERSION_MIN = PYQT5_VERSION_MIN = "5.9.0" +PYSIDE2_VERSION_MIN = "5.12.0" +QT6_VERSION_MIN = PYQT6_VERSION_MIN = PYSIDE6_VERSION_MIN = "6.2.0" + +QT_VERSION_MIN = QT5_VERSION_MIN +PYQT_VERSION_MIN = PYQT5_VERSION_MIN +PYSIDE_VERSION_MIN = PYSIDE2_VERSION_MIN + +# Detecting if a binding was specified by the user +binding_specified = QT_API in os.environ + +API_NAMES = { + "pyqt5": "PyQt5", + "pyside2": "PySide2", + "pyqt6": "PyQt6", + "pyside6": "PySide6", +} +API = os.environ.get(QT_API, "pyqt5").lower() +initial_api = API +if API not in API_NAMES: + raise PythonQtValueError( + f"Specified QT_API={QT_API.lower()!r} is not in valid options: " + f"{API_NAMES}", + ) + +is_old_pyqt = is_pyqt46 = False +QT5 = PYQT5 = True +QT4 = QT6 = PYQT4 = PYQT6 = PYSIDE = PYSIDE2 = PYSIDE6 = False + +PYQT_VERSION = None +PYSIDE_VERSION = None +QT_VERSION = None + +# Unless `FORCE_QT_API` is set, use previously imported Qt Python bindings +if not os.environ.get("FORCE_QT_API"): + if "PyQt5" in sys.modules: + API = initial_api if initial_api in PYQT5_API else "pyqt5" + elif "PySide2" in sys.modules: + API = initial_api if initial_api in PYSIDE2_API else "pyside2" + elif "PyQt6" in sys.modules: + API = initial_api if initial_api in PYQT6_API else "pyqt6" + elif "PySide6" in sys.modules: + API = initial_api if initial_api in PYSIDE6_API else "pyside6" + +if API in PYQT5_API: + try: + from PyQt5.QtCore import ( + PYQT_VERSION_STR as PYQT_VERSION, + ) + from PyQt5.QtCore import ( + QT_VERSION_STR as QT_VERSION, + ) + + QT5 = PYQT5 = True + + if sys.platform == "darwin": + macos_version = parse(platform.mac_ver()[0]) + qt_ver = parse(QT_VERSION) + if macos_version < parse("10.10") and qt_ver >= parse("5.9"): + raise PythonQtError( + "Qt 5.9 or higher only works in " + "macOS 10.10 or higher. Your " + "program will fail in this " + "system.", + ) + elif macos_version < parse("10.11") and qt_ver >= parse("5.11"): + raise PythonQtError( + "Qt 5.11 or higher only works in " + "macOS 10.11 or higher. Your " + "program will fail in this " + "system.", + ) + + del macos_version + del qt_ver + except ImportError: + API = "pyside2" + else: + os.environ[QT_API] = API + +if API in PYSIDE2_API: + try: + from PySide2 import __version__ as PYSIDE_VERSION # analysis:ignore + from PySide2.QtCore import __version__ as QT_VERSION # analysis:ignore + + PYQT5 = False + QT5 = PYSIDE2 = True + + if sys.platform == "darwin": + macos_version = parse(platform.mac_ver()[0]) + qt_ver = parse(QT_VERSION) + if macos_version < parse("10.11") and qt_ver >= parse("5.11"): + raise PythonQtError( + "Qt 5.11 or higher only works in " + "macOS 10.11 or higher. Your " + "program will fail in this " + "system.", + ) + + del macos_version + del qt_ver + except ImportError: + API = "pyqt6" + else: + os.environ[QT_API] = API + +if API in PYQT6_API: + try: + from PyQt6.QtCore import ( + PYQT_VERSION_STR as PYQT_VERSION, + ) + from PyQt6.QtCore import ( + QT_VERSION_STR as QT_VERSION, + ) + + QT5 = PYQT5 = False + QT6 = PYQT6 = True + + except ImportError: + API = "pyside6" + else: + os.environ[QT_API] = API + +if API in PYSIDE6_API: + try: + from PySide6 import __version__ as PYSIDE_VERSION # analysis:ignore + from PySide6.QtCore import __version__ as QT_VERSION # analysis:ignore + + QT5 = PYQT5 = False + QT6 = PYSIDE6 = True + + except ImportError: + raise QtBindingsNotFoundError from None + else: + os.environ[QT_API] = API + + +# If a correct API name is passed to QT_API and it could not be found, +# switches to another and informs through the warning +if initial_api != API and binding_specified: + warnings.warn( + f"Selected binding {initial_api!r} could not be found; " + f"falling back to {API!r}", + PythonQtWarning, + stacklevel=2, + ) + + +# Set display name of the Qt API +API_NAME = API_NAMES[API] + +with contextlib.suppress(ImportError, PythonQtError): + # QtDataVisualization backward compatibility (QtDataVisualization vs. QtDatavisualization) + # Only available for Qt5 bindings > 5.9 on Windows + from . import QtDataVisualization as QtDatavisualization # analysis:ignore + + +def _warn_old_minor_version(name, old_version, min_version): + """Warn if using a Qt or binding version no longer supported by QtPy.""" + warning_message = ( + f"{name} version {old_version} is not supported by QtPy. " + "To ensure your application works correctly with QtPy, " + f"please upgrade to {name} {min_version} or later." + ) + warnings.warn(warning_message, PythonQtWarning, stacklevel=2) + + +# Warn if using an End of Life or unsupported Qt API/binding minor version +if QT_VERSION: + if QT5 and (parse(QT_VERSION) < parse(QT5_VERSION_MIN)): + _warn_old_minor_version("Qt5", QT_VERSION, QT5_VERSION_MIN) + elif QT6 and (parse(QT_VERSION) < parse(QT6_VERSION_MIN)): + _warn_old_minor_version("Qt6", QT_VERSION, QT6_VERSION_MIN) + +if PYQT_VERSION: + if PYQT5 and (parse(PYQT_VERSION) < parse(PYQT5_VERSION_MIN)): + _warn_old_minor_version("PyQt5", PYQT_VERSION, PYQT5_VERSION_MIN) + elif PYQT6 and (parse(PYQT_VERSION) < parse(PYQT6_VERSION_MIN)): + _warn_old_minor_version("PyQt6", PYQT_VERSION, PYQT6_VERSION_MIN) +elif PYSIDE_VERSION: + if PYSIDE2 and (parse(PYSIDE_VERSION) < parse(PYSIDE2_VERSION_MIN)): + _warn_old_minor_version("PySide2", PYSIDE_VERSION, PYSIDE2_VERSION_MIN) + elif PYSIDE6 and (parse(PYSIDE_VERSION) < parse(PYSIDE6_VERSION_MIN)): + _warn_old_minor_version("PySide6", PYSIDE_VERSION, PYSIDE6_VERSION_MIN) diff --git a/python3.10libs/qtpy/__main__.py b/python3.10libs/qtpy/__main__.py new file mode 100644 index 0000000..a8f993c --- /dev/null +++ b/python3.10libs/qtpy/__main__.py @@ -0,0 +1,18 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The QtPy Contributors +# +# Released under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Dev CLI entry point for QtPy, a compat layer for the Python Qt bindings.""" + +import qtpy.cli + + +def main(): + return qtpy.cli.main() + + +if __name__ == "__main__": + main() diff --git a/python3.10libs/qtpy/_utils.py b/python3.10libs/qtpy/_utils.py new file mode 100644 index 0000000..ec851fa --- /dev/null +++ b/python3.10libs/qtpy/_utils.py @@ -0,0 +1,161 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2023- The Spyder Development Team +# +# Released under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides utility functions for use by QtPy itself.""" +from functools import wraps +from typing import TYPE_CHECKING + +import qtpy + +if TYPE_CHECKING: + from qtpy.QtWidgets import QAction + + +def _wrap_missing_optional_dep_error( + attr_error, + *, + import_error, + wrapper=qtpy.QtModuleNotInstalledError, + **wrapper_kwargs, +): + """Create a __cause__-chained wrapper error for a missing optional dep.""" + qtpy_error = wrapper(**wrapper_kwargs) + import_error.__cause__ = attr_error + qtpy_error.__cause__ = import_error + return qtpy_error + + +def getattr_missing_optional_dep(name, module_name, optional_names): + """Wrap AttributeError in a special error if it matches.""" + attr_error = AttributeError( + f"module {module_name!r} has no attribute {name!r}", + ) + if name in optional_names: + return _wrap_missing_optional_dep_error( + attr_error, + **optional_names[name], + ) + return attr_error + + +def possibly_static_exec(cls, *args, **kwargs): + """Call `self.exec` when `self` is given or a static method otherwise.""" + if not args and not kwargs: + # A special case (`cls.exec_()`) to avoid the function resolving error + return cls.exec() + if isinstance(args[0], cls): + if len(args) == 1 and not kwargs: + # A special case (`self.exec_()`) to avoid the function resolving error + return args[0].exec() + return args[0].exec(*args[1:], **kwargs) + + return cls.exec(*args, **kwargs) + + +def possibly_static_exec_(cls, *args, **kwargs): + """Call `self.exec` when `self` is given or a static method otherwise.""" + if not args and not kwargs: + # A special case (`cls.exec()`) to avoid the function resolving error + return cls.exec_() + if isinstance(args[0], cls): + if len(args) == 1 and not kwargs: + # A special case (`self.exec()`) to avoid the function resolving error + return args[0].exec_() + return args[0].exec_(*args[1:], **kwargs) + + return cls.exec_(*args, **kwargs) + + +def add_action(self, *args, old_add_action): + """Re-order arguments of `addAction` to backport compatibility with Qt>=6.3.""" + from qtpy.QtCore import QObject + from qtpy.QtGui import QIcon, QKeySequence + + action: QAction + icon: QIcon + text: str + shortcut: QKeySequence | QKeySequence.StandardKey | str | int + receiver: QObject + member: bytes + if all( + isinstance(arg, t) + for arg, t in zip( + args, + [ + str, + (QKeySequence, QKeySequence.StandardKey, str, int), + QObject, + bytes, + ], + ) + ): + if len(args) == 2: + text, shortcut = args + action = old_add_action(self, text) + action.setShortcut(shortcut) + elif len(args) == 3: + text, shortcut, receiver = args + action = old_add_action(self, text, receiver) + action.setShortcut(shortcut) + elif len(args) == 4: + text, shortcut, receiver, member = args + action = old_add_action(self, text, receiver, member, shortcut) + else: + return old_add_action(self, *args) + return action + if all( + isinstance(arg, t) + for arg, t in zip( + args, + [ + QIcon, + str, + (QKeySequence, QKeySequence.StandardKey, str, int), + QObject, + bytes, + ], + ) + ): + if len(args) == 3: + icon, text, shortcut = args + action = old_add_action(self, icon, text) + action.setShortcut(QKeySequence(shortcut)) + elif len(args) == 4: + icon, text, shortcut, receiver = args + action = old_add_action(self, icon, text, receiver) + action.setShortcut(QKeySequence(shortcut)) + elif len(args) == 5: + icon, text, shortcut, receiver, member = args + action = old_add_action( + self, + icon, + text, + receiver, + member, + QKeySequence(shortcut), + ) + else: + return old_add_action(self, *args) + return action + return old_add_action(self, *args) + + +def static_method_kwargs_wrapper(func, from_kwarg_name, to_kwarg_name): + """ + Helper function to manage `from_kwarg_name` to `to_kwarg_name` kwargs name changes in static methods. + + Makes static methods accept the `from_kwarg_name` kwarg as `to_kwarg_name`. + """ + + @staticmethod + @wraps(func) + def _from_kwarg_name_to_kwarg_name_(*args, **kwargs): + if from_kwarg_name in kwargs: + kwargs[to_kwarg_name] = kwargs.pop(from_kwarg_name) + return func(*args, **kwargs) + + return _from_kwarg_name_to_kwarg_name_ diff --git a/python3.10libs/qtpy/cli.py b/python3.10libs/qtpy/cli.py new file mode 100644 index 0000000..95da025 --- /dev/null +++ b/python3.10libs/qtpy/cli.py @@ -0,0 +1,166 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The QtPy Contributors +# +# Released under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provide a CLI to allow configuring developer settings, including mypy.""" + +# Standard library imports +import argparse +import json +import textwrap + + +def print_version(): + """Print the current version of the package.""" + import qtpy + + print("QtPy version", qtpy.__version__) + + +def get_api_status(): + """Get the status of each Qt API usage.""" + import qtpy + + return {name: name == qtpy.API for name in qtpy.API_NAMES} + + +def generate_mypy_args(): + """Generate a string with always-true/false args to pass to mypy.""" + options = {False: "--always-false", True: "--always-true"} + + apis_active = get_api_status() + return " ".join( + f"{options[is_active]}={name.upper()}" + for name, is_active in apis_active.items() + ) + + +def generate_pyright_config_json(): + """Generate Pyright config to be used in `pyrightconfig.json`.""" + apis_active = get_api_status() + + return json.dumps( + { + "defineConstant": { + name.upper(): is_active + for name, is_active in apis_active.items() + }, + }, + ) + + +def generate_pyright_config_toml(): + """Generate a Pyright config to be used in `pyproject.toml`.""" + apis_active = get_api_status() + + return "[tool.pyright.defineConstant]\n" + "\n".join( + f"{name.upper()} = {str(is_active).lower()}" + for name, is_active in apis_active.items() + ) + + +def print_mypy_args(): + """Print the generated mypy args to stdout.""" + print(generate_mypy_args()) + + +def print_pyright_config_json(): + """Print the generated Pyright JSON config to stdout.""" + print(generate_pyright_config_json()) + + +def print_pyright_config_toml(): + """Print the generated Pyright TOML config to stdout.""" + print(generate_pyright_config_toml()) + + +def print_pyright_configs(): + """Print the generated Pyright configs to stdout.""" + print("pyrightconfig.json:") + print_pyright_config_json() + print() + print("pyproject.toml:") + print_pyright_config_toml() + + +def generate_arg_parser(): + """Generate the argument parser for the dev CLI for QtPy.""" + parser = argparse.ArgumentParser( + description="Features to support development with QtPy.", + ) + parser.set_defaults(func=parser.print_help) + + parser.add_argument( + "--version", + action="store_const", + dest="func", + const=print_version, + help="If passed, will print the version and exit", + ) + + cli_subparsers = parser.add_subparsers( + title="Subcommands", + help="Subcommand to run", + metavar="Subcommand", + ) + + # Parser for the MyPy args subcommand + mypy_args_parser = cli_subparsers.add_parser( + name="mypy-args", + help="Generate command line arguments for using mypy with QtPy.", + formatter_class=argparse.RawTextHelpFormatter, + description=textwrap.dedent( + """ + Generate command line arguments for using mypy with QtPy. + + This will generate strings similar to the following + which help guide mypy through which library QtPy would have used + so that mypy can get the proper underlying type hints. + + --always-false=PYQT5 --always-false=PYQT6 --always-true=PYSIDE2 --always-false=PYSIDE6 + + It can be used as follows on Bash or a similar shell: + + mypy --package mypackage $(qtpy mypy-args) + """, + ), + ) + mypy_args_parser.set_defaults(func=print_mypy_args) + + # Parser for the Pyright config subcommand + pyright_config_parser = cli_subparsers.add_parser( + name="pyright-config", + help="Generate Pyright config for using Pyright with QtPy.", + formatter_class=argparse.RawTextHelpFormatter, + description=textwrap.dedent( + """ + Generate Pyright config for using Pyright with QtPy. + + This will generate config sections to be included in a Pyright + config file (either `pyrightconfig.json` or `pyproject.toml`) + which help guide Pyright through which library QtPy would have used + so that Pyright can get the proper underlying type hints. + + """, + ), + ) + pyright_config_parser.set_defaults(func=print_pyright_configs) + + return parser + + +def main(args=None): + """Run the development CLI for QtPy.""" + parser = generate_arg_parser() + parsed_args = parser.parse_args(args=args) + + reserved_params = {"func"} + cleaned_args = { + key: value + for key, value in vars(parsed_args).items() + if key not in reserved_params + } + parsed_args.func(**cleaned_args) diff --git a/python3.10libs/qtpy/compat.py b/python3.10libs/qtpy/compat.py new file mode 100644 index 0000000..4c6d428 --- /dev/null +++ b/python3.10libs/qtpy/compat.py @@ -0,0 +1,202 @@ +# +# Copyright © 2009- The Spyder Development Team +# Licensed under the terms of the MIT License + +""" +Compatibility functions +""" +import sys + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, +) +from .QtWidgets import QFileDialog + +TEXT_TYPES = (str,) + + +def is_text_string(obj): + """Return True if `obj` is a text string, False if it is anything else, + like binary data.""" + return isinstance(obj, str) + + +def to_text_string(obj, encoding=None): + """Convert `obj` to (unicode) text string""" + if encoding is None: + return str(obj) + if isinstance(obj, str): + # In case this function is not used properly, this could happen + return obj + + return str(obj, encoding) + + +# ============================================================================= +# QVariant conversion utilities +# ============================================================================= +PYQT_API_1 = False + + +def to_qvariant(obj=None): # analysis:ignore + """Convert Python object to QVariant + This is a transitional function from PyQt API#1 (QVariant exist) + to PyQt API#2 and Pyside (QVariant does not exist)""" + return obj + + +def from_qvariant(qobj=None, pytype=None): # analysis:ignore + """Convert QVariant object to Python object + This is a transitional function from PyQt API #1 (QVariant exist) + to PyQt API #2 and Pyside (QVariant does not exist)""" + return qobj + + +# ============================================================================= +# Wrappers around QFileDialog static methods +# ============================================================================= +def getexistingdirectory( + parent=None, + caption="", + basedir="", + options=QFileDialog.ShowDirsOnly, +): + """Wrapper around QtGui.QFileDialog.getExistingDirectory static method + Compatible with PyQt >=v4.4 (API #1 and #2) and PySide >=v1.0""" + # Calling QFileDialog static method + if sys.platform == "win32": + # On Windows platforms: redirect standard outputs + _temp1, _temp2 = sys.stdout, sys.stderr + sys.stdout, sys.stderr = None, None + try: + result = QFileDialog.getExistingDirectory( + parent, + caption, + basedir, + options, + ) + finally: + if sys.platform == "win32": + # On Windows platforms: restore standard outputs + sys.stdout, sys.stderr = _temp1, _temp2 + if not is_text_string(result): + # PyQt API #1 + result = to_text_string(result) + return result + + +def _qfiledialog_wrapper( + attr, + parent=None, + caption="", + basedir="", + filters="", + selectedfilter="", + options=None, +): + if options is None: + options = QFileDialog.Option(0) + + func = getattr(QFileDialog, attr) + + # Calling QFileDialog static method + if sys.platform == "win32": + # On Windows platforms: redirect standard outputs + _temp1, _temp2 = sys.stdout, sys.stderr + sys.stdout, sys.stderr = None, None + result = func(parent, caption, basedir, filters, selectedfilter, options) + if sys.platform == "win32": + # On Windows platforms: restore standard outputs + sys.stdout, sys.stderr = _temp1, _temp2 + + output, selectedfilter = result + + # Always returns the tuple (output, selectedfilter) + return output, selectedfilter + + +def getopenfilename( + parent=None, + caption="", + basedir="", + filters="", + selectedfilter="", + options=None, +): + """Wrapper around QtGui.QFileDialog.getOpenFileName static method + Returns a tuple (filename, selectedfilter) -- when dialog box is canceled, + returns a tuple of empty strings + Compatible with PyQt >=v4.4 (API #1 and #2) and PySide >=v1.0""" + return _qfiledialog_wrapper( + "getOpenFileName", + parent=parent, + caption=caption, + basedir=basedir, + filters=filters, + selectedfilter=selectedfilter, + options=options, + ) + + +def getopenfilenames( + parent=None, + caption="", + basedir="", + filters="", + selectedfilter="", + options=None, +): + """Wrapper around QtGui.QFileDialog.getOpenFileNames static method + Returns a tuple (filenames, selectedfilter) -- when dialog box is canceled, + returns a tuple (empty list, empty string) + Compatible with PyQt >=v4.4 (API #1 and #2) and PySide >=v1.0""" + return _qfiledialog_wrapper( + "getOpenFileNames", + parent=parent, + caption=caption, + basedir=basedir, + filters=filters, + selectedfilter=selectedfilter, + options=options, + ) + + +def getsavefilename( + parent=None, + caption="", + basedir="", + filters="", + selectedfilter="", + options=None, +): + """Wrapper around QtGui.QFileDialog.getSaveFileName static method + Returns a tuple (filename, selectedfilter) -- when dialog box is canceled, + returns a tuple of empty strings + Compatible with PyQt >=v4.4 (API #1 and #2) and PySide >=v1.0""" + return _qfiledialog_wrapper( + "getSaveFileName", + parent=parent, + caption=caption, + basedir=basedir, + filters=filters, + selectedfilter=selectedfilter, + options=options, + ) + + +# ============================================================================= +def isalive(obj): + """Wrapper around sip.isdeleted and shiboken.isValid which tests whether + an object is currently alive.""" + if PYQT5 or PYQT6: + from . import sip + + return not sip.isdeleted(obj) + if PYSIDE2 or PYSIDE6: + from . import shiboken + + return shiboken.isValid(obj) + return None diff --git a/python3.10libs/qtpy/enums_compat.py b/python3.10libs/qtpy/enums_compat.py new file mode 100644 index 0000000..89a7d11 --- /dev/null +++ b/python3.10libs/qtpy/enums_compat.py @@ -0,0 +1,39 @@ +# Copyright © 2009- The Spyder Development Team +# Copyright © 2012- University of North Carolina at Chapel Hill +# Luke Campagnola ('luke.campagnola@%s.com' % 'gmail') +# Ogi Moore ('ognyan.moore@%s.com' % 'gmail') +# KIU Shueng Chuan ('nixchuan@%s.com' % 'gmail') +# Licensed under the terms of the MIT License + +""" +Compatibility functions for scoped and unscoped enum access. +""" + +from . import PYQT6 + +if PYQT6: + import enum + + from . import sip + + def promote_enums(module): + """ + Search enums in the given module and allow unscoped access. + + Taken from: + https://github.com/pyqtgraph/pyqtgraph/blob/pyqtgraph-0.12.1/pyqtgraph/Qt.py#L331-L377 + and adapted to also copy enum values aliased under different names. + + """ + class_names = [name for name in dir(module) if name.startswith("Q")] + for class_name in class_names: + klass = getattr(module, class_name) + if not isinstance(klass, sip.wrappertype): + continue + attrib_names = [name for name in dir(klass) if name[0].isupper()] + for attrib_name in attrib_names: + attrib = getattr(klass, attrib_name) + if not isinstance(attrib, enum.EnumMeta): + continue + for name, value in attrib.__members__.items(): + setattr(klass, name, value) diff --git a/python3.10libs/qtpy/py.typed b/python3.10libs/qtpy/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/python3.10libs/qtpy/shiboken.py b/python3.10libs/qtpy/shiboken.py new file mode 100644 index 0000000..3e20a0c --- /dev/null +++ b/python3.10libs/qtpy/shiboken.py @@ -0,0 +1,23 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides access to shiboken.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5 or PYQT6: + raise QtBindingMissingModuleError(name="shiboken") +elif PYSIDE2: + from shiboken2 import * +elif PYSIDE6: + from shiboken6 import * diff --git a/python3.10libs/qtpy/sip.py b/python3.10libs/qtpy/sip.py new file mode 100644 index 0000000..205538c --- /dev/null +++ b/python3.10libs/qtpy/sip.py @@ -0,0 +1,23 @@ +# ----------------------------------------------------------------------------- +# Copyright © 2009- The Spyder Development Team +# +# Licensed under the terms of the MIT License +# (see LICENSE.txt for details) +# ----------------------------------------------------------------------------- + +"""Provides access to sip.""" + +from . import ( + PYQT5, + PYQT6, + PYSIDE2, + PYSIDE6, + QtBindingMissingModuleError, +) + +if PYQT5: + from PyQt5.sip import * +elif PYQT6: + from PyQt6.sip import * +elif PYSIDE2 or PYSIDE6: + raise QtBindingMissingModuleError(name="sip") diff --git a/python3.10libs/qtpy/tests/__init__.py b/python3.10libs/qtpy/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python3.10libs/qtpy/tests/conftest.py b/python3.10libs/qtpy/tests/conftest.py new file mode 100644 index 0000000..f5c5f9a --- /dev/null +++ b/python3.10libs/qtpy/tests/conftest.py @@ -0,0 +1,86 @@ +import os + +import pytest + + +def pytest_configure(config): + """Configure the test environment.""" + + if "USE_QT_API" in os.environ: + os.environ["QT_API"] = os.environ["USE_QT_API"].lower() + + # We need to import qtpy here to make sure that the API versions get set + # straight away. + import qtpy + + +def pytest_report_header(config): + """Insert a customized header into the test report.""" + + versions = os.linesep + versions += "PyQt5: " + + try: + from PyQt5 import Qt + + versions += f"PyQt: {Qt.PYQT_VERSION_STR} - Qt: {Qt.QT_VERSION_STR}" + except ImportError: + versions += "not installed" + except AttributeError: + versions += "unknown version" + + versions += os.linesep + versions += "PySide2: " + + try: + import PySide2 + from PySide2 import QtCore + + versions += f"PySide: {PySide2.__version__} - Qt: {QtCore.__version__}" + except ImportError: + versions += "not installed" + except AttributeError: + versions += "unknown version" + + versions += os.linesep + versions += "PyQt6: " + + try: + from PyQt6 import QtCore + + versions += ( + f"PyQt: {QtCore.PYQT_VERSION_STR} - Qt: {QtCore.QT_VERSION_STR}" + ) + except ImportError: + versions += "not installed" + except AttributeError: + versions += "unknown version" + + versions += os.linesep + versions += "PySide6: " + + try: + import PySide6 + from PySide6 import QtCore + + versions += f"PySide: {PySide6.__version__} - Qt: {QtCore.__version__}" + except ImportError: + versions += "not installed" + except AttributeError: + versions += "unknown version" + + versions += os.linesep + + return versions + + +@pytest.fixture +def pdf_writer(qtbot): + from pathlib import Path + + from qtpy import QtGui + + output_path = Path("test.pdf") + device = QtGui.QPdfWriter(str(output_path)) + yield device, output_path + output_path.unlink() diff --git a/python3.10libs/qtpy/tests/optional_deps/__init__.py b/python3.10libs/qtpy/tests/optional_deps/__init__.py new file mode 100644 index 0000000..2c95744 --- /dev/null +++ b/python3.10libs/qtpy/tests/optional_deps/__init__.py @@ -0,0 +1,30 @@ +"""Package used for testing the deferred import error mechanism.""" + + +# See https://github.com/spyder-ide/qtpy/pull/387/ + + +from qtpy._utils import getattr_missing_optional_dep + +from .optional_dep import ExampleClass + +_missing_optional_names = {} + + +try: + from .optional_dep import MissingClass +except ImportError as error: + _missing_optional_names["MissingClass"] = { + "name": "optional_dep.MissingClass", + "missing_package": "test_package_please_ignore", + "import_error": error, + } + + +def __getattr__(name): + """Custom getattr to chain and wrap errors due to missing optional deps.""" + raise getattr_missing_optional_dep( + name, + module_name=__name__, + optional_names=_missing_optional_names, + ) diff --git a/python3.10libs/qtpy/tests/optional_deps/optional_dep.py b/python3.10libs/qtpy/tests/optional_deps/optional_dep.py new file mode 100644 index 0000000..818152b --- /dev/null +++ b/python3.10libs/qtpy/tests/optional_deps/optional_dep.py @@ -0,0 +1,5 @@ +"""Test module with an optional dependency that may be missing.""" + + +class ExampleClass: + pass diff --git a/python3.10libs/qtpy/tests/test.ui b/python3.10libs/qtpy/tests/test.ui new file mode 100644 index 0000000..8f0a67c --- /dev/null +++ b/python3.10libs/qtpy/tests/test.ui @@ -0,0 +1,35 @@ + + + Form + + + + 0 + 0 + 400 + 300 + + + + Form + + + + + + + + Ceci n'est pas un bouton + + + + + + + + + + + + + diff --git a/python3.10libs/qtpy/tests/test_cli.py b/python3.10libs/qtpy/tests/test_cli.py new file mode 100644 index 0000000..4b44950 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_cli.py @@ -0,0 +1,156 @@ +"""Test the QtPy CLI.""" + +import subprocess +import sys +import textwrap + +import pytest + +import qtpy + +SUBCOMMANDS = [ + [], + ["mypy-args"], +] + + +@pytest.mark.parametrize( + argnames=["subcommand"], + argvalues=[[subcommand] for subcommand in SUBCOMMANDS], + ids=[" ".join(subcommand) for subcommand in SUBCOMMANDS], +) +def test_cli_help_does_not_fail(subcommand): + subprocess.run( + [sys.executable, "-m", "qtpy", *subcommand, "--help"], + check=True, + ) + + +def test_cli_version(): + output = subprocess.run( + [sys.executable, "-m", "qtpy", "--version"], + capture_output=True, + check=True, + encoding="utf-8", + ) + assert output.stdout.strip().split()[-1] == qtpy.__version__ + + +def test_cli_mypy_args(): + output = subprocess.run( + [sys.executable, "-m", "qtpy", "mypy-args"], + capture_output=True, + check=True, + encoding="utf-8", + ) + + if qtpy.PYQT5: + expected = " ".join( + [ + "--always-true=PYQT5", + "--always-false=PYSIDE2", + "--always-false=PYQT6", + "--always-false=PYSIDE6", + ], + ) + elif qtpy.PYSIDE2: + expected = " ".join( + [ + "--always-false=PYQT5", + "--always-true=PYSIDE2", + "--always-false=PYQT6", + "--always-false=PYSIDE6", + ], + ) + elif qtpy.PYQT6: + expected = " ".join( + [ + "--always-false=PYQT5", + "--always-false=PYSIDE2", + "--always-true=PYQT6", + "--always-false=PYSIDE6", + ], + ) + elif qtpy.PYSIDE6: + expected = " ".join( + [ + "--always-false=PYQT5", + "--always-false=PYSIDE2", + "--always-false=PYQT6", + "--always-true=PYSIDE6", + ], + ) + else: + pytest.fail("No Qt bindings detected") + + assert output.stdout.strip() == expected.strip() + + +def test_cli_pyright_config(): + output = subprocess.run( + [sys.executable, "-m", "qtpy", "pyright-config"], + capture_output=True, + check=True, + encoding="utf-8", + ) + + if qtpy.PYQT5: + expected = textwrap.dedent( + """ + pyrightconfig.json: + {"defineConstant": {"PYQT5": true, "PYSIDE2": false, "PYQT6": false, "PYSIDE6": false}} + + pyproject.toml: + [tool.pyright.defineConstant] + PYQT5 = true + PYSIDE2 = false + PYQT6 = false + PYSIDE6 = false + """, + ) + elif qtpy.PYSIDE2: + expected = textwrap.dedent( + """ + pyrightconfig.json: + {"defineConstant": {"PYQT5": false, "PYSIDE2": true, "PYQT6": false, "PYSIDE6": false}} + + pyproject.toml: + [tool.pyright.defineConstant] + PYQT5 = false + PYSIDE2 = true + PYQT6 = false + PYSIDE6 = false + """, + ) + elif qtpy.PYQT6: + expected = textwrap.dedent( + """ + pyrightconfig.json: + {"defineConstant": {"PYQT5": false, "PYSIDE2": false, "PYQT6": true, "PYSIDE6": false}} + + pyproject.toml: + [tool.pyright.defineConstant] + PYQT5 = false + PYSIDE2 = false + PYQT6 = true + PYSIDE6 = false + """, + ) + elif qtpy.PYSIDE6: + expected = textwrap.dedent( + """ + pyrightconfig.json: + {"defineConstant": {"PYQT5": false, "PYSIDE2": false, "PYQT6": false, "PYSIDE6": true}} + + pyproject.toml: + [tool.pyright.defineConstant] + PYQT5 = false + PYSIDE2 = false + PYQT6 = false + PYSIDE6 = true + """, + ) + else: + pytest.fail("No valid API to test") + + assert output.stdout.strip() == expected.strip() diff --git a/python3.10libs/qtpy/tests/test_compat.py b/python3.10libs/qtpy/tests/test_compat.py new file mode 100644 index 0000000..1e1fc28 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_compat.py @@ -0,0 +1,24 @@ +"""Test the compat module.""" +import sys + +import pytest + +from qtpy import QtWidgets, compat +from qtpy.tests.utils import not_using_conda + + +@pytest.mark.skipif( + ( + (sys.version_info.major == 3 and sys.version_info.minor == 7) + and sys.platform.startswith("win") + and not not_using_conda() + ), + reason="sip not included in Python3.7 on Windows", +) +def test_isalive(qtbot): + """Test compat.isalive""" + test_widget = QtWidgets.QWidget() + assert compat.isalive(test_widget) is True + with qtbot.waitSignal(test_widget.destroyed): + test_widget.deleteLater() + assert compat.isalive(test_widget) is False diff --git a/python3.10libs/qtpy/tests/test_custom.ui b/python3.10libs/qtpy/tests/test_custom.ui new file mode 100644 index 0000000..f74b5c7 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_custom.ui @@ -0,0 +1,42 @@ + + + Form + + + + 0 + 0 + 400 + 300 + + + + Form + + + + + + + + Ceci n'est pas un bouton + + + + + + + + + + + + + _QComboBoxSubclass + QComboBox +
qcombobox_subclass
+
+
+ + +
diff --git a/python3.10libs/qtpy/tests/test_macos_checks.py b/python3.10libs/qtpy/tests/test_macos_checks.py new file mode 100644 index 0000000..e9698b5 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_macos_checks.py @@ -0,0 +1,106 @@ +import contextlib +import platform +import sys +from unittest import mock + +import pytest + +from qtpy import PYQT5, PYSIDE2 + + +@pytest.mark.skipif(not PYQT5, reason="Targeted to PyQt5") +@mock.patch.object(platform, "mac_ver") +def test_qt59_exception(mac_ver, monkeypatch): + # Remove qtpy to reimport it again + with contextlib.suppress(KeyError): + del sys.modules["qtpy"] + + # Patch stdlib to emulate a macOS system + monkeypatch.setattr("sys.platform", "darwin") + mac_ver.return_value = ("10.9.2",) + + # Patch Qt version + monkeypatch.setattr("PyQt5.QtCore.QT_VERSION_STR", "5.9.1") + + # This should raise an Exception + with pytest.raises(Exception) as e: + import qtpy + + assert "10.10" in str(e.value) + assert "5.9" in str(e.value) + + +@pytest.mark.skipif(not PYQT5, reason="Targeted to PyQt5") +@mock.patch.object(platform, "mac_ver") +def test_qt59_no_exception(mac_ver, monkeypatch): + # Remove qtpy to reimport it again + with contextlib.suppress(KeyError): + del sys.modules["qtpy"] + + # Patch stdlib to emulate a macOS system + monkeypatch.setattr("sys.platform", "darwin") + mac_ver.return_value = ("10.10.1",) + + # Patch Qt version + monkeypatch.setattr("PyQt5.QtCore.QT_VERSION_STR", "5.9.5") + + # This should not raise an Exception + try: + import qtpy + except Exception: # noqa: BLE001 + pytest.fail("Error!") + + +@pytest.mark.skipif( + not (PYQT5 or PYSIDE2), + reason="Targeted to PyQt5 or PySide2", +) +@mock.patch.object(platform, "mac_ver") +def test_qt511_exception(mac_ver, monkeypatch): + # Remove qtpy to reimport it again + with contextlib.suppress(KeyError): + del sys.modules["qtpy"] + + # Patch stdlib to emulate a macOS system + monkeypatch.setattr("sys.platform", "darwin") + mac_ver.return_value = ("10.10.3",) + + # Patch Qt version + if PYQT5: + monkeypatch.setattr("PyQt5.QtCore.QT_VERSION_STR", "5.11.1") + else: + monkeypatch.setattr("PySide2.QtCore.__version__", "5.11.1") + + # This should raise an Exception + with pytest.raises(Exception) as e: + import qtpy + + assert "10.11" in str(e.value) + assert "5.11" in str(e.value) + + +@pytest.mark.skipif( + not (PYQT5 or PYSIDE2), + reason="Targeted to PyQt5 or PySide2", +) +@mock.patch.object(platform, "mac_ver") +def test_qt511_no_exception(mac_ver, monkeypatch): + # Remove qtpy to reimport it again + with contextlib.suppress(KeyError): + del sys.modules["qtpy"] + + # Patch stdlib to emulate a macOS system + monkeypatch.setattr("sys.platform", "darwin") + mac_ver.return_value = ("10.13.2",) + + # Patch Qt version + if PYQT5: + monkeypatch.setattr("PyQt5.QtCore.QT_VERSION_STR", "5.11.1") + else: + monkeypatch.setattr("PySide2.QtCore.__version__", "5.11.1") + + # This should not raise an Exception + try: + import qtpy + except Exception: # noqa: BLE001 + pytest.fail("Error!") diff --git a/python3.10libs/qtpy/tests/test_main.py b/python3.10libs/qtpy/tests/test_main.py new file mode 100644 index 0000000..771c489 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_main.py @@ -0,0 +1,142 @@ +import contextlib +import os +import subprocess +import sys + +import pytest + +from qtpy import API_NAMES, QtCore, QtGui, QtWidgets + +with contextlib.suppress(Exception): + # removed in qt 6.0 + from qtpy import QtWebEngineWidgets + + +def assert_pyside2(): + """ + Make sure that we are using PySide + """ + import PySide2 + + assert QtCore.QEvent is PySide2.QtCore.QEvent + assert QtGui.QPainter is PySide2.QtGui.QPainter + assert QtWidgets.QWidget is PySide2.QtWidgets.QWidget + assert ( + QtWebEngineWidgets.QWebEnginePage + is PySide2.QtWebEngineWidgets.QWebEnginePage + ) + assert os.environ["QT_API"] == "pyside2" + + +def assert_pyside6(): + """ + Make sure that we are using PySide + """ + import PySide6 + + assert QtCore.QEvent is PySide6.QtCore.QEvent + assert QtGui.QPainter is PySide6.QtGui.QPainter + assert QtWidgets.QWidget is PySide6.QtWidgets.QWidget + # Only valid for qt>=6.2 + # assert QtWebEngineWidgets.QWebEnginePage is PySide6.QtWebEngineCore.QWebEnginePage + assert os.environ["QT_API"] == "pyside6" + + +def assert_pyqt5(): + """ + Make sure that we are using PyQt5 + """ + import PyQt5 + + assert QtCore.QEvent is PyQt5.QtCore.QEvent + assert QtGui.QPainter is PyQt5.QtGui.QPainter + assert QtWidgets.QWidget is PyQt5.QtWidgets.QWidget + assert os.environ["QT_API"] == "pyqt5" + + +def assert_pyqt6(): + """ + Make sure that we are using PyQt6 + """ + import PyQt6 + + assert QtCore.QEvent is PyQt6.QtCore.QEvent + assert QtGui.QPainter is PyQt6.QtGui.QPainter + assert QtWidgets.QWidget is PyQt6.QtWidgets.QWidget + assert os.environ["QT_API"] == "pyqt6" + + +def test_qt_api(): + """ + If QT_API is specified, we check that the correct Qt wrapper was used + """ + + QT_API = os.environ.get("QT_API", "").lower() + + if QT_API == "pyqt5": + assert_pyqt5() + elif QT_API == "pyside2": + assert_pyside2() + elif QT_API == "pyqt6": + assert_pyqt6() + elif QT_API == "pyside6": + assert_pyside6() + else: + # If the tests are run locally, USE_QT_API and QT_API may not be + # defined, but we still want to make sure qtpy is behaving sensibly. + # We should then be loading, in order of decreasing preference, PyQt5, + # PySide2, PyQt6 and PySide6. + try: + import PyQt5 + except ImportError: + try: + import PySide2 + except ImportError: + try: + import PyQt6 + except ImportError: + import PySide6 + + assert_pyside6() + else: + assert_pyqt6() + else: + assert_pyside2() + else: + assert_pyqt5() + + +@pytest.mark.parametrize("api", API_NAMES.values()) +def test_qt_api_environ(api): + """ + If no QT_API is specified but some Qt is imported, ensure QT_API is set properly. + """ + mod = f"{api}.QtCore" + pytest.importorskip(mod, reason=f"Requires {api}") + # clean env + env = os.environ.copy() + for key in ("QT_API", "USE_QT_API"): + if key in env: + del env[key] + cmd = f""" +import {mod} +from qtpy import API +import os +print(API) +print(os.environ['QT_API']) +""" + output = subprocess.check_output([sys.executable, "-c", cmd], env=env) + got_api, env_qt_api = output.strip().decode("utf-8").splitlines() + assert got_api == api.lower() + assert env_qt_api == api.lower() + # Also ensure we raise a nice error + env["QT_API"] = "bad" + cmd = """ +try: + import qtpy +except ValueError as exc: + assert 'Specified QT_API' in str(exc), str(exc) +else: + raise AssertionError('QtPy imported despite bad QT_API') +""" + subprocess.check_call([sys.executable, "-Oc", cmd], env=env) diff --git a/python3.10libs/qtpy/tests/test_missing_optional_deps.py b/python3.10libs/qtpy/tests/test_missing_optional_deps.py new file mode 100644 index 0000000..e1ca82b --- /dev/null +++ b/python3.10libs/qtpy/tests/test_missing_optional_deps.py @@ -0,0 +1,22 @@ +"""Test the deferred import error mechanism""" + + +# See https://github.com/spyder-ide/qtpy/pull/387/ + + +import pytest + +from qtpy import QtModuleNotInstalledError + + +def test_missing_optional_deps(): + """Test importing a module that uses the deferred import error mechanism""" + from . import optional_deps + + assert optional_deps.ExampleClass is not None + + with pytest.raises(QtModuleNotInstalledError) as excinfo: + from .optional_deps import MissingClass + + msg = "The optional_dep.MissingClass module was not found. It must be installed separately as test_package_please_ignore." + assert msg == str(excinfo.value) diff --git a/python3.10libs/qtpy/tests/test_qdesktopservice_split.py b/python3.10libs/qtpy/tests/test_qdesktopservice_split.py new file mode 100644 index 0000000..9819288 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qdesktopservice_split.py @@ -0,0 +1,22 @@ +"""Test QDesktopServices split in Qt5.""" + + +import pytest + + +def test_qstandarpath(): + """Test the qtpy.QStandardPaths namespace""" + from qtpy.QtCore import QStandardPaths + + assert QStandardPaths.StandardLocation is not None + + # Attributes from QDesktopServices shouldn't be in QStandardPaths + with pytest.raises(AttributeError): + QStandardPaths.setUrlHandler # noqa: B018 + + +def test_qdesktopservice(): + """Test the qtpy.QDesktopServices namespace""" + from qtpy.QtGui import QDesktopServices + + assert QDesktopServices.setUrlHandler is not None diff --git a/python3.10libs/qtpy/tests/test_qsci.py b/python3.10libs/qtpy/tests/test_qsci.py new file mode 100644 index 0000000..8f00158 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qsci.py @@ -0,0 +1,68 @@ +"""Test Qsci.""" + +import pytest + +from qtpy import PYSIDE2, PYSIDE6 +from qtpy.tests.utils import using_conda + + +@pytest.mark.skipif( + PYSIDE2 or PYSIDE6 or using_conda(), + reason="Qsci bindings not available under PySide 2/6 and conda installations", +) +def test_qsci(): + """Test the qtpy.Qsci namespace""" + Qsci = pytest.importorskip("qtpy.Qsci") + assert Qsci.QSCINTILLA_VERSION is not None + assert Qsci.QSCINTILLA_VERSION_STR is not None + assert Qsci.QsciAPIs is not None + assert Qsci.QsciAbstractAPIs is not None + assert Qsci.QsciCommand is not None + assert Qsci.QsciCommandSet is not None + assert Qsci.QsciDocument is not None + assert Qsci.QsciLexer is not None + assert Qsci.QsciLexerAVS is not None + assert Qsci.QsciLexerBash is not None + assert Qsci.QsciLexerBatch is not None + assert Qsci.QsciLexerCMake is not None + assert Qsci.QsciLexerCPP is not None + assert Qsci.QsciLexerCSS is not None + assert Qsci.QsciLexerCSharp is not None + assert Qsci.QsciLexerCoffeeScript is not None + assert Qsci.QsciLexerCustom is not None + assert Qsci.QsciLexerD is not None + assert Qsci.QsciLexerDiff is not None + assert Qsci.QsciLexerFortran is not None + assert Qsci.QsciLexerFortran77 is not None + assert Qsci.QsciLexerHTML is not None + assert Qsci.QsciLexerIDL is not None + assert Qsci.QsciLexerJSON is not None + assert Qsci.QsciLexerJava is not None + assert Qsci.QsciLexerJavaScript is not None + assert Qsci.QsciLexerLua is not None + assert Qsci.QsciLexerMakefile is not None + assert Qsci.QsciLexerMarkdown is not None + assert Qsci.QsciLexerMatlab is not None + assert Qsci.QsciLexerOctave is not None + assert Qsci.QsciLexerPO is not None + assert Qsci.QsciLexerPOV is not None + assert Qsci.QsciLexerPascal is not None + assert Qsci.QsciLexerPerl is not None + assert Qsci.QsciLexerPostScript is not None + assert Qsci.QsciLexerProperties is not None + assert Qsci.QsciLexerPython is not None + assert Qsci.QsciLexerRuby is not None + assert Qsci.QsciLexerSQL is not None + assert Qsci.QsciLexerSpice is not None + assert Qsci.QsciLexerTCL is not None + assert Qsci.QsciLexerTeX is not None + assert Qsci.QsciLexerVHDL is not None + assert Qsci.QsciLexerVerilog is not None + assert Qsci.QsciLexerXML is not None + assert Qsci.QsciLexerYAML is not None + assert Qsci.QsciMacro is not None + assert Qsci.QsciPrinter is not None + assert Qsci.QsciScintilla is not None + assert Qsci.QsciScintillaBase is not None + assert Qsci.QsciStyle is not None + assert Qsci.QsciStyledText is not None diff --git a/python3.10libs/qtpy/tests/test_qt3danimation.py b/python3.10libs/qtpy/tests/test_qt3danimation.py new file mode 100644 index 0000000..23171b0 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qt3danimation.py @@ -0,0 +1,22 @@ +import pytest + + +def test_qt3danimation(): + """Test the qtpy.Qt3DAnimation namespace""" + Qt3DAnimation = pytest.importorskip("qtpy.Qt3DAnimation") + + assert Qt3DAnimation.QAnimationController is not None + assert Qt3DAnimation.QAdditiveClipBlend is not None + assert Qt3DAnimation.QAbstractClipBlendNode is not None + assert Qt3DAnimation.QAbstractAnimation is not None + assert Qt3DAnimation.QKeyframeAnimation is not None + assert Qt3DAnimation.QAbstractAnimationClip is not None + assert Qt3DAnimation.QAbstractClipAnimator is not None + assert Qt3DAnimation.QClipAnimator is not None + assert Qt3DAnimation.QAnimationGroup is not None + assert Qt3DAnimation.QLerpClipBlend is not None + assert Qt3DAnimation.QMorphingAnimation is not None + assert Qt3DAnimation.QAnimationAspect is not None + assert Qt3DAnimation.QVertexBlendAnimation is not None + assert Qt3DAnimation.QBlendedClipAnimator is not None + assert Qt3DAnimation.QMorphTarget is not None diff --git a/python3.10libs/qtpy/tests/test_qt3dcore.py b/python3.10libs/qtpy/tests/test_qt3dcore.py new file mode 100644 index 0000000..7cae916 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qt3dcore.py @@ -0,0 +1,45 @@ +import pytest + +from qtpy import PYQT6, PYSIDE6 + + +@pytest.mark.skipif(PYQT6, reason="Not complete in PyQt6") +@pytest.mark.skipif(PYSIDE6, reason="Not complete in PySide6") +def test_qt3dcore(): + """Test the qtpy.Qt3DCore namespace""" + Qt3DCore = pytest.importorskip("qtpy.Qt3DCore") + + assert Qt3DCore.QPropertyValueAddedChange is not None + assert Qt3DCore.QSkeletonLoader is not None + assert Qt3DCore.QPropertyNodeRemovedChange is not None + assert Qt3DCore.QPropertyUpdatedChange is not None + assert Qt3DCore.QAspectEngine is not None + assert Qt3DCore.QPropertyValueAddedChangeBase is not None + assert Qt3DCore.QStaticPropertyValueRemovedChangeBase is not None + assert Qt3DCore.QPropertyNodeAddedChange is not None + assert Qt3DCore.QDynamicPropertyUpdatedChange is not None + assert Qt3DCore.QStaticPropertyUpdatedChangeBase is not None + assert Qt3DCore.ChangeFlags is not None + assert Qt3DCore.QAbstractAspect is not None + assert Qt3DCore.QBackendNode is not None + assert Qt3DCore.QTransform is not None + assert Qt3DCore.QPropertyUpdatedChangeBase is not None + assert Qt3DCore.QNodeId is not None + assert Qt3DCore.QJoint is not None + assert Qt3DCore.QSceneChange is not None + assert Qt3DCore.QNodeIdTypePair is not None + assert Qt3DCore.QAbstractSkeleton is not None + assert Qt3DCore.QComponentRemovedChange is not None + assert Qt3DCore.QComponent is not None + assert Qt3DCore.QEntity is not None + assert Qt3DCore.QNodeCommand is not None + assert Qt3DCore.QNode is not None + assert Qt3DCore.QPropertyValueRemovedChange is not None + assert Qt3DCore.QPropertyValueRemovedChangeBase is not None + assert Qt3DCore.QComponentAddedChange is not None + assert Qt3DCore.QNodeCreatedChangeBase is not None + assert Qt3DCore.QNodeDestroyedChange is not None + assert Qt3DCore.QArmature is not None + assert Qt3DCore.QStaticPropertyValueAddedChangeBase is not None + assert Qt3DCore.ChangeFlag is not None + assert Qt3DCore.QSkeleton is not None diff --git a/python3.10libs/qtpy/tests/test_qt3dextras.py b/python3.10libs/qtpy/tests/test_qt3dextras.py new file mode 100644 index 0000000..ba3f0e1 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qt3dextras.py @@ -0,0 +1,44 @@ +import pytest + + +def test_qt3dextras(): + """Test the qtpy.Qt3DExtras namespace""" + Qt3DExtras = pytest.importorskip("qtpy.Qt3DExtras") + + assert Qt3DExtras.QTextureMaterial is not None + assert Qt3DExtras.QPhongAlphaMaterial is not None + assert Qt3DExtras.QOrbitCameraController is not None + assert Qt3DExtras.QAbstractSpriteSheet is not None + assert Qt3DExtras.QNormalDiffuseMapMaterial is not None + assert Qt3DExtras.QDiffuseSpecularMaterial is not None + assert Qt3DExtras.QSphereGeometry is not None + assert Qt3DExtras.QCuboidGeometry is not None + assert Qt3DExtras.QForwardRenderer is not None + assert Qt3DExtras.QPhongMaterial is not None + assert Qt3DExtras.QSpriteGrid is not None + assert Qt3DExtras.QDiffuseMapMaterial is not None + assert Qt3DExtras.QConeGeometry is not None + assert Qt3DExtras.QSpriteSheetItem is not None + assert Qt3DExtras.QPlaneGeometry is not None + assert Qt3DExtras.QSphereMesh is not None + assert Qt3DExtras.QNormalDiffuseSpecularMapMaterial is not None + assert Qt3DExtras.QCuboidMesh is not None + assert Qt3DExtras.QGoochMaterial is not None + assert Qt3DExtras.QText2DEntity is not None + assert Qt3DExtras.QTorusMesh is not None + assert Qt3DExtras.Qt3DWindow is not None + assert Qt3DExtras.QPerVertexColorMaterial is not None + assert Qt3DExtras.QExtrudedTextGeometry is not None + assert Qt3DExtras.QSkyboxEntity is not None + assert Qt3DExtras.QAbstractCameraController is not None + assert Qt3DExtras.QExtrudedTextMesh is not None + assert Qt3DExtras.QCylinderGeometry is not None + assert Qt3DExtras.QTorusGeometry is not None + assert Qt3DExtras.QMorphPhongMaterial is not None + assert Qt3DExtras.QPlaneMesh is not None + assert Qt3DExtras.QDiffuseSpecularMapMaterial is not None + assert Qt3DExtras.QSpriteSheet is not None + assert Qt3DExtras.QConeMesh is not None + assert Qt3DExtras.QFirstPersonCameraController is not None + assert Qt3DExtras.QMetalRoughMaterial is not None + assert Qt3DExtras.QCylinderMesh is not None diff --git a/python3.10libs/qtpy/tests/test_qt3dinput.py b/python3.10libs/qtpy/tests/test_qt3dinput.py new file mode 100644 index 0000000..562055e --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qt3dinput.py @@ -0,0 +1,29 @@ +import pytest + + +def test_qt3dinput(): + """Test the qtpy.Qt3DInput namespace""" + Qt3DInput = pytest.importorskip("qtpy.Qt3DInput") + + assert Qt3DInput.QAxisAccumulator is not None + assert Qt3DInput.QInputSettings is not None + assert Qt3DInput.QAnalogAxisInput is not None + assert Qt3DInput.QAbstractAxisInput is not None + assert Qt3DInput.QMouseHandler is not None + assert Qt3DInput.QButtonAxisInput is not None + assert Qt3DInput.QInputSequence is not None + assert Qt3DInput.QWheelEvent is not None + assert Qt3DInput.QActionInput is not None + assert Qt3DInput.QKeyboardDevice is not None + assert Qt3DInput.QMouseDevice is not None + assert Qt3DInput.QAxis is not None + assert Qt3DInput.QInputChord is not None + assert Qt3DInput.QMouseEvent is not None + assert Qt3DInput.QKeyboardHandler is not None + assert Qt3DInput.QKeyEvent is not None + assert Qt3DInput.QAbstractActionInput is not None + assert Qt3DInput.QInputAspect is not None + assert Qt3DInput.QLogicalDevice is not None + assert Qt3DInput.QAction is not None + assert Qt3DInput.QAbstractPhysicalDevice is not None + assert Qt3DInput.QAxisSetting is not None diff --git a/python3.10libs/qtpy/tests/test_qt3dlogic.py b/python3.10libs/qtpy/tests/test_qt3dlogic.py new file mode 100644 index 0000000..c325bf7 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qt3dlogic.py @@ -0,0 +1,9 @@ +import pytest + + +def test_qt3dlogic(): + """Test the qtpy.Qt3DLogic namespace""" + Qt3DLogic = pytest.importorskip("qtpy.Qt3DLogic") + + assert Qt3DLogic.QLogicAspect is not None + assert Qt3DLogic.QFrameAction is not None diff --git a/python3.10libs/qtpy/tests/test_qt3drender.py b/python3.10libs/qtpy/tests/test_qt3drender.py new file mode 100644 index 0000000..7f5d450 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qt3drender.py @@ -0,0 +1,120 @@ +import pytest + +from qtpy import PYQT6, PYSIDE6 + + +@pytest.mark.skipif(PYQT6, reason="Not complete in PyQt6") +@pytest.mark.skipif(PYSIDE6, reason="Not complete in PySide6") +def test_qt3drender(): + """Test the qtpy.Qt3DRender namespace""" + Qt3DRender = pytest.importorskip("qtpy.Qt3DRender") + + assert Qt3DRender.QPointSize is not None + assert Qt3DRender.QFrustumCulling is not None + assert Qt3DRender.QPickPointEvent is not None + assert Qt3DRender.QRenderPassFilter is not None + assert Qt3DRender.QMesh is not None + assert Qt3DRender.QRayCaster is not None + assert Qt3DRender.QStencilMask is not None + assert Qt3DRender.QPickLineEvent is not None + assert Qt3DRender.QPickTriangleEvent is not None + assert Qt3DRender.QRenderState is not None + assert Qt3DRender.QTextureWrapMode is not None + assert Qt3DRender.QRenderPass is not None + assert Qt3DRender.QGeometryRenderer is not None + assert Qt3DRender.QAttribute is not None + assert Qt3DRender.QStencilOperation is not None + assert Qt3DRender.QScissorTest is not None + assert Qt3DRender.QTextureCubeMapArray is not None + assert Qt3DRender.QRenderTarget is not None + assert Qt3DRender.QStencilTest is not None + assert Qt3DRender.QTextureData is not None + assert Qt3DRender.QBuffer is not None + assert Qt3DRender.QLineWidth is not None + assert Qt3DRender.QLayer is not None + assert Qt3DRender.QTextureRectangle is not None + assert Qt3DRender.QRenderTargetSelector is not None + assert Qt3DRender.QPickingSettings is not None + assert Qt3DRender.QCullFace is not None + assert Qt3DRender.QAbstractFunctor is not None + assert Qt3DRender.PropertyReaderInterface is not None + assert Qt3DRender.QMaterial is not None + assert Qt3DRender.QAlphaCoverage is not None + assert Qt3DRender.QClearBuffers is not None + assert Qt3DRender.QAlphaTest is not None + assert Qt3DRender.QStencilOperationArguments is not None + assert Qt3DRender.QTexture2DMultisample is not None + assert Qt3DRender.QLevelOfDetailSwitch is not None + assert Qt3DRender.QRenderStateSet is not None + assert Qt3DRender.QViewport is not None + assert Qt3DRender.QObjectPicker is not None + assert Qt3DRender.QPolygonOffset is not None + assert Qt3DRender.QRenderSettings is not None + assert Qt3DRender.QFrontFace is not None + assert Qt3DRender.QTexture3D is not None + assert Qt3DRender.QTextureBuffer is not None + assert Qt3DRender.QTechniqueFilter is not None + assert Qt3DRender.QLayerFilter is not None + assert Qt3DRender.QFilterKey is not None + assert Qt3DRender.QRenderSurfaceSelector is not None + assert Qt3DRender.QEnvironmentLight is not None + assert Qt3DRender.QMemoryBarrier is not None + assert Qt3DRender.QNoDepthMask is not None + assert Qt3DRender.QBlitFramebuffer is not None + assert Qt3DRender.QGraphicsApiFilter is not None + assert Qt3DRender.QAbstractTexture is not None + assert Qt3DRender.QRenderCaptureReply is not None + assert Qt3DRender.QAbstractLight is not None + assert Qt3DRender.QAbstractRayCaster is not None + assert Qt3DRender.QDirectionalLight is not None + assert Qt3DRender.QDispatchCompute is not None + assert Qt3DRender.QBufferDataGenerator is not None + assert Qt3DRender.QPointLight is not None + assert Qt3DRender.QStencilTestArguments is not None + assert Qt3DRender.QTexture1D is not None + assert Qt3DRender.QCameraSelector is not None + assert Qt3DRender.QProximityFilter is not None + assert Qt3DRender.QTexture1DArray is not None + assert Qt3DRender.QBlendEquation is not None + assert Qt3DRender.QTextureImageDataGenerator is not None + assert Qt3DRender.QSpotLight is not None + assert Qt3DRender.QEffect is not None + assert Qt3DRender.QSeamlessCubemap is not None + assert Qt3DRender.QTexture2DMultisampleArray is not None + assert Qt3DRender.QComputeCommand is not None + assert Qt3DRender.QFrameGraphNode is not None + assert Qt3DRender.QSortPolicy is not None + assert Qt3DRender.QTextureImageData is not None + assert Qt3DRender.QCamera is not None + assert Qt3DRender.QGeometry is not None + assert Qt3DRender.QScreenRayCaster is not None + assert Qt3DRender.QClipPlane is not None + assert Qt3DRender.QMultiSampleAntiAliasing is not None + assert Qt3DRender.QRayCasterHit is not None + assert Qt3DRender.QAbstractTextureImage is not None + assert Qt3DRender.QNoDraw is not None + assert Qt3DRender.QPickEvent is not None + assert Qt3DRender.QRenderCapture is not None + assert Qt3DRender.QDepthTest is not None + assert Qt3DRender.QParameter is not None + assert Qt3DRender.QLevelOfDetail is not None + assert Qt3DRender.QGeometryFactory is not None + assert Qt3DRender.QTexture2D is not None + assert Qt3DRender.QRenderAspect is not None + assert Qt3DRender.QPaintedTextureImage is not None + assert Qt3DRender.QDithering is not None + assert Qt3DRender.QTextureGenerator is not None + assert Qt3DRender.QBlendEquationArguments is not None + assert Qt3DRender.QLevelOfDetailBoundingSphere is not None + assert Qt3DRender.QColorMask is not None + assert Qt3DRender.QSceneLoader is not None + assert Qt3DRender.QTextureLoader is not None + assert Qt3DRender.QShaderProgram is not None + assert Qt3DRender.QTextureCubeMap is not None + assert Qt3DRender.QTexture2DArray is not None + assert Qt3DRender.QTextureImage is not None + assert Qt3DRender.QCameraLens is not None + assert Qt3DRender.QRenderTargetOutput is not None + assert Qt3DRender.QShaderProgramBuilder is not None + assert Qt3DRender.QTechnique is not None + assert Qt3DRender.QShaderData is not None diff --git a/python3.10libs/qtpy/tests/test_qtaxcontainer.py b/python3.10libs/qtpy/tests/test_qtaxcontainer.py new file mode 100644 index 0000000..6e31a15 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtaxcontainer.py @@ -0,0 +1,9 @@ +import pytest + + +def test_qtaxcontainer(): + """Test the qtpy.QtAxContainer namespace""" + QtAxContainer = pytest.importorskip("qtpy.QtAxContainer") + + assert QtAxContainer.QAxSelect is not None + assert QtAxContainer.QAxWidget is not None diff --git a/python3.10libs/qtpy/tests/test_qtbluetooth.py b/python3.10libs/qtpy/tests/test_qtbluetooth.py new file mode 100644 index 0000000..f9294e9 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtbluetooth.py @@ -0,0 +1,14 @@ +import pytest + + +def test_qtbluetooth(): + """Test the qtpy.QtBluetooth namespace""" + QtBluetooth = pytest.importorskip("qtpy.QtBluetooth") + + assert QtBluetooth.QBluetooth is not None + assert QtBluetooth.QBluetoothDeviceInfo is not None + assert QtBluetooth.QBluetoothServer is not None + assert QtBluetooth.QBluetoothSocket is not None + assert QtBluetooth.QBluetoothAddress is not None + assert QtBluetooth.QBluetoothUuid is not None + assert QtBluetooth.QBluetoothServiceDiscoveryAgent is not None diff --git a/python3.10libs/qtpy/tests/test_qtcharts.py b/python3.10libs/qtpy/tests/test_qtcharts.py new file mode 100644 index 0000000..4cce6f9 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtcharts.py @@ -0,0 +1,15 @@ +import pytest + +from qtpy import PYSIDE2, PYSIDE6 + + +@pytest.mark.skipif( + not (PYSIDE2 or PYSIDE6), + reason="Only available by default in PySide", +) +def test_qtcharts(): + """Test the qtpy.QtCharts namespace""" + QtCharts = pytest.importorskip("qtpy.QtCharts") + + assert QtCharts.QChart is not None + assert QtCharts.QtCharts.QChart is not None diff --git a/python3.10libs/qtpy/tests/test_qtconcurrent.py b/python3.10libs/qtpy/tests/test_qtconcurrent.py new file mode 100644 index 0000000..114de18 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtconcurrent.py @@ -0,0 +1,17 @@ +import pytest +from packaging.version import parse + +from qtpy import PYSIDE2, PYSIDE_VERSION + + +def test_qtconcurrent(): + """Test the qtpy.QtConcurrent namespace""" + QtConcurrent = pytest.importorskip("qtpy.QtConcurrent") + + assert QtConcurrent.QtConcurrent is not None + + if PYSIDE2 and parse(PYSIDE_VERSION) >= parse("5.15.2"): + assert QtConcurrent.QFutureQString is not None + assert QtConcurrent.QFutureVoid is not None + assert QtConcurrent.QFutureWatcherQString is not None + assert QtConcurrent.QFutureWatcherVoid is not None diff --git a/python3.10libs/qtpy/tests/test_qtcore.py b/python3.10libs/qtpy/tests/test_qtcore.py new file mode 100644 index 0000000..4f98156 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtcore.py @@ -0,0 +1,211 @@ +"""Test QtCore.""" + +import enum +import sys +from datetime import date, datetime, time + +import pytest +from packaging.version import parse + +from qtpy import ( + PYQT5, + PYQT6, + PYQT_VERSION, + PYSIDE2, + PYSIDE_VERSION, + QtCore, +) + +_now = datetime.now() +# Make integer milliseconds; `floor` here, don't `round`! +NOW = _now.replace(microsecond=(_now.microsecond // 1000 * 1000)) + + +def test_qtmsghandler(): + """Test qtpy.QtMsgHandler""" + assert QtCore.qInstallMessageHandler is not None + + +@pytest.mark.parametrize("method", ["toPython", "toPyDateTime"]) +def test_QDateTime_toPython_and_toPyDateTime(method): + """Test `QDateTime.toPython` and `QDateTime.toPyDateTime`""" + q_datetime = QtCore.QDateTime(NOW) + py_datetime = getattr(q_datetime, method)() + assert isinstance(py_datetime, datetime) + assert py_datetime == NOW + + +@pytest.mark.parametrize("method", ["toPython", "toPyDate"]) +def test_QDate_toPython_and_toPyDate(method): + """Test `QDate.toPython` and `QDate.toPyDate`""" + q_date = QtCore.QDateTime(NOW).date() + py_date = getattr(q_date, method)() + assert isinstance(py_date, date) + assert py_date == NOW.date() + + +@pytest.mark.parametrize("method", ["toPython", "toPyTime"]) +def test_QTime_toPython_and_toPyTime(method): + """Test `QTime.toPython` and `QTime.toPyTime`""" + q_time = QtCore.QDateTime(NOW).time() + py_time = getattr(q_time, method)() + assert isinstance(py_time, time) + assert py_time == NOW.time() + + +def test_qeventloop_exec(qtbot): + """Test `QEventLoop.exec_` and `QEventLoop.exec`""" + assert QtCore.QEventLoop.exec_ is not None + assert QtCore.QEventLoop.exec is not None + event_loop = QtCore.QEventLoop(None) + QtCore.QTimer.singleShot(100, event_loop.quit) + event_loop.exec_() + QtCore.QTimer.singleShot(100, event_loop.quit) + event_loop.exec() + + +def test_qthread_exec(): + """Test `QThread.exec_` and `QThread.exec_`""" + assert QtCore.QThread.exec_ is not None + assert QtCore.QThread.exec is not None + + +@pytest.mark.skipif( + PYSIDE2 and parse(PYSIDE_VERSION) < parse("5.15"), + reason="QEnum macro doesn't seem to be present on PySide2 <5.15", +) +def test_qenum(): + """Test QEnum macro""" + + class EnumTest(QtCore.QObject): + class Position(enum.IntEnum): + West = 0 + North = 1 + South = 2 + East = 3 + + QtCore.QEnum(Position) + + obj = EnumTest() + assert obj.metaObject().enumerator(0).name() == "Position" + + +def test_QLibraryInfo_location_and_path(): + """Test `QLibraryInfo.location` and `QLibraryInfo.path`""" + assert QtCore.QLibraryInfo.location is not None + assert ( + QtCore.QLibraryInfo.location(QtCore.QLibraryInfo.PrefixPath) + is not None + ) + assert QtCore.QLibraryInfo.path is not None + assert QtCore.QLibraryInfo.path(QtCore.QLibraryInfo.PrefixPath) is not None + + +def test_QLibraryInfo_LibraryLocation_and_LibraryPath(): + """Test `QLibraryInfo.LibraryLocation` and `QLibraryInfo.LibraryPath`""" + assert QtCore.QLibraryInfo.LibraryLocation is not None + assert QtCore.QLibraryInfo.LibraryPath is not None + + +def test_QCoreApplication_exec_(qapp): + """Test `QtCore.QCoreApplication.exec_`""" + assert QtCore.QCoreApplication.exec_ is not None + app = QtCore.QCoreApplication.instance() or QtCore.QCoreApplication( + [sys.executable, __file__], + ) + assert app is not None + QtCore.QTimer.singleShot(100, QtCore.QCoreApplication.instance().quit) + QtCore.QCoreApplication.exec_() + app = QtCore.QCoreApplication.instance() or QtCore.QCoreApplication( + [sys.executable, __file__], + ) + assert app is not None + QtCore.QTimer.singleShot(100, QtCore.QCoreApplication.instance().quit) + app.exec_() + + +def test_QCoreApplication_exec(qapp): + """Test `QtCore.QCoreApplication.exec`""" + assert QtCore.QCoreApplication.exec is not None + app = QtCore.QCoreApplication.instance() or QtCore.QCoreApplication( + [sys.executable, __file__], + ) + assert app is not None + QtCore.QTimer.singleShot(100, QtCore.QCoreApplication.instance().quit) + QtCore.QCoreApplication.exec() + app = QtCore.QCoreApplication.instance() or QtCore.QCoreApplication( + [sys.executable, __file__], + ) + assert app is not None + QtCore.QTimer.singleShot(100, QtCore.QCoreApplication.instance().quit) + app.exec() + + +@pytest.mark.skipif( + PYQT5 or PYQT6, + reason="Doesn't seem to be present on PyQt5 and PyQt6", +) +def test_qtextstreammanipulator_exec(): + """Test `QTextStreamManipulator.exec_` and `QTextStreamManipulator.exec`""" + assert QtCore.QTextStreamManipulator.exec_ is not None + assert QtCore.QTextStreamManipulator.exec is not None + + +@pytest.mark.skipif( + PYSIDE2 or PYQT6, + reason="Doesn't seem to be present on PySide2 and PyQt6", +) +def test_QtCore_SignalInstance(): + class ClassWithSignal(QtCore.QObject): + signal = QtCore.Signal() + + instance = ClassWithSignal() + + assert isinstance(instance.signal, QtCore.SignalInstance) + + +@pytest.mark.skipif( + PYQT5 and PYQT_VERSION.startswith("5.9"), + reason="A specific setup with at least sip 4.9.9 is needed for PyQt5 5.9.*" + "to work with scoped enum access", +) +def test_enum_access(): + """Test scoped and unscoped enum access for qtpy.QtCore.*.""" + assert ( + QtCore.QAbstractAnimation.Stopped + == QtCore.QAbstractAnimation.State.Stopped + ) + assert QtCore.QEvent.ActionAdded == QtCore.QEvent.Type.ActionAdded + assert QtCore.Qt.AlignLeft == QtCore.Qt.AlignmentFlag.AlignLeft + assert QtCore.Qt.Key_Return == QtCore.Qt.Key.Key_Return + assert QtCore.Qt.transparent == QtCore.Qt.GlobalColor.transparent + assert QtCore.Qt.Widget == QtCore.Qt.WindowType.Widget + assert QtCore.Qt.BackButton == QtCore.Qt.MouseButton.BackButton + assert QtCore.Qt.XButton1 == QtCore.Qt.MouseButton.XButton1 + assert ( + QtCore.Qt.BackgroundColorRole + == QtCore.Qt.ItemDataRole.BackgroundColorRole + ) + assert QtCore.Qt.TextColorRole == QtCore.Qt.ItemDataRole.TextColorRole + assert QtCore.Qt.MidButton == QtCore.Qt.MouseButton.MiddleButton + + +@pytest.mark.skipif( + PYSIDE2 and PYSIDE_VERSION.startswith("5.12.0"), + reason="Utility functions unavailable for PySide2 5.12.0", +) +def test_qtgui_namespace_mightBeRichText(): + """ + Test included elements (mightBeRichText) from module QtGui. + + See: https://doc.qt.io/qt-5/qt-sub-qtgui.html + """ + assert QtCore.Qt.mightBeRichText is not None + + +def test_itemflags_typedef(): + """ + Test existence of `QFlags` typedef `ItemFlags` that was removed from PyQt6 + """ + assert QtCore.Qt.ItemFlags is not None + assert QtCore.Qt.ItemFlags() == QtCore.Qt.ItemFlag(0) diff --git a/python3.10libs/qtpy/tests/test_qtdatavisualization.py b/python3.10libs/qtpy/tests/test_qtdatavisualization.py new file mode 100644 index 0000000..5bbf7ca --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtdatavisualization.py @@ -0,0 +1,86 @@ +import pytest + + +def test_qtdatavisualization(): + """Test the qtpy.QtDataVisualization namespace""" + # Using import skip here since with Python 3 you need to install another package + # besides the base `PyQt5` or `PySide2`. + # For example in the case of `PyQt5` you need `PyQtDataVisualization` + + # QtDataVisualization + QtDataVisualization = pytest.importorskip("qtpy.QtDataVisualization") + assert QtDataVisualization.QScatter3DSeries is not None + assert QtDataVisualization.QSurfaceDataItem is not None + assert QtDataVisualization.QSurface3DSeries is not None + assert QtDataVisualization.QAbstract3DInputHandler is not None + assert QtDataVisualization.QHeightMapSurfaceDataProxy is not None + assert QtDataVisualization.QAbstractDataProxy is not None + assert QtDataVisualization.Q3DCamera is not None + assert QtDataVisualization.QAbstract3DGraph is not None + assert QtDataVisualization.QCustom3DVolume is not None + assert QtDataVisualization.Q3DInputHandler is not None + assert QtDataVisualization.QBarDataProxy is not None + assert QtDataVisualization.QSurfaceDataProxy is not None + assert QtDataVisualization.QScatterDataItem is not None + assert QtDataVisualization.Q3DLight is not None + assert QtDataVisualization.QScatterDataProxy is not None + assert QtDataVisualization.QValue3DAxis is not None + assert QtDataVisualization.Q3DBars is not None + assert QtDataVisualization.QBarDataItem is not None + assert QtDataVisualization.QItemModelBarDataProxy is not None + assert QtDataVisualization.Q3DTheme is not None + assert QtDataVisualization.QCustom3DItem is not None + assert QtDataVisualization.QItemModelScatterDataProxy is not None + assert QtDataVisualization.QValue3DAxisFormatter is not None + assert QtDataVisualization.QItemModelSurfaceDataProxy is not None + assert QtDataVisualization.Q3DScatter is not None + assert QtDataVisualization.QTouch3DInputHandler is not None + assert QtDataVisualization.QBar3DSeries is not None + assert QtDataVisualization.QAbstract3DAxis is not None + assert QtDataVisualization.Q3DScene is not None + assert QtDataVisualization.QCategory3DAxis is not None + assert QtDataVisualization.QAbstract3DSeries is not None + assert QtDataVisualization.Q3DObject is not None + assert QtDataVisualization.QCustom3DLabel is not None + assert QtDataVisualization.Q3DSurface is not None + assert QtDataVisualization.QLogValue3DAxisFormatter is not None + + # QtDatavisualization + # import qtpy to get alias for `QtDataVisualization` with lower `v` + qtpy = pytest.importorskip("qtpy") + + assert qtpy.QtDatavisualization.QScatter3DSeries is not None + assert qtpy.QtDatavisualization.QSurfaceDataItem is not None + assert qtpy.QtDatavisualization.QSurface3DSeries is not None + assert qtpy.QtDatavisualization.QAbstract3DInputHandler is not None + assert qtpy.QtDatavisualization.QHeightMapSurfaceDataProxy is not None + assert qtpy.QtDatavisualization.QAbstractDataProxy is not None + assert qtpy.QtDatavisualization.Q3DCamera is not None + assert qtpy.QtDatavisualization.QAbstract3DGraph is not None + assert qtpy.QtDatavisualization.QCustom3DVolume is not None + assert qtpy.QtDatavisualization.Q3DInputHandler is not None + assert qtpy.QtDatavisualization.QBarDataProxy is not None + assert qtpy.QtDatavisualization.QSurfaceDataProxy is not None + assert qtpy.QtDatavisualization.QScatterDataItem is not None + assert qtpy.QtDatavisualization.Q3DLight is not None + assert qtpy.QtDatavisualization.QScatterDataProxy is not None + assert qtpy.QtDatavisualization.QValue3DAxis is not None + assert qtpy.QtDatavisualization.Q3DBars is not None + assert qtpy.QtDatavisualization.QBarDataItem is not None + assert qtpy.QtDatavisualization.QItemModelBarDataProxy is not None + assert qtpy.QtDatavisualization.Q3DTheme is not None + assert qtpy.QtDatavisualization.QCustom3DItem is not None + assert qtpy.QtDatavisualization.QItemModelScatterDataProxy is not None + assert qtpy.QtDatavisualization.QValue3DAxisFormatter is not None + assert qtpy.QtDatavisualization.QItemModelSurfaceDataProxy is not None + assert qtpy.QtDatavisualization.Q3DScatter is not None + assert qtpy.QtDatavisualization.QTouch3DInputHandler is not None + assert qtpy.QtDatavisualization.QBar3DSeries is not None + assert qtpy.QtDatavisualization.QAbstract3DAxis is not None + assert qtpy.QtDatavisualization.Q3DScene is not None + assert qtpy.QtDatavisualization.QCategory3DAxis is not None + assert qtpy.QtDatavisualization.QAbstract3DSeries is not None + assert qtpy.QtDatavisualization.Q3DObject is not None + assert qtpy.QtDatavisualization.QCustom3DLabel is not None + assert qtpy.QtDatavisualization.Q3DSurface is not None + assert qtpy.QtDatavisualization.QLogValue3DAxisFormatter is not None diff --git a/python3.10libs/qtpy/tests/test_qtdbus.py b/python3.10libs/qtpy/tests/test_qtdbus.py new file mode 100644 index 0000000..5594692 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtdbus.py @@ -0,0 +1,11 @@ +import pytest + + +def test_qtdbus(): + """Test the qtpy.QtDBus namespace""" + QtDBus = pytest.importorskip("qtpy.QtDBus") + + assert QtDBus.QDBusAbstractAdaptor is not None + assert QtDBus.QDBusAbstractInterface is not None + assert QtDBus.QDBusArgument is not None + assert QtDBus.QDBusConnection is not None diff --git a/python3.10libs/qtpy/tests/test_qtdesigner.py b/python3.10libs/qtpy/tests/test_qtdesigner.py new file mode 100644 index 0000000..206390d --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtdesigner.py @@ -0,0 +1,29 @@ +import pytest + +from qtpy import PYSIDE2 + + +@pytest.mark.skipif(PYSIDE2, reason="QtDesigner is not available in PySide2") +def test_qtdesigner(): + """Test the qtpy.QtDesigner namespace.""" + QtDesigner = pytest.importorskip("qtpy.QtDesigner") + + assert QtDesigner.QAbstractExtensionFactory is not None + assert QtDesigner.QAbstractExtensionManager is not None + assert QtDesigner.QDesignerActionEditorInterface is not None + assert QtDesigner.QDesignerContainerExtension is not None + assert QtDesigner.QDesignerCustomWidgetCollectionInterface is not None + assert QtDesigner.QDesignerCustomWidgetInterface is not None + assert QtDesigner.QDesignerFormEditorInterface is not None + assert QtDesigner.QDesignerFormWindowCursorInterface is not None + assert QtDesigner.QDesignerFormWindowInterface is not None + assert QtDesigner.QDesignerFormWindowManagerInterface is not None + assert QtDesigner.QDesignerMemberSheetExtension is not None + assert QtDesigner.QDesignerObjectInspectorInterface is not None + assert QtDesigner.QDesignerPropertyEditorInterface is not None + assert QtDesigner.QDesignerPropertySheetExtension is not None + assert QtDesigner.QDesignerTaskMenuExtension is not None + assert QtDesigner.QDesignerWidgetBoxInterface is not None + assert QtDesigner.QExtensionFactory is not None + assert QtDesigner.QExtensionManager is not None + assert QtDesigner.QFormBuilder is not None diff --git a/python3.10libs/qtpy/tests/test_qtgui.py b/python3.10libs/qtpy/tests/test_qtgui.py new file mode 100644 index 0000000..1aa4f47 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtgui.py @@ -0,0 +1,202 @@ +"""Test QtGui.""" + +import sys + +import pytest + +from qtpy import ( + PYQT5, + PYQT_VERSION, + PYSIDE2, + PYSIDE6, + QtCore, + QtGui, + QtWidgets, +) +from qtpy.tests.utils import not_using_conda + + +def test_qfontmetrics_width(qtbot): + """Test QFontMetrics and QFontMetricsF width""" + assert QtGui.QFontMetrics.width is not None + assert QtGui.QFontMetricsF.width is not None + font = QtGui.QFont("times", 24) + font_metrics = QtGui.QFontMetrics(font) + font_metricsF = QtGui.QFontMetricsF(font) + width = font_metrics.width("Test") + widthF = font_metricsF.width("Test") + assert width in range(40, 62) + assert 39 <= widthF <= 63 + + +def test_qdrag_functions(qtbot): + """Test functions mapping for QtGui.QDrag.""" + assert QtGui.QDrag.exec_ is not None + drag = QtGui.QDrag(None) + drag.exec_() + + +def test_QGuiApplication_exec_(): + """Test `QtGui.QGuiApplication.exec_`""" + assert QtGui.QGuiApplication.exec_ is not None + app = QtGui.QGuiApplication.instance() or QtGui.QGuiApplication( + [sys.executable, __file__], + ) + assert app is not None + QtCore.QTimer.singleShot(100, QtGui.QGuiApplication.instance().quit) + QtGui.QGuiApplication.exec_() + app = QtGui.QGuiApplication.instance() or QtGui.QGuiApplication( + [sys.executable, __file__], + ) + assert app is not None + QtCore.QTimer.singleShot(100, QtGui.QGuiApplication.instance().quit) + app.exec_() + + +def test_what_moved_to_qtgui_in_qt6(): + """Test what has been moved to QtGui in Qt6""" + assert QtGui.QAction is not None + assert QtGui.QActionGroup is not None + assert QtGui.QFileSystemModel is not None + assert QtGui.QShortcut is not None + assert QtGui.QUndoCommand is not None + + +def test_qtextdocument_functions(pdf_writer): + """Test functions mapping for QtGui.QTextDocument.""" + assert QtGui.QTextDocument.print_ is not None + text_document = QtGui.QTextDocument("Test") + print_device, output_path = pdf_writer + text_document.print_(print_device) + assert output_path.exists() + + +@pytest.mark.skipif( + PYQT5 and PYQT_VERSION.startswith("5.9"), + reason="A specific setup with at least sip 4.9.9 is needed for PyQt5 5.9.*" + "to work with scoped enum access", +) +def test_enum_access(): + """Test scoped and unscoped enum access for qtpy.QtWidgets.*.""" + assert QtGui.QColor.Rgb == QtGui.QColor.Spec.Rgb + assert QtGui.QFont.AllUppercase == QtGui.QFont.Capitalization.AllUppercase + assert QtGui.QIcon.Normal == QtGui.QIcon.Mode.Normal + assert QtGui.QImage.Format_Invalid == QtGui.QImage.Format.Format_Invalid + + +@pytest.mark.skipif( + sys.platform == "darwin" and sys.version_info[:2] == (3, 7), + reason="Stalls on macOS CI with Python 3.7", +) +def test_QSomethingEvent_pos_functions(qtbot): + """ + Test `QMouseEvent.pos` and related functions removed in Qt 6, + and `QMouseEvent.position`, etc., missing from Qt 5. + """ + + class Window(QtWidgets.QMainWindow): + def mouseDoubleClickEvent(self, event: QtGui.QMouseEvent) -> None: + assert event.globalPos() - event.pos() == self.mapToParent( + QtCore.QPoint(0, 0), + ) + assert event.pos().x() == event.x() + assert event.pos().y() == event.y() + assert event.globalPos().x() == event.globalX() + assert event.globalPos().y() == event.globalY() + assert event.position().x() == event.pos().x() + assert event.position().y() == event.pos().y() + assert event.globalPosition().x() == event.globalPos().x() + assert event.globalPosition().y() == event.globalPos().y() + + event.accept() + + window = Window() + window.setMinimumSize(320, 240) # ensure the window is of sufficient size + window.show() + + with qtbot.waitExposed(window): + qtbot.mouseMove(window, QtCore.QPoint(42, 6 * 9)) + qtbot.mouseDClick(window, QtCore.Qt.LeftButton) + + # the rest of the functions are not actually tested + # QSinglePointEvent (Qt6) child classes checks + for _class in ("QNativeGestureEvent", "QEnterEvent", "QTabletEvent"): + for _function in ( + "pos", + "x", + "y", + "globalPos", + "globalX", + "globalY", + "position", + "globalPosition", + ): + assert hasattr(getattr(QtGui, _class), _function) + + # QHoverEvent checks + for _function in ("pos", "x", "y", "position"): + assert hasattr(QtGui.QHoverEvent, _function) + + # QDropEvent and child classes checks + for _class in ("QDropEvent", "QDragMoveEvent", "QDragEnterEvent"): + for _function in ("pos", "posF", "position"): + assert hasattr(getattr(QtGui, _class), _function) + + +@pytest.mark.skipif( + not (PYSIDE2 or PYSIDE6), + reason="PySide{2,6} specific test", +) +def test_qtextcursor_moveposition(): + """Test monkeypatched QTextCursor.movePosition""" + doc = QtGui.QTextDocument("foo bar baz") + cursor = QtGui.QTextCursor(doc) + + assert not cursor.movePosition(QtGui.QTextCursor.Start) + assert cursor.movePosition( + QtGui.QTextCursor.EndOfWord, + mode=QtGui.QTextCursor.KeepAnchor, + ) + assert cursor.selectedText() == "foo" + + assert cursor.movePosition(QtGui.QTextCursor.Start) + assert cursor.movePosition( + QtGui.QTextCursor.WordRight, + n=2, + mode=QtGui.QTextCursor.KeepAnchor, + ) + assert cursor.selectedText() == "foo bar " + + assert cursor.movePosition(QtGui.QTextCursor.Start) + assert cursor.position() == cursor.anchor() + assert cursor.movePosition( + QtGui.QTextCursor.NextWord, + QtGui.QTextCursor.KeepAnchor, + 3, + ) + assert cursor.selectedText() == "foo bar baz" + + +def test_opengl_imports(): + """ + Test for presence of QOpenGL* classes. + + These classes were members of QtGui in Qt5, but moved to QtOpenGL in Qt6. + QtPy makes them available in QtGui to maintain compatibility. + """ + + assert QtGui.QOpenGLBuffer is not None + assert QtGui.QOpenGLContext is not None + assert QtGui.QOpenGLContextGroup is not None + assert QtGui.QOpenGLDebugLogger is not None + assert QtGui.QOpenGLDebugMessage is not None + assert QtGui.QOpenGLFramebufferObject is not None + assert QtGui.QOpenGLFramebufferObjectFormat is not None + assert QtGui.QOpenGLPixelTransferOptions is not None + assert QtGui.QOpenGLShader is not None + assert QtGui.QOpenGLShaderProgram is not None + assert QtGui.QOpenGLTexture is not None + assert QtGui.QOpenGLTextureBlitter is not None + assert QtGui.QOpenGLVersionProfile is not None + assert QtGui.QOpenGLVertexArrayObject is not None + assert QtGui.QOpenGLWindow is not None diff --git a/python3.10libs/qtpy/tests/test_qthelp.py b/python3.10libs/qtpy/tests/test_qthelp.py new file mode 100644 index 0000000..1107bc5 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qthelp.py @@ -0,0 +1,18 @@ +"""Test for QtHelp namespace.""" + + +def test_qthelp(): + """Test the qtpy.QtHelp namespace.""" + from qtpy import QtHelp + + assert QtHelp.QHelpContentItem is not None + assert QtHelp.QHelpContentModel is not None + assert QtHelp.QHelpContentWidget is not None + assert QtHelp.QHelpEngine is not None + assert QtHelp.QHelpEngineCore is not None + assert QtHelp.QHelpIndexModel is not None + assert QtHelp.QHelpIndexWidget is not None + assert QtHelp.QHelpSearchEngine is not None + assert QtHelp.QHelpSearchQuery is not None + assert QtHelp.QHelpSearchQueryWidget is not None + assert QtHelp.QHelpSearchResultWidget is not None diff --git a/python3.10libs/qtpy/tests/test_qtlocation.py b/python3.10libs/qtpy/tests/test_qtlocation.py new file mode 100644 index 0000000..f23a388 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtlocation.py @@ -0,0 +1,54 @@ +import pytest + +from qtpy import PYQT5, PYSIDE2 + + +@pytest.mark.skipif( + not (PYQT5 or PYSIDE2), + reason="Only available in Qt5 bindings", +) +def test_qtlocation(): + """Test the qtpy.QtLocation namespace""" + from qtpy import QtLocation + + if PYSIDE2: + assert QtLocation.QGeoServiceProviderFactory is not None + + assert QtLocation.QGeoCodeReply is not None + assert QtLocation.QGeoCodingManager is not None + assert QtLocation.QGeoCodingManagerEngine is not None + assert QtLocation.QGeoManeuver is not None + assert QtLocation.QGeoRoute is not None + assert QtLocation.QGeoRouteReply is not None + assert QtLocation.QGeoRouteRequest is not None + assert QtLocation.QGeoRouteSegment is not None + assert QtLocation.QGeoRoutingManager is not None + assert QtLocation.QGeoRoutingManagerEngine is not None + assert QtLocation.QGeoServiceProvider is not None + assert QtLocation.QPlace is not None + assert QtLocation.QPlaceAttribute is not None + assert QtLocation.QPlaceCategory is not None + assert QtLocation.QPlaceContactDetail is not None + assert QtLocation.QPlaceContent is not None + assert QtLocation.QPlaceContentReply is not None + assert QtLocation.QPlaceContentRequest is not None + assert QtLocation.QPlaceDetailsReply is not None + assert QtLocation.QPlaceEditorial is not None + assert QtLocation.QPlaceIcon is not None + assert QtLocation.QPlaceIdReply is not None + assert QtLocation.QPlaceImage is not None + assert QtLocation.QPlaceManager is not None + assert QtLocation.QPlaceManagerEngine is not None + assert QtLocation.QPlaceMatchReply is not None + assert QtLocation.QPlaceMatchRequest is not None + assert QtLocation.QPlaceProposedSearchResult is not None + assert QtLocation.QPlaceRatings is not None + assert QtLocation.QPlaceReply is not None + assert QtLocation.QPlaceResult is not None + assert QtLocation.QPlaceReview is not None + assert QtLocation.QPlaceSearchReply is not None + assert QtLocation.QPlaceSearchRequest is not None + assert QtLocation.QPlaceSearchResult is not None + assert QtLocation.QPlaceSearchSuggestionReply is not None + assert QtLocation.QPlaceSupplier is not None + assert QtLocation.QPlaceUser is not None diff --git a/python3.10libs/qtpy/tests/test_qtmacextras.py b/python3.10libs/qtpy/tests/test_qtmacextras.py new file mode 100644 index 0000000..1f33b61 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtmacextras.py @@ -0,0 +1,23 @@ +import sys + +import pytest + +from qtpy import PYQT6, PYSIDE6 +from qtpy.tests.utils import using_conda + + +@pytest.mark.skipif( + PYQT6 or PYSIDE6, + reason="Not available on Qt6-based bindings", +) +@pytest.mark.skipif( + sys.platform != "darwin" or using_conda(), + reason="Only available in Qt5 bindings > 5.9 with pip on mac in CIs", +) +def test_qtmacextras(): + """Test the qtpy.QtMacExtras namespace""" + QtMacExtras = pytest.importorskip("qtpy.QtMacExtras") + + assert QtMacExtras.QMacPasteboardMime is not None + assert QtMacExtras.QMacToolBar is not None + assert QtMacExtras.QMacToolBarItem is not None diff --git a/python3.10libs/qtpy/tests/test_qtmultimedia.py b/python3.10libs/qtpy/tests/test_qtmultimedia.py new file mode 100644 index 0000000..354bcd6 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtmultimedia.py @@ -0,0 +1,18 @@ +import sys + +import pytest + +from qtpy import PYQT6, PYSIDE6 + + +def test_qtmultimedia(): + """Test the qtpy.QtMultimedia namespace""" + from qtpy import QtMultimedia + + assert QtMultimedia.QAudio is not None + assert QtMultimedia.QAudioInput is not None + + if not (PYSIDE6 or PYQT6): + assert QtMultimedia.QAbstractVideoBuffer is not None + assert QtMultimedia.QAudioDeviceInfo is not None + assert QtMultimedia.QSound is not None diff --git a/python3.10libs/qtpy/tests/test_qtmultimediawidgets.py b/python3.10libs/qtpy/tests/test_qtmultimediawidgets.py new file mode 100644 index 0000000..6f80e4d --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtmultimediawidgets.py @@ -0,0 +1,15 @@ +"""Test QtMultimediaWidgets.""" + + +from qtpy import PYQT5, PYSIDE2 + + +def test_qtmultimediawidgets(): + """Test the qtpy.QtMultimediaWidgets namespace""" + from qtpy import QtMultimediaWidgets + + if PYQT5 or PYSIDE2: + assert QtMultimediaWidgets.QCameraViewfinder is not None + # assert QtMultimediaWidgets.QVideoWidgetControl is not None + assert QtMultimediaWidgets.QGraphicsVideoItem is not None + assert QtMultimediaWidgets.QVideoWidget is not None diff --git a/python3.10libs/qtpy/tests/test_qtnetwork.py b/python3.10libs/qtpy/tests/test_qtnetwork.py new file mode 100644 index 0000000..77b91d2 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtnetwork.py @@ -0,0 +1,40 @@ +from qtpy import PYQT6, PYSIDE2, PYSIDE6, QtNetwork + + +def test_qtnetwork(): + """Test the qtpy.QtNetwork namespace""" + assert QtNetwork.QAbstractNetworkCache is not None + assert QtNetwork.QNetworkCacheMetaData is not None + if not PYSIDE2: + assert QtNetwork.QHttpMultiPart is not None + assert QtNetwork.QHttpPart is not None + assert QtNetwork.QNetworkAccessManager is not None + assert QtNetwork.QNetworkCookie is not None + assert QtNetwork.QNetworkCookieJar is not None + assert QtNetwork.QNetworkDiskCache is not None + assert QtNetwork.QNetworkReply is not None + assert QtNetwork.QNetworkRequest is not None + if not (PYSIDE6 or PYQT6): + assert QtNetwork.QNetworkConfigurationManager is not None + assert QtNetwork.QNetworkConfiguration is not None + assert QtNetwork.QNetworkSession is not None + assert QtNetwork.QAuthenticator is not None + assert QtNetwork.QHostAddress is not None + assert QtNetwork.QHostInfo is not None + assert QtNetwork.QNetworkAddressEntry is not None + assert QtNetwork.QNetworkInterface is not None + assert QtNetwork.QNetworkProxy is not None + assert QtNetwork.QNetworkProxyFactory is not None + assert QtNetwork.QNetworkProxyQuery is not None + assert QtNetwork.QAbstractSocket is not None + assert QtNetwork.QLocalServer is not None + assert QtNetwork.QLocalSocket is not None + assert QtNetwork.QTcpServer is not None + assert QtNetwork.QTcpSocket is not None + assert QtNetwork.QUdpSocket is not None + assert QtNetwork.QSslCertificate is not None + assert QtNetwork.QSslCipher is not None + assert QtNetwork.QSslConfiguration is not None + assert QtNetwork.QSslError is not None + assert QtNetwork.QSslKey is not None + assert QtNetwork.QSslSocket is not None diff --git a/python3.10libs/qtpy/tests/test_qtnetworkauth.py b/python3.10libs/qtpy/tests/test_qtnetworkauth.py new file mode 100644 index 0000000..ff9b923 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtnetworkauth.py @@ -0,0 +1,16 @@ +import pytest + +from qtpy import PYQT5, PYQT6, PYSIDE2 + + +@pytest.mark.skipif(PYSIDE2, reason="Not available for PySide2") +def test_qtnetworkauth(): + """Test the qtpy.QtNetworkAuth namespace""" + QtNetworkAuth = pytest.importorskip("qtpy.QtNetworkAuth") + + assert QtNetworkAuth.QAbstractOAuth is not None + assert QtNetworkAuth.QAbstractOAuth2 is not None + assert QtNetworkAuth.QAbstractOAuthReplyHandler is not None + assert QtNetworkAuth.QOAuth1 is not None + assert QtNetworkAuth.QOAuth1Signature is not None + assert QtNetworkAuth.QOAuth2AuthorizationCodeFlow is not None diff --git a/python3.10libs/qtpy/tests/test_qtopengl.py b/python3.10libs/qtpy/tests/test_qtopengl.py new file mode 100644 index 0000000..93bb5bb --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtopengl.py @@ -0,0 +1,21 @@ +def test_qtopengl(): + """Test the qtpy.QtOpenGL namespace""" + from qtpy import QtOpenGL + + assert QtOpenGL.QOpenGLBuffer is not None + assert QtOpenGL.QOpenGLContext is not None + assert QtOpenGL.QOpenGLContextGroup is not None + assert QtOpenGL.QOpenGLDebugLogger is not None + assert QtOpenGL.QOpenGLDebugMessage is not None + assert QtOpenGL.QOpenGLFramebufferObject is not None + assert QtOpenGL.QOpenGLFramebufferObjectFormat is not None + assert QtOpenGL.QOpenGLPixelTransferOptions is not None + assert QtOpenGL.QOpenGLShader is not None + assert QtOpenGL.QOpenGLShaderProgram is not None + assert QtOpenGL.QOpenGLTexture is not None + assert QtOpenGL.QOpenGLTextureBlitter is not None + assert QtOpenGL.QOpenGLVersionProfile is not None + assert QtOpenGL.QOpenGLVertexArrayObject is not None + assert QtOpenGL.QOpenGLWindow is not None + # We do not test for QOpenGLTimeMonitor or QOpenGLTimerQuery as + # they are not present on some architectures such as armhf diff --git a/python3.10libs/qtpy/tests/test_qtopenglwidgets.py b/python3.10libs/qtpy/tests/test_qtopenglwidgets.py new file mode 100644 index 0000000..2271e92 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtopenglwidgets.py @@ -0,0 +1,11 @@ +import pytest + +from qtpy import PYQT5, PYSIDE2 + + +@pytest.mark.skipif(PYSIDE2 or PYQT5, reason="Not available in PySide2/PyQt5") +def test_qtopenglwidgets(): + """Test the qtpy.QtOpenGLWidgets namespace""" + from qtpy import QtOpenGLWidgets + + assert QtOpenGLWidgets.QOpenGLWidget is not None diff --git a/python3.10libs/qtpy/tests/test_qtpdf.py b/python3.10libs/qtpy/tests/test_qtpdf.py new file mode 100644 index 0000000..f9611b1 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtpdf.py @@ -0,0 +1,10 @@ +import pytest + + +def test_qtpdf(): + """Test the qtpy.QtPdf namespace""" + QtPdf = pytest.importorskip("qtpy.QtPdf") + + assert QtPdf.QPdfDocument is not None + assert QtPdf.QPdfLink is not None + assert QtPdf.QPdfSelection is not None diff --git a/python3.10libs/qtpy/tests/test_qtpdfwidgets.py b/python3.10libs/qtpy/tests/test_qtpdfwidgets.py new file mode 100644 index 0000000..55f508c --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtpdfwidgets.py @@ -0,0 +1,8 @@ +import pytest + + +def test_qtpdfwidgets(): + """Test the qtpy.QtPdfWidgets namespace""" + QtPdfWidgets = pytest.importorskip("qtpy.QtPdfWidgets") + + assert QtPdfWidgets.QPdfView is not None diff --git a/python3.10libs/qtpy/tests/test_qtpositioning.py b/python3.10libs/qtpy/tests/test_qtpositioning.py new file mode 100644 index 0000000..adf8f45 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtpositioning.py @@ -0,0 +1,33 @@ +import pytest + +from qtpy import QT6 +from qtpy.tests.utils import using_conda + + +@pytest.mark.skipif( + QT6 and using_conda(), + reason="QPositioning bindings not included in Conda qt-main >= 6.4.3.", +) +def test_qtpositioning(): + """Test the qtpy.QtPositioning namespace""" + from qtpy import QtPositioning + + assert QtPositioning.QGeoAddress is not None + assert QtPositioning.QGeoAreaMonitorInfo is not None + assert QtPositioning.QGeoAreaMonitorSource is not None + assert QtPositioning.QGeoCircle is not None + assert QtPositioning.QGeoCoordinate is not None + assert QtPositioning.QGeoLocation is not None + assert QtPositioning.QGeoPath is not None + # CI for 3.7 uses Qt 5.9 + # assert QtPositioning.QGeoPolygon is not None # New in Qt 5.10 + assert QtPositioning.QGeoPositionInfo is not None + assert QtPositioning.QGeoPositionInfoSource is not None + # QGeoPositionInfoSourceFactory is not available in PyQt + # assert QtPositioning.QGeoPositionInfoSourceFactory is not None # New in Qt 5.2 + # assert QtPositioning.QGeoPositionInfoSourceFactoryV2 is not None # New in Qt 5.14 + assert QtPositioning.QGeoRectangle is not None + assert QtPositioning.QGeoSatelliteInfo is not None + assert QtPositioning.QGeoSatelliteInfoSource is not None + assert QtPositioning.QGeoShape is not None + assert QtPositioning.QNmeaPositionInfoSource is not None diff --git a/python3.10libs/qtpy/tests/test_qtprintsupport.py b/python3.10libs/qtpy/tests/test_qtprintsupport.py new file mode 100644 index 0000000..6a36aa1 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtprintsupport.py @@ -0,0 +1,36 @@ +"""Test QtPrintSupport.""" + +import sys + +import pytest + +from qtpy import QtPrintSupport + + +def test_qtprintsupport(): + """Test the qtpy.QtPrintSupport namespace""" + assert QtPrintSupport.QAbstractPrintDialog is not None + assert QtPrintSupport.QPageSetupDialog is not None + assert QtPrintSupport.QPrintDialog is not None + assert QtPrintSupport.QPrintPreviewDialog is not None + assert QtPrintSupport.QPrintEngine is not None + assert QtPrintSupport.QPrinter is not None + assert QtPrintSupport.QPrinterInfo is not None + assert QtPrintSupport.QPrintPreviewWidget is not None + + +def test_qpagesetupdialog_exec_(): + """Test qtpy.QtPrintSupport.QPageSetupDialog exec_""" + assert QtPrintSupport.QPageSetupDialog.exec_ is not None + + +def test_qprintdialog_exec_(): + """Test qtpy.QtPrintSupport.QPrintDialog exec_""" + assert QtPrintSupport.QPrintDialog.exec_ is not None + + +def test_qprintpreviewwidget_print_(qtbot): + """Test qtpy.QtPrintSupport.QPrintPreviewWidget print_""" + assert QtPrintSupport.QPrintPreviewWidget.print_ is not None + preview_widget = QtPrintSupport.QPrintPreviewWidget() + preview_widget.print_() diff --git a/python3.10libs/qtpy/tests/test_qtpurchasing.py b/python3.10libs/qtpy/tests/test_qtpurchasing.py new file mode 100644 index 0000000..d4c5173 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtpurchasing.py @@ -0,0 +1,10 @@ +import pytest + + +def test_qtpurchasing(): + """Test the qtpy.QtPurchasing namespace""" + QtPurchasing = pytest.importorskip("qtpy.QtPurchasing") + + assert QtPurchasing.QInAppProduct is not None + assert QtPurchasing.QInAppStore is not None + assert QtPurchasing.QInAppTransaction is not None diff --git a/python3.10libs/qtpy/tests/test_qtqml.py b/python3.10libs/qtpy/tests/test_qtqml.py new file mode 100644 index 0000000..9baf91b --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtqml.py @@ -0,0 +1,32 @@ +from qtpy import PYSIDE2, PYSIDE6 + + +def test_qtqml(): + """Test the qtpy.QtQml namespace""" + from qtpy import QtQml + + assert QtQml.QJSEngine is not None + assert QtQml.QJSValue is not None + assert QtQml.QJSValueIterator is not None + assert QtQml.QQmlAbstractUrlInterceptor is not None + assert QtQml.QQmlApplicationEngine is not None + assert QtQml.QQmlComponent is not None + assert QtQml.QQmlContext is not None + assert QtQml.QQmlEngine is not None + assert QtQml.QQmlImageProviderBase is not None + assert QtQml.QQmlError is not None + assert QtQml.QQmlExpression is not None + assert QtQml.QQmlExtensionPlugin is not None + assert QtQml.QQmlFileSelector is not None + assert QtQml.QQmlIncubationController is not None + assert QtQml.QQmlIncubator is not None + if not (PYSIDE2 or PYSIDE6): + # https://wiki.qt.io/Qt_for_Python_Missing_Bindings#QtQml + assert QtQml.QQmlListProperty is not None + assert QtQml.QQmlListReference is not None + assert QtQml.QQmlNetworkAccessManagerFactory is not None + assert QtQml.QQmlParserStatus is not None + assert QtQml.QQmlProperty is not None + assert QtQml.QQmlPropertyValueSource is not None + assert QtQml.QQmlScriptString is not None + assert QtQml.QQmlPropertyMap is not None diff --git a/python3.10libs/qtpy/tests/test_qtquick.py b/python3.10libs/qtpy/tests/test_qtquick.py new file mode 100644 index 0000000..ee7c1ed --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtquick.py @@ -0,0 +1,48 @@ +from qtpy import PYQT5, PYSIDE2 + + +def test_qtquick(): + """Test the qtpy.QtQuick namespace""" + from qtpy import QtQuick + + if PYQT5: + assert QtQuick.QQuickCloseEvent is not None + assert QtQuick.QSGFlatColorMaterial is not None + assert QtQuick.QSGImageNode is not None + assert QtQuick.QSGMaterial is not None + assert QtQuick.QSGMaterialShader is not None + assert QtQuick.QSGOpaqueTextureMaterial is not None + assert QtQuick.QSGRectangleNode is not None + assert QtQuick.QSGRenderNode is not None + assert QtQuick.QSGRendererInterface is not None + assert QtQuick.QSGTextureMaterial is not None + assert QtQuick.QSGVertexColorMaterial is not None + + assert QtQuick.QQuickAsyncImageProvider is not None + assert QtQuick.QQuickFramebufferObject is not None + assert QtQuick.QQuickImageProvider is not None + assert QtQuick.QQuickImageResponse is not None + assert QtQuick.QQuickItem is not None + assert QtQuick.QQuickItemGrabResult is not None + assert QtQuick.QQuickPaintedItem is not None + assert QtQuick.QQuickRenderControl is not None + assert QtQuick.QQuickTextDocument is not None + assert QtQuick.QQuickTextureFactory is not None + assert QtQuick.QQuickView is not None + assert QtQuick.QQuickWindow is not None + if PYQT5 or PYSIDE2: + assert QtQuick.QSGAbstractRenderer is not None + assert QtQuick.QSGEngine is not None + assert QtQuick.QSGBasicGeometryNode is not None + assert QtQuick.QSGClipNode is not None + assert QtQuick.QSGDynamicTexture is not None + assert QtQuick.QSGGeometry is not None + assert QtQuick.QSGGeometryNode is not None + assert QtQuick.QSGMaterialType is not None + assert QtQuick.QSGNode is not None + assert QtQuick.QSGOpacityNode is not None + assert QtQuick.QSGSimpleRectNode is not None + assert QtQuick.QSGSimpleTextureNode is not None + assert QtQuick.QSGTexture is not None + assert QtQuick.QSGTextureProvider is not None + assert QtQuick.QSGTransformNode is not None diff --git a/python3.10libs/qtpy/tests/test_qtquick3d.py b/python3.10libs/qtpy/tests/test_qtquick3d.py new file mode 100644 index 0000000..ca614bd --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtquick3d.py @@ -0,0 +1,10 @@ +import pytest + + +def test_qtquick3d(): + """Test the qtpy.QtQuick3D namespace""" + QtQuick3D = pytest.importorskip("qtpy.QtQuick3D") + + assert QtQuick3D.QQuick3D is not None + assert QtQuick3D.QQuick3DGeometry is not None + assert QtQuick3D.QQuick3DObject is not None diff --git a/python3.10libs/qtpy/tests/test_qtquickcontrols2.py b/python3.10libs/qtpy/tests/test_qtquickcontrols2.py new file mode 100644 index 0000000..a77ef00 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtquickcontrols2.py @@ -0,0 +1,8 @@ +import pytest + + +def test_qtquickcontrols2(): + """Test the qtpy.QtQuickControls2 namespace""" + QtQuickControls2 = pytest.importorskip("qtpy.QtQuickControls2") + + assert QtQuickControls2.QQuickStyle is not None diff --git a/python3.10libs/qtpy/tests/test_qtquickwidgets.py b/python3.10libs/qtpy/tests/test_qtquickwidgets.py new file mode 100644 index 0000000..4765cc1 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtquickwidgets.py @@ -0,0 +1,5 @@ +def test_qtquickwidgets(): + """Test the qtpy.QtQuickWidgets namespace""" + from qtpy import QtQuickWidgets + + assert QtQuickWidgets.QQuickWidget is not None diff --git a/python3.10libs/qtpy/tests/test_qtremoteobjects.py b/python3.10libs/qtpy/tests/test_qtremoteobjects.py new file mode 100644 index 0000000..db009ea --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtremoteobjects.py @@ -0,0 +1,12 @@ +import pytest + + +def test_qtremoteobjects(): + """Test the qtpy.QtRemoteObjects namespace""" + QtRemoteObjects = pytest.importorskip("qtpy.QtRemoteObjects") + + assert QtRemoteObjects.QRemoteObjectAbstractPersistedStore is not None + assert QtRemoteObjects.QRemoteObjectDynamicReplica is not None + assert QtRemoteObjects.QRemoteObjectHost is not None + assert QtRemoteObjects.QRemoteObjectHostBase is not None + assert QtRemoteObjects.QRemoteObjectNode is not None diff --git a/python3.10libs/qtpy/tests/test_qtscxml.py b/python3.10libs/qtpy/tests/test_qtscxml.py new file mode 100644 index 0000000..4003379 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtscxml.py @@ -0,0 +1,10 @@ +import pytest + + +def test_qtscxml(): + """Test the qtpy.QtScxml namespace""" + QtScxml = pytest.importorskip("qtpy.QtScxml") + + assert QtScxml.QScxmlCompiler is not None + assert QtScxml.QScxmlDynamicScxmlServiceFactory is not None + assert QtScxml.QScxmlExecutableContent is not None diff --git a/python3.10libs/qtpy/tests/test_qtsensors.py b/python3.10libs/qtpy/tests/test_qtsensors.py new file mode 100644 index 0000000..b15dc36 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtsensors.py @@ -0,0 +1,10 @@ +import pytest + + +def test_qtsensors(): + """Test the qtpy.QtSensors namespace""" + QtSensors = pytest.importorskip("qtpy.QtSensors") + + assert QtSensors.QAccelerometer is not None + assert QtSensors.QAccelerometerFilter is not None + assert QtSensors.QAccelerometerReading is not None diff --git a/python3.10libs/qtpy/tests/test_qtserialport.py b/python3.10libs/qtpy/tests/test_qtserialport.py new file mode 100644 index 0000000..a50c0ea --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtserialport.py @@ -0,0 +1,12 @@ +import pytest + +from qtpy import PYSIDE2 + + +@pytest.mark.skipif(PYSIDE2, reason="Not available in CI") +def test_qtserialport(): + """Test the qtpy.QtSerialPort namespace""" + QtSerialPort = pytest.importorskip("qtpy.QtSerialPort") + + assert QtSerialPort.QSerialPort is not None + assert QtSerialPort.QSerialPortInfo is not None diff --git a/python3.10libs/qtpy/tests/test_qtsql.py b/python3.10libs/qtpy/tests/test_qtsql.py new file mode 100644 index 0000000..5be5ea4 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtsql.py @@ -0,0 +1,87 @@ +"""Test QtSql.""" + +import sys + +import pytest + +from qtpy import PYSIDE2, PYSIDE_VERSION, QtSql + + +@pytest.fixture +def database_connection(): + """Create a database connection""" + connection = QtSql.QSqlDatabase.addDatabase("QSQLITE") + yield connection + connection.close() + + +def test_qtsql(): + """Test the qtpy.QtSql namespace""" + assert QtSql.QSqlDatabase is not None + assert QtSql.QSqlDriverCreatorBase is not None + assert QtSql.QSqlDriver is not None + assert QtSql.QSqlError is not None + assert QtSql.QSqlField is not None + assert QtSql.QSqlIndex is not None + assert QtSql.QSqlQuery is not None + assert QtSql.QSqlRecord is not None + assert QtSql.QSqlResult is not None + assert QtSql.QSqlQueryModel is not None + assert QtSql.QSqlRelationalDelegate is not None + assert QtSql.QSqlRelation is not None + assert QtSql.QSqlRelationalTableModel is not None + assert QtSql.QSqlTableModel is not None + + # Following modules are not (yet) part of any wrapper: + # QSqlDriverCreator, QSqlDriverPlugin + + +@pytest.mark.skipif( + sys.platform == "win32" and PYSIDE2 and PYSIDE_VERSION.startswith("5.13"), + reason="SQLite driver unavailable on PySide 5.13.2 with Windows", +) +def test_qtsql_members_aliases(database_connection): + """ + Test aliased methods over qtpy.QtSql members including: + + * qtpy.QtSql.QSqlDatabase.exec_ + * qtpy.QtSql.QSqlQuery.exec_ + * qtpy.QtSql.QSqlResult.exec_ + """ + assert QtSql.QSqlDatabase.exec_ is not None + assert QtSql.QSqlQuery.exec_ is not None + assert QtSql.QSqlResult.exec_ is not None + + assert database_connection.open() + database_connection.setDatabaseName("test.sqlite") + QtSql.QSqlDatabase.exec_( + database_connection, + """ + CREATE TABLE test ( + id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE NOT NULL, + name VARCHAR(40) NOT NULL + ) + """, + ) + + # Created table 'test' and 'sqlite_sequence' + assert len(database_connection.tables()) == 2 + + insert_table_query = QtSql.QSqlQuery() + assert insert_table_query.exec_( + """ + INSERT INTO test (name) VALUES ( + "TESTING" + ) + """, + ) + + select_table_query = QtSql.QSqlQuery() + select_table_query.prepare( + """ + SELECT * FROM test + """, + ) + select_table_query.exec_() + record = select_table_query.record() + assert not record.isEmpty() diff --git a/python3.10libs/qtpy/tests/test_qtstatemachine.py b/python3.10libs/qtpy/tests/test_qtstatemachine.py new file mode 100644 index 0000000..5fa986b --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtstatemachine.py @@ -0,0 +1,16 @@ +import pytest + + +def test_qtstatemachine(): + """Test the qtpy.QtStateMachine namespace""" + QtStateMachine = pytest.importorskip("qtpy.QtStateMachine") + + assert QtStateMachine.QAbstractState is not None + assert QtStateMachine.QAbstractTransition is not None + assert QtStateMachine.QEventTransition is not None + assert QtStateMachine.QFinalState is not None + assert QtStateMachine.QHistoryState is not None + assert QtStateMachine.QKeyEventTransition is not None + assert QtStateMachine.QMouseEventTransition is not None + assert QtStateMachine.QSignalTransition is not None + assert QtStateMachine.QState is not None diff --git a/python3.10libs/qtpy/tests/test_qtsvg.py b/python3.10libs/qtpy/tests/test_qtsvg.py new file mode 100644 index 0000000..c39c95f --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtsvg.py @@ -0,0 +1,14 @@ +import pytest + +from qtpy import PYQT6, PYSIDE6 + + +def test_qtsvg(): + """Test the qtpy.QtSvg namespace""" + QtSvg = pytest.importorskip("qtpy.QtSvg") + + if not (PYSIDE6 or PYQT6): + assert QtSvg.QGraphicsSvgItem is not None + assert QtSvg.QSvgWidget is not None + assert QtSvg.QSvgGenerator is not None + assert QtSvg.QSvgRenderer is not None diff --git a/python3.10libs/qtpy/tests/test_qtsvgwidgets.py b/python3.10libs/qtpy/tests/test_qtsvgwidgets.py new file mode 100644 index 0000000..7533925 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtsvgwidgets.py @@ -0,0 +1,9 @@ +import pytest + + +def test_qtsvgwidgets(): + """Test the qtpy.QtSvgWidgets namespace""" + QtSvgWidgets = pytest.importorskip("qtpy.QtSvgWidgets") + + assert QtSvgWidgets.QGraphicsSvgItem is not None + assert QtSvgWidgets.QSvgWidget is not None diff --git a/python3.10libs/qtpy/tests/test_qttest.py b/python3.10libs/qtpy/tests/test_qttest.py new file mode 100644 index 0000000..2d67439 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qttest.py @@ -0,0 +1,29 @@ +import pytest +from packaging import version + +from qtpy import PYQT5, PYQT6, PYQT_VERSION, PYSIDE6, QtTest + + +def test_qttest(): + """Test the qtpy.QtTest namespace""" + assert QtTest.QTest is not None + + if PYQT5 or PYQT6 or PYSIDE6: + assert QtTest.QSignalSpy is not None + + if ( + (PYQT5 and version.parse(PYQT_VERSION) >= version.parse("5.11")) + or PYQT6 + or PYSIDE6 + ): + assert QtTest.QAbstractItemModelTester is not None + + +@pytest.mark.skipif( + PYQT5 and PYQT_VERSION.startswith("5.9"), + reason="A specific setup with at least sip 4.9.9 is needed for PyQt5 5.9.*" + "to work with scoped enum access", +) +def test_enum_access(): + """Test scoped and unscoped enum access for qtpy.QtTest.*.""" + assert QtTest.QTest.Click == QtTest.QTest.KeyAction.Click diff --git a/python3.10libs/qtpy/tests/test_qttexttospeech.py b/python3.10libs/qtpy/tests/test_qttexttospeech.py new file mode 100644 index 0000000..bcb97f0 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qttexttospeech.py @@ -0,0 +1,22 @@ +import pytest +from packaging import version + +from qtpy import PYQT5, PYQT_VERSION, PYSIDE2 + + +@pytest.mark.skipif( + not ( + (PYQT5 and version.parse(PYQT_VERSION) >= version.parse("5.15.1")) + or PYSIDE2 + ), + reason="Only available in Qt5 bindings (PyQt5 >= 5.15.1 or PySide2)", +) +def test_qttexttospeech(): + """Test the qtpy.QtTextToSpeech namespace.""" + from qtpy import QtTextToSpeech + + assert QtTextToSpeech.QTextToSpeech is not None + assert QtTextToSpeech.QVoice is not None + + if PYSIDE2: + assert QtTextToSpeech.QTextToSpeechEngine is not None diff --git a/python3.10libs/qtpy/tests/test_qtuitools.py b/python3.10libs/qtpy/tests/test_qtuitools.py new file mode 100644 index 0000000..13ee402 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtuitools.py @@ -0,0 +1,8 @@ +import pytest + + +def test_qtuitools(): + """Test the qtpy.QtUiTools namespace""" + QtUiTools = pytest.importorskip("qtpy.QtUiTools") + + assert QtUiTools.QUiLoader is not None diff --git a/python3.10libs/qtpy/tests/test_qtwebchannel.py b/python3.10libs/qtpy/tests/test_qtwebchannel.py new file mode 100644 index 0000000..8a364ba --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtwebchannel.py @@ -0,0 +1,6 @@ +def test_qtwebchannel(): + """Test the qtpy.QtWebChannel namespace""" + from qtpy import QtWebChannel + + assert QtWebChannel.QWebChannel is not None + assert QtWebChannel.QWebChannelAbstractTransport is not None diff --git a/python3.10libs/qtpy/tests/test_qtwebenginecore.py b/python3.10libs/qtpy/tests/test_qtwebenginecore.py new file mode 100644 index 0000000..8f2b8c9 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtwebenginecore.py @@ -0,0 +1,8 @@ +import pytest + + +def test_qtwebenginecore(): + """Test the qtpy.QtWebEngineCore namespace""" + QtWebEngineCore = pytest.importorskip("qtpy.QtWebEngineCore") + + assert QtWebEngineCore.QWebEngineHttpRequest is not None diff --git a/python3.10libs/qtpy/tests/test_qtwebenginequick.py b/python3.10libs/qtpy/tests/test_qtwebenginequick.py new file mode 100644 index 0000000..cfeda74 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtwebenginequick.py @@ -0,0 +1,13 @@ +import pytest + +from qtpy import PYQT5, PYSIDE2 + + +@pytest.mark.skipif(PYQT5 or PYSIDE2, reason="Only available in Qt6 bindings") +def test_qtwebenginequick(): + """Test the qtpy.QtWebEngineQuick namespace""" + + QtWebEngineQuick = pytest.importorskip("qtpy.QtWebEngineQuick") + + assert QtWebEngineQuick.QtWebEngineQuick is not None + assert QtWebEngineQuick.QQuickWebEngineProfile is not None diff --git a/python3.10libs/qtpy/tests/test_qtwebenginewidgets.py b/python3.10libs/qtpy/tests/test_qtwebenginewidgets.py new file mode 100644 index 0000000..c7c4d36 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtwebenginewidgets.py @@ -0,0 +1,24 @@ +import pytest +from packaging import version + +from qtpy import PYQT5, PYQT6, PYQT_VERSION, PYSIDE2, PYSIDE6, PYSIDE_VERSION + + +@pytest.mark.skipif( + not ( + (PYQT6 and version.parse(PYQT_VERSION) >= version.parse("6.2")) + or (PYSIDE6 and version.parse(PYSIDE_VERSION) >= version.parse("6.2")) + or PYQT5 + or PYSIDE2 + ), + reason="Only available in Qt<6,>=6.2 bindings", +) +def test_qtwebenginewidgets(): + """Test the qtpy.QtWebEngineWidget namespace""" + + QtWebEngineWidgets = pytest.importorskip("qtpy.QtWebEngineWidgets") + + assert QtWebEngineWidgets.QWebEnginePage is not None + assert QtWebEngineWidgets.QWebEngineView is not None + assert QtWebEngineWidgets.QWebEngineSettings is not None + assert QtWebEngineWidgets.QWebEngineScript is not None diff --git a/python3.10libs/qtpy/tests/test_qtwebsockets.py b/python3.10libs/qtpy/tests/test_qtwebsockets.py new file mode 100644 index 0000000..ae16902 --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtwebsockets.py @@ -0,0 +1,9 @@ +def test_qtwebsockets(): + """Test the qtpy.QtWebSockets namespace""" + from qtpy import QtWebSockets + + assert QtWebSockets.QMaskGenerator is not None + assert QtWebSockets.QWebSocket is not None + assert QtWebSockets.QWebSocketCorsAuthenticator is not None + assert QtWebSockets.QWebSocketProtocol is not None + assert QtWebSockets.QWebSocketServer is not None diff --git a/python3.10libs/qtpy/tests/test_qtwidgets.py b/python3.10libs/qtpy/tests/test_qtwidgets.py new file mode 100644 index 0000000..70c937a --- /dev/null +++ b/python3.10libs/qtpy/tests/test_qtwidgets.py @@ -0,0 +1,291 @@ +"""Test QtWidgets.""" +import contextlib +import sys +from time import sleep + +import pytest +from pytestqt.exceptions import TimeoutError + +from qtpy import ( + PYQT5, + PYQT6, + PYQT_VERSION, + PYSIDE2, + PYSIDE6, + QtCore, + QtGui, + QtWidgets, +) +from qtpy.tests.utils import not_using_conda, using_conda + + +def test_qtextedit_functions(qtbot, pdf_writer): + """Test functions mapping for QtWidgets.QTextEdit.""" + assert QtWidgets.QTextEdit.setTabStopWidth + assert QtWidgets.QTextEdit.tabStopWidth + assert QtWidgets.QTextEdit.print_ + textedit_widget = QtWidgets.QTextEdit(None) + textedit_widget.setTabStopWidth(90) + assert textedit_widget.tabStopWidth() == 90 + print_device, output_path = pdf_writer + textedit_widget.print_(print_device) + assert output_path.exists() + + +def test_qlineedit_functions(): + """Test functions mapping for QtWidgets.QLineEdit""" + assert QtWidgets.QLineEdit.getTextMargins + + +def test_what_moved_to_qtgui_in_qt6(): + """Test that we move back what has been moved to QtGui in Qt6""" + assert QtWidgets.QAction is not None + assert QtWidgets.QActionGroup is not None + assert QtWidgets.QFileSystemModel is not None + assert QtWidgets.QShortcut is not None + assert QtWidgets.QUndoCommand is not None + + +def test_qplaintextedit_functions(qtbot, pdf_writer): + """Test functions mapping for QtWidgets.QPlainTextEdit.""" + assert QtWidgets.QPlainTextEdit.setTabStopWidth + assert QtWidgets.QPlainTextEdit.tabStopWidth + assert QtWidgets.QPlainTextEdit.print_ + plaintextedit_widget = QtWidgets.QPlainTextEdit(None) + plaintextedit_widget.setTabStopWidth(90) + assert plaintextedit_widget.tabStopWidth() == 90 + print_device, output_path = pdf_writer + plaintextedit_widget.print_(print_device) + assert output_path.exists() + + +def test_QApplication_exec_(): + """Test `QtWidgets.QApplication.exec_`""" + assert QtWidgets.QApplication.exec_ is not None + app = QtWidgets.QApplication.instance() or QtWidgets.QApplication( + [sys.executable, __file__], + ) + assert app is not None + QtCore.QTimer.singleShot(100, QtWidgets.QApplication.instance().quit) + QtWidgets.QApplication.exec_() + app = QtWidgets.QApplication.instance() or QtWidgets.QApplication( + [sys.executable, __file__], + ) + assert app is not None + QtCore.QTimer.singleShot(100, QtWidgets.QApplication.instance().quit) + app.exec_() + + +@pytest.mark.skipif( + sys.platform == "darwin" and sys.version_info[:2] == (3, 7), + reason="Stalls on macOS CI with Python 3.7", +) +def test_qdialog_functions(qtbot): + """Test functions mapping for QtWidgets.QDialog.""" + assert QtWidgets.QDialog.exec_ + dialog = QtWidgets.QDialog(None) + QtCore.QTimer.singleShot(100, dialog.accept) + dialog.exec_() + + +@pytest.mark.skipif( + sys.platform == "darwin" and sys.version_info[:2] == (3, 7), + reason="Stalls on macOS CI with Python 3.7", +) +def test_qdialog_subclass(qtbot): + """Test functions mapping for QtWidgets.QDialog when using a subclass""" + assert QtWidgets.QDialog.exec_ + + class CustomDialog(QtWidgets.QDialog): + def __init__(self): + super().__init__(None) + self.setWindowTitle("Testing") + + assert CustomDialog.exec_ + dialog = CustomDialog() + QtCore.QTimer.singleShot(100, dialog.accept) + dialog.exec_() + + +@pytest.mark.skipif( + sys.platform == "darwin" and sys.version_info[:2] == (3, 7), + reason="Stalls on macOS CI with Python 3.7", +) +def test_QMenu_functions(qtbot): + """Test functions mapping for `QtWidgets.QMenu`.""" + # A window is required for static calls + window = QtWidgets.QMainWindow() + menu = QtWidgets.QMenu(window) + menu.addAction("QtPy") + menu.addAction("QtPy with a shortcut", QtGui.QKeySequence.UnknownKey) + menu.addAction( + QtGui.QIcon(), + "QtPy with an icon and a shortcut", + QtGui.QKeySequence.UnknownKey, + ) + window.show() + + with qtbot.waitExposed(window): + # Call `exec_` of a `QMenu` instance + QtCore.QTimer.singleShot(100, menu.close) + menu.exec_() + + # Call static `QMenu.exec_` + QtCore.QTimer.singleShot( + 100, + lambda: qtbot.keyClick( + QtWidgets.QApplication.widgetAt(1, 1), + QtCore.Qt.Key_Escape, + ), + ) + QtWidgets.QMenu.exec_(menu.actions(), QtCore.QPoint(1, 1)) + + +@pytest.mark.skipif( + sys.platform == "darwin" and sys.version_info[:2] == (3, 7), + reason="Stalls on macOS CI with Python 3.7", +) +def test_QToolBar_functions(qtbot): + """Test `QtWidgets.QToolBar.addAction` compatibility with Qt6 arguments' order.""" + toolbar = QtWidgets.QToolBar() + toolbar.addAction("QtPy with a shortcut", QtGui.QKeySequence.UnknownKey) + toolbar.addAction( + QtGui.QIcon(), + "QtPy with an icon and a shortcut", + QtGui.QKeySequence.UnknownKey, + ) + + +@pytest.mark.skipif( + PYQT5 and PYQT_VERSION.startswith("5.9"), + reason="A specific setup with at least sip 4.9.9 is needed for PyQt5 5.9.*" + "to work with scoped enum access", +) +def test_enum_access(): + """Test scoped and unscoped enum access for qtpy.QtWidgets.*.""" + assert ( + QtWidgets.QFileDialog.AcceptOpen + == QtWidgets.QFileDialog.AcceptMode.AcceptOpen + ) + assert ( + QtWidgets.QMessageBox.InvalidRole + == QtWidgets.QMessageBox.ButtonRole.InvalidRole + ) + assert QtWidgets.QStyle.State_None == QtWidgets.QStyle.StateFlag.State_None + assert ( + QtWidgets.QSlider.TicksLeft + == QtWidgets.QSlider.TickPosition.TicksAbove + ) + assert ( + QtWidgets.QStyle.SC_SliderGroove + == QtWidgets.QStyle.SubControl.SC_SliderGroove + ) + + +def test_opengl_imports(): + """ + Test for presence of QOpenGLWidget. + + QOpenGLWidget was a member of QtWidgets in Qt5, but moved to QtOpenGLWidgets in Qt6. + QtPy makes QOpenGLWidget available in QtWidgets to maintain compatibility. + """ + assert QtWidgets.QOpenGLWidget is not None + + +@pytest.mark.skipif( + sys.platform == "darwin" + and sys.version_info[:2] == (3, 7) + and (PYQT5 or PYSIDE2), + reason="Crashes on macOS with Python 3.7 with 'Illegal instruction: 4'", +) +@pytest.mark.parametrize("keyword", ["dir", "directory"]) +@pytest.mark.parametrize("instance", [True, False]) +def test_qfiledialog_dir_compat(tmp_path, qtbot, keyword, instance): + """ + This function is testing if the decorators that renamed the dir/directory + keyword are working. + + It may stop working if the Qt bindings do some overwriting of the methods + in constructor. It should not happen, but the PySide team + did similar things in the past (like overwriting enum module in + PySide6==6.3.2). + + keyword: str + The keyword that should be used in the function call. + instance: bool + If True, the function is called on the instance of the QFileDialog, + otherwise on the class. + """ + + class CloseThread(QtCore.QThread): + """ + On some implementations the `getExistingDirectory` functions starts own + event loop that will not trigger QTimer started before the call. Until + the dialog is closed the main event loop will be stopped. + + Because of this it is required to use the thread to interact with the + dialog. + """ + + def run(self, allow_restart=True): + sleep(0.5) + need_restart = allow_restart + app = QtWidgets.QApplication.instance() + for dlg in app.topLevelWidgets(): + if ( + not isinstance(dlg, QtWidgets.QFileDialog) + or dlg.isHidden() + ): + continue + # "when implemented this I try to use: + # * dlg.close() - On Qt6 it will close the dialog, but it will + # not restart the main event loop. + # * dlg.accept() - It ends with information thar `accept` and + # `reject` of such created dialog can not be called. + # * accept dialog with enter - It works, but it cannot be + # called to early after dialog is shown + qtbot.keyClick(dlg, QtCore.Qt.Key_Enter) + need_restart = False + sleep(0.1) + for dlg in app.topLevelWidgets(): + # As described above, it may happen that dialog is not closed after first using enter. + # in such case we call `run` function again. The 0.5s sleep is enough for the second enter to close the dialog. + if ( + not isinstance(dlg, QtWidgets.QFileDialog) + or dlg.isHidden() + ): + continue + self.run(allow_restart=False) + return + + if need_restart: + self.run() + + # We need to use the `DontUseNativeDialog` option to be able to interact + # with it from code. + try: + opt = QtWidgets.QFileDialog.Option.DontUseNativeDialog + except AttributeError: + # old qt5 bindings + opt = QtWidgets.QFileDialog.DontUseNativeDialog + + kwargs = { + "caption": "Select a directory", + keyword: str(tmp_path), + "options": opt, + } + + thr = CloseThread() + thr.start() + qtbot.waitUntil(thr.isRunning, timeout=1000) + dlg = QtWidgets.QFileDialog() if instance else QtWidgets.QFileDialog + dlg.getExistingDirectory(**kwargs) + qtbot.waitUntil(thr.isFinished, timeout=3000) + + +def test_qfiledialog_flags_typedef(): + """ + Test existence of `QFlags