diff --git a/.travis.yml b/.travis.yml index e9f8f87..7d17c19 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,7 @@ +# Config file for automatic testing at travis-ci.org language: python -python: - - "2.7" -install: "pip install -r requirements.dev" -script: py.test \ No newline at end of file +env: + - TOXENV=py27 +install: + - pip install tox +script: tox diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 536dee0..78d21ca 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,8 +13,4 @@ Pull-requests are welcomed! ## Testing 1. Install dev requirements by running `pip install -r requirements.dev` -2. Run tests using `py.test [optional/path/to/tests]` - -You can also enable coverage reports when running tests by using ``--cov slashed/path`` option to specify a path to package report for which should be gathered, and ``--cov-report (html|xml|annotate)`` to specify type of coverage report you want to receive. - -Use `-v` to make tests output more verbose. +2. Run tests using `py.test --cov nefertari tests` diff --git a/README.md b/README.md index 647161b..10e29d8 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ # `Nefertari` [![Build Status](https://travis-ci.org/brandicted/nefertari.svg?branch=master)](https://travis-ci.org/brandicted/nefertari) -[![Documentation Status](https://readthedocs.org/projects/nefertari/badge/?version=master)](https://readthedocs.org/projects/nefertari/?badge=master) +[![Documentation Status](https://readthedocs.org/projects/nefertari/badge/?version=master)](http://nefertari.readthedocs.org/en/master/) Nefertari is a REST API framework sitting on top of [Pyramid](https://github.com/Pylons/pyramid) and [ElasticSearch](https://www.elastic.co/downloads/elasticsearch). She currently offers two backend engines: [SQLA](https://github.com/brandicted/nefertari-sqla) and [MongoDB](https://github.com/brandicted/nefertari-mongodb). - -You can read the documentation on [readthedocs](https://nefertari.readthedocs.org/en/latest/). diff --git a/VERSION b/VERSION index 7dff5b8..9325c3c 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.2.1 \ No newline at end of file +0.3.0 \ No newline at end of file diff --git a/docs/source/acls.rst b/docs/source/acls.rst deleted file mode 100644 index 4a14094..0000000 --- a/docs/source/acls.rst +++ /dev/null @@ -1,12 +0,0 @@ -Authentication & Authorization -============================== - -Nefertari currently supports the default Pyramid "auth ticket" cookie method of authentication. - -For authorizing access to specific resources, Nefertari uses standard Pyramid access control lists. `See the documentation on Pyramid ACLs `_ to understand how to extend and customize them. - -ACL API -------- - -.. automodule:: nefertari.acl - :members: diff --git a/docs/source/auth.rst b/docs/source/auth.rst new file mode 100644 index 0000000..0eb6ad6 --- /dev/null +++ b/docs/source/auth.rst @@ -0,0 +1,38 @@ +Authentication & Security +========================= + +Set `auth = true` in you .ini file to enable authentication. + +Ticket Auth +----------- + +Nefertari currently supports the default Pyramid "auth ticket" cookie method of authentication. + +Token Auth +---------- + +(under development) + +Visible fields in views +----------------------- + +You can control which fields to display to both authenticated users and unauthenticated users by defining `_auth_fields` and `_public_fields` respectively in your models. + +ACL API +------- + +For authorizing access to specific resources, Nefertari uses standard Pyramid access control lists. `See the documentation on Pyramid ACLs `_ to understand how to extend and customize them. + +.. automodule:: nefertari.acl + :members: + +CORS +---- + +To enable CORS headers, set the following lines in your .ini file: + +.. code-block:: ini + + cors.enable = true + cors.allow_origins = http://localhost + cors.allow_credentials = true diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 203c547..f9afcbb 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,16 +1,21 @@ Changelog ========= +* :release:`0.3.0 <2015-05-18>` +* :support:`-` Step-by-step 'Getting started' guide +* :bug:`- major` Fixed several issues related to ElasticSearch indexing +* :support:`-` Increased test coverave +* :feature:`-` Added ability to PATCH/DELETE collections +* :feature:`-` Implemented API output control by field (apply_privacy wrapper) + * :release:`0.2.1 <2015-04-21>` -* :support:`0` Specify any field as primary, and have it respected by all relations endpoints. -* :feature:`0` Added DictField type. -* :support:`0` URL parsing for DictField and ListField values with _m=VERB options. +* :bug:`-` Fixed URL parsing for DictField and ListField values with _m=VERB options * :release:`0.2.0 <2015-04-07>` -* :feature:`0` Added script to index Elasticsearch models. -* :feature:`0` Started adding tests. -* :support:`0` Listing on PyPI. -* :support:`0` Improved docs. +* :feature:`-` Added script to index Elasticsearch models +* :feature:`-` Started adding tests +* :support:`-` Listing on PyPI +* :support:`-` Improved docs * :release:`0.1.1 <2015-04-01>` -* :support:`0` Initial release after two years of development as "Presto". Now with database engines! Originally extracted and generalized from the Brandicted API which only used MongoDB. \ No newline at end of file +* :support:`-` Initial release after two years of development as "Presto". Now with database engines! Originally extracted and generalized from the Brandicted API which only used MongoDB. diff --git a/docs/source/conf.py b/docs/source/conf.py index 688657c..9bf235d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,19 +12,17 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys import os -import shlex # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -47,7 +45,7 @@ source_suffix = '.rst' # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' @@ -75,9 +73,9 @@ # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -85,27 +83,27 @@ # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -113,7 +111,8 @@ # -- Options for HTML output ---------------------------------------------- -# on_rtd is whether we are on readthedocs.org, this line of code grabbed from docs.readthedocs.org +# on_rtd is whether we are on readthedocs.org, this line of code grabbed from +# docs.readthedocs.org on_rtd = os.environ.get('READTHEDOCS', None) == 'True' if not on_rtd: # only import and set the theme if we're building docs locally @@ -121,7 +120,8 @@ html_theme = 'sphinx_rtd_theme' html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] -# otherwise, readthedocs.org uses their theme by default, so no need to specify it +# otherwise, readthedocs.org uses their theme by default, so no need to +# specify it # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. @@ -130,26 +130,26 @@ # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -159,62 +159,62 @@ # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' -#html_search_language = 'en' +# html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} +# html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' +# html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. htmlhelp_basename = 'Nefertaridoc' @@ -222,46 +222,46 @@ # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', + # The paper size ('letterpaper' or 'a4paper'). + # 'papersize': 'letterpaper', -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', + # The font size ('10pt', '11pt' or '12pt'). + # 'pointsize': '10pt', -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # Additional stuff for the LaTeX preamble. + # 'preamble': '', -# Latex figure (float) alignment -#'figure_align': 'htbp', + # Latex figure (float) alignment + # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Nefertari.tex', u'Nefertari Documentation', - u'Brandicted', 'manual'), + (master_doc, 'Nefertari.tex', u'Nefertari Documentation', + u'Brandicted', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- @@ -274,7 +274,7 @@ ] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -283,19 +283,19 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Nefertari', u'Nefertari Documentation', - author, 'Nefertari', 'One line description of project.', - 'Miscellaneous'), + (master_doc, 'Nefertari', u'Nefertari Documentation', + author, 'Nefertari', 'One line description of project.', + 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False diff --git a/docs/source/database_backends.rst b/docs/source/database_backends.rst index 4cb7039..cb04852 100644 --- a/docs/source/database_backends.rst +++ b/docs/source/database_backends.rst @@ -1,25 +1,54 @@ Database Backends ================= -Nefertari implements database engines on top of two different ORMs: `SQLAlchemy `_ and `MongoEngine `_. As such, Nefertari can be used with exising models implemented using either mapper library. +Introduction +------------ -These two engines wrap the underlying APIs of each ORM and provide a standardized syntax for using either one, making it easy to switch between them with minimal changes. +Nefertari implements database engines on top of two different ORMs: `SQLAlchemy `_ and `MongoEngine `_. These two engines wrap the underlying APIs of each ORM and provide a standardized syntax for using either one, making it easy to switch between them with minimal changes. -Each Nefertari engine is developed in its own repository: +Each Nefertari engine is maintained in its own repository: -* `SQLA Engine `_ -* `MongoDB Engine `_ +* `Nefertari SQLA Engine `_ +* `Nefertari MongoDB Engine `_ -Nefertari can use `Elasticsearch `_ to read/GET any given resource. You can read more about **ESBaseDocument** in the `Wrapper API `_ section below. +Nefertari can either use `Elasticsearch `_ or the database engine itself to read (GET) any given resource. You can read more about **ESBaseDocument** in the `Wrapper API `_ section below. +Field abstractions +------------------ + +* BigIntegerField +* BooleanField +* DateField +* DateTimeField +* ChoiceField +* FloatField +* IntegerField +* IntervalField +* BinaryField +* DecimalField +* PickleField +* SmallIntegerField +* StringField +* TextField +* TimeField +* UnicodeField +* UnicodeTextField +* Relationship +* IdField +* ForeignKeyField +* ListField +* DictField Wrapper API ----------- -Both of the database engines used by Nefertari implement a similar "Wrapper API" for developers to use within a Nefertari project. Use the following base classes in your project to leverage the powers of Nefertari. To see them in action, check out the `example project `_. +Both of the database engines used by Nefertari implement a similar "Wrapper API" for developers to use within a Nefertari project. You can read more about either engine's in their respective documentation: + + * `Nefertari SQLA documentation `_ + * `Nefertari MongoDB documentation `_ **BaseMixin** - Mixin with a most of the API of *BaseDocument*. *BaseDocument* subclasses from this mixin. + Mixin with most of the API of *BaseDocument*. *BaseDocument* subclasses from this mixin. **BaseDocument** Base for regular models defined in your application. Just subclass it to define your model's fields. Relevant attributes: @@ -50,27 +79,3 @@ Both of the database engines used by Nefertari implement a similar "Wrapper API" **relationship_cls(field, model_cls)** Return class which is pointed to by relationship field *field* from model *model_cls*. - -Field abstractions -------------------- - -* BigIntegerField -* BooleanField -* DateField -* DateTimeField -* ChoiceField -* FloatField -* IntegerField -* IntervalField -* BinaryField -* DecimalField -* PickleField -* SmallIntegerField -* StringField -* TextField -* TimeField -* UnicodeField -* UnicodeTextField -* Relationship -* PrimaryKeyField -* ForeignKeyField diff --git a/docs/source/development_tools.rst b/docs/source/development_tools.rst index f0be22e..5e96a7e 100644 --- a/docs/source/development_tools.rst +++ b/docs/source/development_tools.rst @@ -2,7 +2,7 @@ Development Tools ================= Indexing in ElasticSearch ------------------------------ +------------------------- ``nefertari.index`` console script can be used to manually (re-)index models from your database engine to ElasticSearch. @@ -13,7 +13,7 @@ You can run it like so:: The available options are: --config specify ini file to use (required) ---models list of dotted paths of models to index. Models must be subclasses of ESBaseDocument. +--models list of models to index (e.g. User). Models must subclass ESBaseDocument. --params URL-encoded parameters for each module --quiet "quiet mode" (surpress output) --index Specify name of index. E.g. the slug at the end of http://localhost:9200/example_api diff --git a/docs/source/example_project.rst b/docs/source/example_project.rst index 912891c..fa5d651 100644 --- a/docs/source/example_project.rst +++ b/docs/source/example_project.rst @@ -1,4 +1,4 @@ Example Project =============== -For an example of how to use Nefertari see the `Example Project `_. \ No newline at end of file +For a more complete example of a Pyramid project using Nefertari, you can take a look at the `Example Project `_. diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index 95ef3ba..e1b630a 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -1,168 +1,147 @@ Getting started =============== -To get started, follow these steps: +**1. Create a Pyramid "starter" project** in a virtualenv directory (see the `pyramid documentation `_ if you've never done that before) -0. Install nefertari and either nefertari-sqla or nefertari-mongodb for the database backend you want to use:: +.. code-block:: shell - pip install nefertari nefertari-sqla nefertari-mongodb + $ mkvirtualenv MyProject + $ pip install nefertari + $ pcreate -s starter MyProject + $ cd MyProject + $ pip install -e . +Install the database backend of your choice, e.g. sqla or mongodb -1. `First, create a normal Pyramid app `_. In the "main" module, import nefertari and then declare your resources like so:: +.. code-block:: shell - from pyramid.config import Configurator - from pyramid.authorization import ACLAuthorizationPolicy - from pyramid.authentication import AuthTktAuthenticationPolicy - from nefertari.acl import RootACL - - - def main(global_config, **settings): - # Nefertari encourages using ACLAuthorizationPolicy and provides a few - # base ACL classes. Choice of authentication policy is completely - # up to you. - config = Configurator( - settings=settings, - authorization_policy=ACLAuthorizationPolicy(), - authentication_policy=AuthTktAuthenticationPolicy(), - root_factory=RootACL, - ) - - # Include 'nefertari.engine' to let her perform the engine setup - config.include('nefertari.engine') + $ pip install nefertari- - # Include nefertari and elasticsearch - config.include('nefertari') - config.include('nefertari.elasticsearch') - # Include your models modules after inclusion of 'nefertari.engine' - config.include('my_app.models') - - # Declare your resources - root = config.get_root_resource() - user = root.add('user', 'users', factory='my_app.acl.UsersACL') - user_story = user.add('story', 'stories') - user_story.add('likes') - - # Use the engine helper to bootstrap the db - from nefertari.engine import setup_database - setup_database(config) - - config.commit() - # Launch the server in the way that works for you - return config.make_wsgi_app() - - -And here is the content of our ``acl.py``. Check out ACLs that are included in Nefertari in :doc:`acls` section:: - - from nefertari.acl import GuestACL - from .models import User - - class UserACL(GuestACL): - __context_class__ = User - - -2. Add Nefertari settings to your settings file (e.g. ``local.ini``) under ``[app:your_app_name]`` section:: +**2. Add a few settings** to development.ini, inside the ``[app:main]`` section .. code-block:: ini - # Set 'nefertari.engine' to the engine you want (e.g. nefertari_sqla or nefertari_mongodb) - nefertari.engine = - # Elasticsearh settings elasticsearch.hosts = localhost:9200 elasticsearch.sniff = false - elasticsearch.index_name = my_app + elasticsearch.index_name = myproject elasticsearch.index.disable = false - # Dependine on the engine you chose, provide database-specific settings. - # E.g. for sqla: - sqlalchemy.url = postgresql://:/dbname + # disable authentication + auth = false + + # Set '' (e.g. nefertari_sqla or nefertari_mongodb) + nefertari.engine = + +.. code-block:: ini + + # For sqla: + sqlalchemy.url = postgresql://localhost:5432/myproject + +.. code-block:: ini # For mongo: mongodb.host = localhost mongodb.port = 27017 - mongodb.db = dbname - - # Other nefertari settings - # Auth enabled/disabled - auth = false - # Debug enabled/disabled - debug = true - # Max age of the static cache - static_cache_max_age = 7200 - # Max number of objects returned from public APIs - public_max_limit = 100 + mongodb.db = myproject -3. The corresponding views would look something like the following. Defined actions are: index (GET), show (GET), create(POST), update(PUT/PATCH), delete(DELETE):: +**3. Replace the file** `myproject/__init__.py` .. code-block:: python - from nefertari.view import BaseView - from nefertari.engine import JSONEncoder + from pyramid.config import Configurator - class UsersView(BaseView): - _model_class = User + def main(global_config, **settings): - def show(self, id): - return {} + config = Configurator(settings=settings) + config.include('nefertari.engine') + config.include('nefertari') + config.include('nefertari.elasticsearch') - def create(self): - return HTTPCreated() + # Include your `models` modules + config.include('myproject.models') - def index(self): - return {'data'=['item1', 'item2']} + root = config.get_root_resource() - def delete(self, id): - return HTTPOk() + from .models import Item + root.add( + 'myitem', 'myitems', + view='myproject.views.ItemsView') + # Use the engine helper to bootstrap the db + from nefertari.engine import setup_database + setup_database(config) + + config.commit() + # Launch the server in the way that works for you + return config.make_wsgi_app() + + +**4. Replace the file** `myproject/views.py` - class UserStoriesView(BaseView): - _model_class = UserStory +.. code-block:: python - def index(self, user_id): - # Get stories here - stories = [] - return dict(data=stories, count=len(stories)) + from nefertari.view import BaseView + from nefertari.elasticsearch import ES + from nefertari.json_httpexceptions import ( + JHTTPCreated, JHTTPOk) - def show(self, user_id, id): - # Get a particular story here - return story_dict + from .models import Item - def delete(self, user_id, id): - return HTTPOK() + class ItemsView(BaseView): + _model_class = Item - class UserStoryLikesView(BaseView): - _model_class = UserStoryLike + def index(self): + self._query_params.process_int_param('_limit', 20) + return ES('Item').get_collection(**self._query_params) - def show(self, user_id, story_id): - # Get a particular story like here - return user_story_like_dict + def show(self, **kwargs): + return ES('Item').get_resource(**kwargs) - def delete(self, user_id, story_id): - return HTTPOK() + def create(self): + story = Item(**self._json_params) + story.save() + pk_field = Item.pk_field() + return JHTTPCreated( + resource=story.to_dict(), + request=self.request, + ) + def update(self, **kwargs): + pk_field = Item.pk_field() + story = Item.get_resource(**kwargs).update(self._json_params) + return JHTTPOk() -Each view must define the following properties: + def delete(self, **kwargs): + Item._delete(**kwargs) + return JHTTPOk() - * *_model_class*: class of the model that is being served by this view. -Optional properties: +**5. Create the file** `myproject/models.py` + +.. code-block:: python + + from nefertari import engine as eng + from nefertari.engine import ESBaseDocument - * *_json_encoder*: encoder to encode objects to JSON. Database-specific encoders are available at ``nefertari.engine.JSONEncoder``. + def includeme(config): + pass -Your views should sit in a package and each module of that package should contain views for a particular root level route. In our example, the ``users`` route view must be at ``views.users.UsersView``. + class Item(ESBaseDocument): + __tablename__ = 'items' -If its not defined in your view, Nefertari will return HTTPMethodNotAllowed by default. -Note that in case of a singular resource (i.e. Likes), there is no "index" view and "show" returns only the one item. -Also, note that "delete", "update" and other actions that would normally require an id, do not in Nefertari, because there is only one object being referenced. + id = eng.IdField(primary_key=True) + name = eng.StringField() + description = eng.TextField() -4. Define your models using abstractions imported from 'nefertari.engine'. For more information on abstractions, see :doc:`engines/index` section. -5. Run your app with ``pserve settings_file.ini`` and request the routes you defined. +**5. Run your app** +.. code-block:: shell -In case you need to tunnel PUT,PATCH and DELETE via POST in a browser one must use "_method=" or the shorthand "_m" along with other POST parameters as if they were normal URL params. E.g. http://myapi.com/api/stories?_m=POST&name=stuff&user=bob". + $ pserve development.ini diff --git a/docs/source/index.rst b/docs/source/index.rst index 425ce8a..ae10d1d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -7,28 +7,23 @@ Nefertari is a REST API framework for Pyramid that uses Elasticsearch for reads Nefertari currently powers `Brandicted `_ and `Cerri `_. She is fully production ready and actively maintained. -Contents -======== +Table of Content +================ .. toctree:: :maxdepth: 2 - overview getting_started + views database_backends - acls + auth making_requests development_tools example_project + why changelog -Indices and tables -================== - -* :ref:`genindex` -* :ref:`search` - .. image:: nefertari.jpg Image credit: Wikipedia \ No newline at end of file diff --git a/docs/source/making_requests.rst b/docs/source/making_requests.rst index 89a664f..443f315 100644 --- a/docs/source/making_requests.rst +++ b/docs/source/making_requests.rst @@ -7,7 +7,7 @@ Query syntax =============================== =========== url parameter description =============================== =========== -``_m=`` to tunnel any http method using GET, e.g. _m=POST +``_m=`` to tunnel any http method using GET, e.g. _m=POST [#]_ ``_limit=`` to limit the returned collection to results (default: 20, max limit: 100 for unauthenticated users) ``_sort=`` to sort collection by ``_start=`` to start collection from the th resource @@ -28,4 +28,23 @@ url parameter description ``_search_fields=`` use with ``?q=`` to restrict search to specific fields =============================== =========== -.. [#] The full syntax of ElasticSearch querying is beyond the scope of this documentation. You can read more on the `ElasticSearch Query String Query `_ page and more specifically on `Ranges `_ to do things like: ``?date=[2014-01-01 TO *]`` +.. [#] To update listfields and dictfields, you can use the following syntax: ``_m=PATCH&=&.=`` +.. [#] The full syntax of ElasticSearch querying is beyond the scope of this documentation. You can read more on the ElasticSearch Query String Query `documentation `_ to do things like fuzzy search: ``?name=fuzzy~`` or date range search: ``?date=[2015-01-01 TO *]`` + +update_many() +------------- + +If update_many() is defined in your view, you will be able to update a single field across an entire collection or filtered collection. E.g. + +PATCH `/api/?q=` + +.. code-block:: json + + { + "":"" + } + +delete_many() +------------- + +Similarly, if delete_many() is defined, you will be able to delete an entire collection or filtered collection. E.g. DELETE `/api/?_missing_=` diff --git a/docs/source/views.rst b/docs/source/views.rst new file mode 100644 index 0000000..c5a9cd9 --- /dev/null +++ b/docs/source/views.rst @@ -0,0 +1,24 @@ +Configuring views +================= + +Introduction +------------ + +It is recommended that your views reside in a package. In this case, each module of that package would contain all views of any given root-level route. Alternatively, ou can explicitly provide a view name, or a view class as ``view`` keyword argument to ``resource.add()`` in your project's ``main`` function. In the case of a singular resource, there is no need to define ``index`` and ``show`` returns only one item. + +* *index*: called upon ``GET`` request to a collection, e.g. ``/collection`` +* *show*: called upon ``GET`` request to a collection-item, e.g. ``/collection/`` +* *create*: called upon ``POST`` request to a collection +* *update*: called upon ``PATCH`` request to a collection-item +* *delete*: called upon ``DELETE`` request to a collection-item +* *update_many*: called upon ``PATCH`` request to a collection or filtered collection, e.g. ``/collection?_exists_=`` +* *delete_many*: called upon ``DELETE`` request to a collection or filtered collection + +Notes +----- + +When using SQLA, each view must define the following properties: + * *_model_class*: class of the model that is being served by this view. + +Optional properties: + * *_json_encoder*: encoder to encode objects to JSON. Database-specific encoders are available at ``nefertari.engine.JSONEncoder``. diff --git a/docs/source/overview.rst b/docs/source/why.rst similarity index 97% rename from docs/source/overview.rst rename to docs/source/why.rst index 643a14f..76fe8ec 100644 --- a/docs/source/overview.rst +++ b/docs/source/why.rst @@ -1,5 +1,5 @@ -Overview -======== +Why Nefertari? +============== Nefertari is a tool for making REST APIs using the Pyramid web framework. @@ -29,4 +29,4 @@ By making assumptions about sane defaults, we can eliminate the need for boilerp Nefertari is the meat and potatoes of our development stack. Her partner project, Ramses, is the seasoning/sugar/cherry on top! Ramses allows whole production-ready Nefertari apps to be generated at runtime from a simple YAML file specifying the endpoints desired. `Check it out. `_ -.. [#] For the record, DRF is pretty badass and we have great respect for its breadth and the hard work of its community. Laying out a ton of boilerplate can be considered to fall into "flat is better than nested" and might be best for some teams. \ No newline at end of file +.. [#] For the record, DRF is pretty badass and we have great respect for its breadth and the hard work of its community. Laying out a ton of boilerplate can be considered to fall into "flat is better than nested" and might be best for some teams. diff --git a/nefertari/__init__.py b/nefertari/__init__.py index 0816219..b038af6 100644 --- a/nefertari/__init__.py +++ b/nefertari/__init__.py @@ -12,19 +12,21 @@ def includeme(config): from nefertari.resource import get_root_resource, get_resource_map - from nefertari.renderers import JsonRendererFactory, NefertariJsonRendererFactory + from nefertari.renderers import ( + JsonRendererFactory, NefertariJsonRendererFactory) log.info("%s %s" % (APP_NAME, __version__)) config.add_directive('get_root_resource', get_root_resource) config.add_renderer('json', JsonRendererFactory) config.add_renderer('nefertari_json', NefertariJsonRendererFactory) - config.registry._root_resources = {} - config.registry._resources_map = {} + if not hasattr(config.registry, '_root_resources'): + config.registry._root_resources = {} + if not hasattr(config.registry, '_resources_map'): + config.registry._resources_map = {} config.add_request_method(get_resource_map, 'resource_map', reify=True) - config.add_tween('nefertari.tweens.post_tunneling') config.add_tween('nefertari.tweens.cache_control') config.add_route('options', '/*path', request_method='OPTIONS') diff --git a/nefertari/acl.py b/nefertari/acl.py index 60f6f25..6bdc488 100644 --- a/nefertari/acl.py +++ b/nefertari/acl.py @@ -1,9 +1,29 @@ -from pyramid.security import ( - ALL_PERMISSIONS, Allow, Everyone, Deny, - Authenticated) +from pyramid.security import ALL_PERMISSIONS, Allow, Everyone, Authenticated -class BaseACL(object): +class SelfParamMixin(object): + """ ACL mixin that implements method to translate input key value + to a user ID field, when key value equals :param_value: + + Value is only converted if user is logged in and :request.user: + is an instance of :__context_class__:, thus for routes that display + auth users. + """ + param_value = 'self' + + def resolve_self_key(self, key): + if key != self.param_value: + return key + user = getattr(self.request, 'user', None) + if not user or not self.__context_class__: + return key + if not isinstance(user, self.__context_class__): + return key + obj_id = getattr(user, user.pk_field()) or key + return obj_id + + +class BaseACL(SelfParamMixin): """ Base ACL class. Grants: @@ -30,10 +50,11 @@ def context_acl(self, obj): def __getitem__(self, key): assert(self.__context_class__) + key = self.resolve_self_key(key) - id_field = self.__context_class__.id_field() + pk_field = self.__context_class__.pk_field() obj = self.__context_class__.get( - __raise=True, **{id_field: key}) + __raise=True, **{pk_field: key}) obj.__acl__ = self.context_acl(obj) obj.__parent__ = self obj.__name__ = key @@ -67,8 +88,9 @@ def __init__(self, request): super(GuestACL, self).__init__(request) self.acl = (Allow, Everyone, ['index', 'show']) - def context_acl(self, context): + def context_acl(self, obj): return [ + (Allow, 'g:admin', ALL_PERMISSIONS), (Allow, Everyone, ['index', 'show']), ] @@ -82,10 +104,10 @@ class AuthenticatedReadACL(BaseACL): def __init__(self, request): super(AuthenticatedReadACL, self).__init__(request) - self.acl = (Allow, Authenticated, ['index', 'show']) + self.acl = (Allow, Authenticated, 'index') - def context_acl(self, context): + def context_acl(self, obj): return [ (Allow, 'g:admin', ALL_PERMISSIONS), - (Allow, Authenticated, ['index', 'show']), + (Allow, Authenticated, 'show'), ] diff --git a/nefertari/tests/__init__.py b/nefertari/authentication/__init__.py similarity index 100% rename from nefertari/tests/__init__.py rename to nefertari/authentication/__init__.py diff --git a/nefertari/authentication/models.py b/nefertari/authentication/models.py new file mode 100644 index 0000000..199b387 --- /dev/null +++ b/nefertari/authentication/models.py @@ -0,0 +1,235 @@ +import uuid +import logging + +import cryptacular.bcrypt +from pyramid.security import authenticated_userid, forget + +from nefertari.json_httpexceptions import JHTTPBadRequest +from nefertari import engine +from nefertari.utils import dictset + +log = logging.getLogger(__name__) +crypt = cryptacular.bcrypt.BCRYPTPasswordManager() + + +class AuthModelDefaultMixin(object): + """ Mixin that implements all methods required for Ticket and Token + auth systems to work. + + All implemented methods must be class methods. + """ + @classmethod + def get_resource(self, *args, **kwargs): + return super(AuthModelDefaultMixin, self).get_resource( + *args, **kwargs) + + @classmethod + def pk_field(self, *args, **kwargs): + return super(AuthModelDefaultMixin, self).pk_field(*args, **kwargs) + + @classmethod + def get_or_create(self, *args, **kwargs): + return super(AuthModelDefaultMixin, self).get_or_create( + *args, **kwargs) + + @classmethod + def is_admin(cls, user): + """ Determine if :user: is an admin. Used by `apply_privacy` wrapper. + """ + return 'admin' in user.groups + + @classmethod + def get_token_credentials(cls, username, request): + """ Get api token for user with username of :username: + + Used by Token-based auth as `credentials_callback` kwarg. + """ + try: + user = cls.get_resource(username=username) + except Exception as ex: + log.error(unicode(ex)) + forget(request) + else: + if user: + return user.api_key.token + + @classmethod + def get_groups_by_token(cls, username, token, request): + """ Get user's groups if user with :username: exists and their api key + token equals :token: + + Used by Token-based authentication as `check` kwarg. + """ + try: + user = cls.get_resource(username=username) + except Exception as ex: + log.error(unicode(ex)) + forget(request) + return + else: + if user and user.api_key.token == token: + return ['g:%s' % g for g in user.groups] + + @classmethod + def authenticate_by_password(cls, params): + """ Authenticate user with login and password from :params: + + Used both by Token and Ticket-based auths (called from views). + """ + def verify_password(user, password): + return crypt.check(user.password, password) + + success = False + user = None + login = params['login'].lower().strip() + key = 'email' if '@' in login else 'username' + try: + user = cls.get_resource(**{key: login}) + except Exception as ex: + log.error(unicode(ex)) + + if user: + password = params.get('password', None) + success = (password and verify_password(user, password)) + return success, user + + @classmethod + def get_groups_by_userid(cls, userid, request): + """ Return group identifiers of user with id :userid: + + Used by Ticket-based auth as `callback` kwarg. + """ + try: + user = cls.get_resource(**{cls.pk_field(): userid}) + except Exception as ex: + log.error(unicode(ex)) + forget(request) + else: + if user: + return ['g:%s' % g for g in user.groups] + + @classmethod + def create_account(cls, params): + """ Create auth user instance with data from :params:. + + Used by both Token and Ticket-based auths to register a user ( + called from views). + """ + user_params = dictset(params).subset( + ['username', 'email', 'password']) + try: + return cls.get_or_create( + email=user_params['email'], + defaults=user_params) + except JHTTPBadRequest: + raise JHTTPBadRequest('Failed to create account.') + + @classmethod + def get_authuser_by_userid(cls, request): + """ Get user by ID. + + Used by Ticket-based auth. Is added as request method to populate + `request.user`. + """ + _id = authenticated_userid(request) + if _id: + return cls.get_resource(**{cls.pk_field(): _id}) + + @classmethod + def get_authuser_by_name(cls, request): + """ Get user by username + + Used by Token-based auth. Is added as request method to populate + `request.user`. + """ + username = authenticated_userid(request) + if username: + return cls.get_resource(username=username) + + +def lower_strip(value): + return (value or '').lower().strip() + + +def encrypt_password(password): + """ Crypt :password: if it's not crypted yet. """ + if password and not crypt.match(password): + password = unicode(crypt.encode(password)) + return password + + +class AuthUser(AuthModelDefaultMixin, engine.BaseDocument): + """ Class that is meant to be User class in Auth system. + + Implements basic operations to support Pyramid Ticket-based and custom + ApiKey token-based authentication. + """ + __tablename__ = 'nefertari_authuser' + + id = engine.IdField(primary_key=True) + username = engine.StringField( + min_length=1, max_length=50, unique=True, + required=True, processors=[lower_strip]) + email = engine.StringField( + unique=True, required=True, processors=[lower_strip]) + password = engine.StringField( + min_length=3, required=True, processors=[encrypt_password]) + groups = engine.ListField( + item_type=engine.StringField, + choices=['admin', 'user'], default=['user']) + + +def create_apikey_token(): + """ Generate ApiKey.token using uuid library. """ + return uuid.uuid4().hex.replace('-', '') + + +def create_apikey_model(user_model): + """ Generate ApiKey model class and connect it with :user_model:. + + ApiKey is generated with relationship to user model class :user_model: + as a One-to-One relationship with a backreference. + ApiKey is set up to be auto-generated when a new :user_model: is created. + + Returns ApiKey document class. If ApiKey is already defined, it is not + generated. + + Arguments: + :user_model: Class that represents user model for which api keys will + be generated and with which ApiKey will have relationship. + """ + try: + return engine.get_document_cls('ApiKey') + except ValueError: + pass + + fk_kwargs = { + 'ref_column': None, + } + if hasattr(user_model, '__tablename__'): + fk_kwargs['ref_column'] = '.'.join([ + user_model.__tablename__, user_model.pk_field()]) + fk_kwargs['ref_column_type'] = user_model.pk_field_type() + + class ApiKey(engine.BaseDocument): + __tablename__ = 'nefertari_apikey' + + id = engine.IdField(primary_key=True) + token = engine.StringField(default=create_apikey_token) + user = engine.Relationship( + document=user_model.__name__, + uselist=False, + backref_name='api_key', + backref_uselist=False) + user_id = engine.ForeignKeyField( + ref_document=user_model.__name__, + **fk_kwargs) + + def reset_token(self): + self.update({'token': create_apikey_token()}) + return self.token + + # Setup ApiKey autogeneration on :user_model: creation + ApiKey.autogenerate_for(user_model, 'user') + + return ApiKey diff --git a/nefertari/authentication/policies.py b/nefertari/authentication/policies.py new file mode 100644 index 0000000..8df9368 --- /dev/null +++ b/nefertari/authentication/policies.py @@ -0,0 +1,97 @@ +from pyramid.authentication import CallbackAuthenticationPolicy + +from nefertari import engine +from .models import create_apikey_model + + +class ApiKeyAuthenticationPolicy(CallbackAuthenticationPolicy): + """ ApiKey authentication policy. + + Relies on `Authorization` header being used in request, e.g.: + `Authorization: ApiKey username:token` + + To use this policy, instantiate it with required arguments, as described + in `__init__` method and register it with Pyramid's + `Configurator.set_authentication_policy`. + + You may also find useful `nefertari.authentication.views. + TokenAuthenticationView` + view which offers basic functionality to create, claim, and reset the + token. + """ + def __init__(self, user_model, check=None, credentials_callback=None): + """ Init the policy. + + Arguments: + :user_model: String name or class of a User model for which ApiKey + model is to be generated + :check: A callback passed the username, api_key and the request, + expected to return None if user doesn't exist or a sequence of + principal identifiers (possibly empty) if the user does exist. + If callback is None, the username will be assumed to exist with + no principals. Optional. + :credentials_callback: A callback passed the username and current + request, expected to return and user's api key. + Is used to generate 'WWW-Authenticate' header with a value of + valid 'Authorization' request header that should be used to + perform requests. + """ + self.user_model = user_model + if isinstance(self.user_model, basestring): + self.user_model = engine.get_document_cls(self.user_model) + create_apikey_model(self.user_model) + + self.check = check + self.credentials_callback = credentials_callback + super(ApiKeyAuthenticationPolicy, self).__init__() + + def remember(self, request, username, **kw): + """ Returns 'WWW-Authenticate' header with a value that should be used + in 'Authorization' header. + """ + if self.credentials_callback: + token = self.credentials_callback(username, request) + api_key = 'ApiKey {}:{}'.format(username, token) + return [('WWW-Authenticate', api_key)] + + def forget(self, request): + """ Returns challenge headers. This should be attached to a response + to indicate that credentials are required.""" + return [('WWW-Authenticate', 'ApiKey realm="%s"' % self.realm)] + + def unauthenticated_userid(self, request): + """ Username parsed from the ``Authorization`` request header.""" + credentials = self._get_credentials(request) + if credentials: + return credentials[0] + + def callback(self, username, request): + """ Having :username: return user's identifiers or None. """ + credentials = self._get_credentials(request) + if credentials: + username, api_key = credentials + if self.check: + return self.check(username, api_key, request) + + def _get_credentials(self, request): + """ Extract username and api key token from 'Authorization' header """ + authorization = request.headers.get('Authorization') + if not authorization: + return None + try: + authmeth, authbytes = authorization.split(' ', 1) + except ValueError: # not enough values to unpack + return None + if authmeth.lower() != 'apikey': + return None + + try: + auth = authbytes.decode('utf-8') + except UnicodeDecodeError: + auth = authbytes.decode('latin-1') + + try: + username, api_key = auth.split(':', 1) + except ValueError: # not enough values to unpack + return None + return username, api_key diff --git a/nefertari/authentication/views.py b/nefertari/authentication/views.py new file mode 100644 index 0000000..dd67787 --- /dev/null +++ b/nefertari/authentication/views.py @@ -0,0 +1,115 @@ +from pyramid.security import remember, forget + +from nefertari.json_httpexceptions import ( + JHTTPFound, JHTTPConflict, JHTTPUnauthorized, JHTTPNotFound, JHTTPOk) +from nefertari.view import BaseView +from .models import AuthUser + + +class TicketAuthenticationView(BaseView): + """ View for auth operations to use with Pyramid ticket-based auth. + `login` (POST): Login the user with 'login' and 'password' + `logout`: Logout user + """ + _model_class = AuthUser + + def register(self): + """ Register new user by POSTing all required data. + + """ + user, created = self._model_class.create_account( + self._json_params) + + if not created: + raise JHTTPConflict('Looks like you already have an account.') + + pk_field = user.pk_field() + headers = remember(self.request, getattr(user, pk_field)) + return JHTTPOk('Registered', headers=headers) + + def login(self, **params): + self._json_params.update(params) + next = self._query_params.get('next', '') + login_url = self.request.route_url('login') + if next.startswith(login_url): + next = '' # never use the login form itself as next + + unauthorized_url = self._query_params.get('unauthorized', None) + success, user = self._model_class.authenticate_by_password( + self._json_params) + + if success: + pk_field = user.pk_field() + headers = remember(self.request, getattr(user, pk_field)) + if next: + raise JHTTPFound(location=next, headers=headers) + else: + return JHTTPOk('Logged in', headers=headers) + if user: + if unauthorized_url: + return JHTTPUnauthorized(location=unauthorized_url+'?error=1') + + raise JHTTPUnauthorized('Failed to Login.') + else: + raise JHTTPNotFound('User not found') + + def logout(self): + next = self._query_params.get('next') + headers = forget(self.request) + if next: + return JHTTPFound(location=next, headers=headers) + return JHTTPOk('Logged out', headers=headers) + + +class TokenAuthenticationView(BaseView): + """ View for auth operations to use with + `nefertari.authentication.policies.ApiKeyAuthenticationPolicy` + token-based auth. Implements methods: + """ + _model_class = AuthUser + + def register(self): + """ Register a new user by POSTing all required data. + + User's `Authorization` header value is returned in `WWW-Authenticate` + header. + """ + user, created = self._model_class.create_account(self._json_params) + + if not created: + raise JHTTPConflict('Looks like you already have an account.') + + headers = remember(self.request, user.username) + return JHTTPOk('Registered', headers=headers) + + def claim_token(self, **params): + """Claim current token by POSTing 'login' and 'password'. + + User's `Authorization` header value is returned in `WWW-Authenticate` + header. + """ + self._json_params.update(params) + success, self.user = self._model_class.authenticate_by_password( + self._json_params) + + if success: + headers = remember(self.request, self.user.username) + return JHTTPOk('Token claimed', headers=headers) + if self.user: + raise JHTTPUnauthorized('Wrong login or password') + else: + raise JHTTPNotFound('User not found') + + def reset_token(self, **params): + """ Reset current token by POSTing 'login' and 'password'. + + User's `Authorization` header value is returned in `WWW-Authenticate` + header. + """ + response = self.claim_token(**params) + if not self.user: + return response + + self.user.api_key.reset_token() + headers = remember(self.request, self.user.username) + return JHTTPOk('Registered', headers=headers) diff --git a/nefertari/elasticsearch.py b/nefertari/elasticsearch.py index 6e4ef72..525c9d4 100644 --- a/nefertari/elasticsearch.py +++ b/nefertari/elasticsearch.py @@ -5,8 +5,8 @@ from nefertari.utils import ( dictset, dict2obj, process_limit, split_strip) -from nefertari.json_httpexceptions import * -from nefertari.engine import ESJSONSerializer +from nefertari.json_httpexceptions import JHTTPBadRequest, JHTTPNotFound, exception_response +from nefertari import engine log = logging.getLogger(__name__) @@ -22,6 +22,10 @@ ] +class IndexNotFoundException(Exception): + pass + + class ESHttpConnection(elasticsearch.Urllib3HttpConnection): def perform_request(self, *args, **kw): try: @@ -33,10 +37,13 @@ def perform_request(self, *args, **kw): return super(ESHttpConnection, self).perform_request(*args, **kw) except Exception as e: - if e.status_code == 'N/A': - e.status_code = 400 + status_code = e.status_code + if status_code == 404: + raise IndexNotFoundException() + if status_code == 'N/A': + status_code = 400 raise exception_response( - e.status_code, + status_code, detail='elasticsearch error.', extra=dict(data=e)) @@ -46,13 +53,17 @@ def includeme(config): ES.setup(Settings) +def _bulk_body(body): + return ES.api.bulk(body=body) + + def apply_sort(_sort): _sort_param = [] if _sort: for each in [e.strip() for e in _sort.split(',')]: if each.startswith('-'): - _sort_param.append(each[1:]+':desc') + _sort_param.append(each[1:] + ':desc') elif each.startswith('+'): _sort_param.append(each[1:] + ':asc') else: @@ -66,7 +77,7 @@ def build_terms(name, values, operator='OR'): def build_qs(params, _raw_terms='', operator='AND'): - #if param is _all then remove it + # if param is _all then remove it params.pop_by_values('_all') terms = [] @@ -106,7 +117,8 @@ def setup(cls, settings): try: _hosts = ES.settings.hosts hosts = [] - for (host, port) in [split_strip(each, ':') for each in split_strip(_hosts)]: + for (host, port) in [ + split_strip(each, ':') for each in split_strip(_hosts)]: hosts.append(dict(host=host, port=port)) params = {} @@ -117,12 +129,13 @@ def setup(cls, settings): ) ES.api = elasticsearch.Elasticsearch( - hosts=hosts, serializer=ESJSONSerializer(), + hosts=hosts, serializer=engine.ESJSONSerializer(), connection_class=ESHttpConnection, **params) log.info('Including ElasticSearch. %s' % ES.settings) except KeyError as e: - raise Exception('Bad or missing settings for elasticsearch. %s' % e) + raise Exception( + 'Bad or missing settings for elasticsearch. %s' % e) def __init__(self, source='', index_name=None, chunk_size=100): self.doc_type = self.src2type(source) @@ -154,8 +167,9 @@ def prep_bulk_documents(self, action, documents): _docs = [] for doc in documents: if not isinstance(doc, dict): - raise ValueError('document type must be `dict` not a ' - '%s' % (type(doc))) + raise ValueError( + 'Document type must be `dict` not a `{}`'.format( + type(doc).__name__)) if '_type' in doc: _doc_type = self.src2type(doc['_type']) @@ -176,7 +190,9 @@ def prep_bulk_documents(self, action, documents): return _docs def _bulk(self, action, documents, chunk_size=None): - chunk_size = chunk_size or self.chunk_size + if chunk_size is None: + chunk_size = self.chunk_size + if not documents: log.debug('empty documents: %s' % self.doc_type) return @@ -198,34 +214,40 @@ def _bulk(self, action, documents, chunk_size=None): # meta, document, meta, ... self.process_chunks( documents=body, - operation=lambda b: ES.api.bulk(body=b), + operation=_bulk_body, chunk_size=chunk_size*2) else: - log.warning('empty body') + log.warning('Empty body') def index(self, documents, chunk_size=None): - """ Reindex all `document`. """ + """ Reindex all `document`s. """ self._bulk('index', documents, chunk_size) - def index_missing(self, documents, chunk_size=None): - """ Index documents from a `document` that are missing from index. + def index_missing_documents(self, documents, chunk_size=None): + """ Index documents that are missing from ES index. - To determine what documents are missing `mget` call with a list of - document IDs from `documents` is performed. Then `document` are - filtered to drop documents that were found. + Determines which documents are missing using ES `mget` call which + returns a list of document IDs as `documents`. Then missing + `documents` from that list are indexed. """ - log.info('Indexing documents of type `{}` missing from ' - 'index `{}`'.format(self.doc_type, self.index_name)) + log.info('Trying to index documents of type `{}` missing from ' + '`{}` index'.format(self.doc_type, self.index_name)) if not documents: log.info('No documents to index') return - response = ES.api.mget( + query_kwargs = dict( index=self.index_name, doc_type=self.doc_type, fields=['_id'], body={'ids': [d['id'] for d in documents]}, ) - indexed_ids = set(d['_id'] for d in response['docs'] if d.get('found')) + try: + response = ES.api.mget(**query_kwargs) + except IndexNotFoundException: + indexed_ids = set() + else: + indexed_ids = set( + d['_id'] for d in response['docs'] if d.get('found')) documents = [d for d in documents if str(d['id']) not in indexed_ids] if not documents: @@ -239,7 +261,8 @@ def delete(self, ids): if not isinstance(ids, list): ids = [ids] - self._bulk('delete', [{'id':_id, '_type': self.doc_type} for _id in ids]) + documents = [{'id': _id, '_type': self.doc_type} for _id in ids] + self._bulk('delete', documents) def get_by_ids(self, ids, **params): if not ids: @@ -268,9 +291,21 @@ def get_by_ids(self, ids, **params): ) if fields: params['fields'] = fields - - data = ES.api.mget(**params) documents = _ESDocs() + documents._nefertari_meta = dict( + start=_start, + fields=fields, + ) + + try: + data = ES.api.mget(**params) + except IndexNotFoundException: + if __raise_on_empty: + raise JHTTPNotFound( + '{}({}) resource not found (Index does not exist)'.format( + self.doc_type, params)) + documents._nefertari_meta.update(total=0) + return documents for _d in data['docs']: try: @@ -286,10 +321,8 @@ def get_by_ids(self, ids, **params): documents.append(dict2obj(dictset(_d))) - documents._nefertari_meta = dict( + documents._nefertari_meta.update( total=len(documents), - start=_start, - fields=fields, ) return documents @@ -314,15 +347,16 @@ def build_search_params(self, params): } } } - print query_string else: _params['body'] = {"query": {"match_all": {}}} - if '_limit' in params: - _params['from_'], _params['size'] = process_limit( - params.get('_start', None), - params.get('_page', None), - params['_limit']) + if '_limit' not in params: + raise JHTTPBadRequest('Missing _limit') + + _params['from_'], _params['size'] = process_limit( + params.get('_start', None), + params.get('_page', None), + params['_limit']) if '_sort' in params: _params['sort'] = apply_sort(params['_sort']) @@ -344,7 +378,10 @@ def do_count(self, params): params.pop('size', None) params.pop('from_', None) params.pop('sort', None) - return ES.api.count(**params)['count'] + try: + return ES.api.count(**params)['count'] + except IndexNotFoundException: + return 0 def get_collection(self, **params): __raise_on_empty = params.pop('__raise_on_empty', False) @@ -360,23 +397,34 @@ def get_collection(self, **params): # pop the fields before passing to search. # ES does not support passing names of nested structures _fields = _params.pop('fields', '') - data = ES.api.search(**_params) documents = _ESDocs() + documents._nefertari_meta = dict( + start=_params['from_'], + fields=_fields) + + try: + data = ES.api.search(**_params) + except IndexNotFoundException: + if __raise_on_empty: + raise JHTTPNotFound( + '{}({}) resource not found (Index does not exist)'.format( + self.doc_type, params)) + documents._nefertari_meta.update( + total=0, took=0) + return documents for da in data['hits']['hits']: - _d = da['fields'] if 'fields' in _params else da['_source'] + _d = da['fields'] if _fields else da['_source'] _d['_score'] = da['_score'] documents.append(dict2obj(_d)) - documents._nefertari_meta = dict( + documents._nefertari_meta.update( total=data['hits']['total'], - start=_params['from_'], - fields=_fields, took=data['took'], ) if not documents: - msg = "'%s(%s)' resource not found" % (self.doc_type, params) + msg = "%s(%s) resource not found" % (self.doc_type, params) if __raise_on_empty: raise JHTTPNotFound(msg) else: @@ -394,7 +442,15 @@ def get_resource(self, **kw): params.setdefault('ignore', 404) params.update(kw) - data = ES.api.get_source(**params) + try: + data = ES.api.get_source(**params) + except IndexNotFoundException: + if __raise: + raise JHTTPNotFound( + "{}({}) resource not found (Index does not exist)".format( + self.doc_type, params)) + data = {} + if not data: msg = "'%s(%s)' resource not found" % (self.doc_type, params) if __raise: diff --git a/nefertari/engine.py b/nefertari/engine.py index bde7b91..9b35a52 100644 --- a/nefertari/engine.py +++ b/nefertari/engine.py @@ -1,5 +1,5 @@ """ -Extend global scope with an engine-specific variables/objects. +Extend global scope with engine-specific variables/objects. Usage ----- @@ -17,17 +17,17 @@ ----- Db setup should be performed after loading models, as some engines require -models schema to be defined before creating the database. If your database -does not have above requirement, it's up to you to decide when to setup -db. +model schemas to be defined before creating the database. If your database +does not have the above requirement, it's up to you to decide when to set up +the db. -The specified engine module is also `config.include`d here thus running +The specified engine module is also `config.include`d here, thus running the engine's `icludeme` function and allowing setting up required state, performing some actions, etc. -Specified engine may be either module or package. -In case you build a custom engine, variables you expect to use from it, -should be importable from package itself. +The engine specified may be either a module or a package. +In case you build a custom engine, variables you expect to use from it +should be importable from the package itself. E.g. ``from your.package import BaseDocument`` nefertari relies on 'nefertari.engine' being included when configuring the app. diff --git a/nefertari/json_httpexceptions.py b/nefertari/json_httpexceptions.py index 843ff62..14b4ae0 100644 --- a/nefertari/json_httpexceptions.py +++ b/nefertari/json_httpexceptions.py @@ -4,7 +4,8 @@ from datetime import datetime from pyramid import httpexceptions as http_exc -from nefertari.utils import dictset, json_dumps +from nefertari.wrappers import apply_privacy + logger = logging.getLogger(__name__) @@ -25,7 +26,9 @@ def add_stack(): def create_json_response(obj, request=None, log_it=False, show_stack=False, **extra): + from nefertari.utils import json_dumps body = dict() + encoder = extra.pop('encoder', None) for attr in BASE_ATTRS: body[attr] = extra.pop(attr, None) or getattr(obj, attr, None) @@ -41,7 +44,7 @@ def create_json_response(obj, request=None, log_it=False, show_stack=False, body.update(extra) - obj.body = json_dumps(body) + obj.body = json_dumps(body, encoder=encoder) show_stack = log_it or show_stack status = obj.status_int @@ -62,6 +65,7 @@ def exception_response(status_code, **kw): class JBase(object): def __init__(self, *arg, **kw): + from nefertari.utils import dictset kw = dictset(kw) self.__class__.__base__.__init__( self, *arg, @@ -74,7 +78,9 @@ def __init__(self, *arg, **kw): http_exceptions = http_exc.status_map.values() + [ - http_exc.HTTPBadRequest, http_exc.HTTPInternalServerError] + http_exc.HTTPBadRequest, + http_exc.HTTPInternalServerError, +] for exc_cls in http_exceptions: @@ -90,9 +96,18 @@ def httperrors(context, request): class JHTTPCreated(http_exc.HTTPCreated): def __init__(self, *args, **kwargs): resource = kwargs.pop('resource', None) + encoder = kwargs.pop('encoder', None) + request = kwargs.pop('request', None) super(JHTTPCreated, self).__init__(*args, **kwargs) if resource and 'location' in kwargs: resource['self'] = kwargs['location'] - create_json_response(self, **dict(data=resource)) + auth = request and request.registry._root_resources.values()[0].auth + if resource and auth: + wrapper = apply_privacy(request=request) + resource = wrapper(result=resource) + + create_json_response( + self, data=resource, + encoder=encoder) diff --git a/nefertari/logstash.py b/nefertari/logstash.py index a47c29d..15f2e4a 100644 --- a/nefertari/logstash.py +++ b/nefertari/logstash.py @@ -23,7 +23,8 @@ def includeme(config): try: sock.sendto( 'PING', 0, - (Settings['logstash.host'], Settings.asint('logstash.port'))) + (Settings['logstash.host'], + Settings.asint('logstash.port'))) recv, svr = sock.recvfrom(255) sock.shutdown(2) except Exception as e: diff --git a/nefertari/renderers.py b/nefertari/renderers.py index 6c65c5a..15a1954 100644 --- a/nefertari/renderers.py +++ b/nefertari/renderers.py @@ -32,7 +32,7 @@ def __init__(self, info): def __call__(self, value, system): """ Call the renderer implementation with the value and the system value passed in as arguments and return - the result (a string or unicode object). The value is + the result (a string or unicode object). The value is the return value of a view. The system value is a dictionary containing available system values (e.g. view, context, and request). """ @@ -48,7 +48,8 @@ def __call__(self, value, system): value = self.run_after_calls(value, system) view = system['view'] - enc_class = getattr(view, '_json_encoder', _JSONEncoder) or _JSONEncoder + enc_class = getattr( + view, '_json_encoder', _JSONEncoder) or _JSONEncoder return json.dumps(value, cls=enc_class) def run_after_calls(self, value, system): @@ -63,14 +64,14 @@ def run_after_calls(self, value, system): class NefertariJsonRendererFactory(JsonRendererFactory): - """Yarhp specific json renderer which will apply + """Special json renderer which will apply all after_calls(filters) to the result. """ def run_after_calls(self, value, system): request = system.get('request') if request and hasattr(request, 'action'): - after_calls = getattr(request, 'filters', []) + after_calls = getattr(request, 'filters', {}) for call in after_calls.get(request.action, []): value = call(**dict(request=request, result=value)) diff --git a/nefertari/resource.py b/nefertari/resource.py index 8021045..95bf824 100644 --- a/nefertari/resource.py +++ b/nefertari/resource.py @@ -29,8 +29,8 @@ def add_resource_routes(config, view, member_name, collection_name, **kwargs): ``member_name`` should be the appropriate singular version of the resource given your locale and used with members of the collection. - ``collection_name`` will be used to refer to the resource collection methods - and should be a plural version of the member_name argument. + ``collection_name`` will be used to refer to the resource collection + methods and should be a plural version of the member_name argument. All keyword arguments are optional. @@ -40,7 +40,7 @@ def add_resource_routes(config, view, member_name, collection_name, **kwargs): resources or relations between resources. ``name_prefix`` - Perpends the route names that are generated with the + Prepends the route names that are generated with the name_prefix given. Combined with the path_prefix option, it's easy to generate route names and paths that represent resources that are in relations. @@ -123,17 +123,24 @@ def add_route_and_view(config, action, route_name, path, request_method, if collection_name: add_route_and_view( - config, 'update_many', name_prefix + (collection_name or member_name), + config, 'update_many', + name_prefix + (collection_name or member_name), path, 'PUT', traverse=_traverse) add_route_and_view( - config, 'delete_many', name_prefix + (collection_name or member_name), + config, 'update_many', + name_prefix + (collection_name or member_name), + path, 'PATCH', traverse=_traverse) + + add_route_and_view( + config, 'delete_many', + name_prefix + (collection_name or member_name), path, 'DELETE', traverse=_traverse) return action_route -def default_view(resource): +def get_default_view_path(resource): "Returns the dotted path to the default view class." parts = [a.member_name for a in resource.ancestors] +\ @@ -191,18 +198,19 @@ def get_ancestors(self): ancestors = property(get_ancestors) resource_map = property(lambda self: self.config.registry._resources_map) is_root = property(lambda self: not self.member_name) - is_singular = property(lambda self: not self.is_root and not self.collection_name) + is_singular = property( + lambda self: not self.is_root and not self.collection_name) def add(self, member_name, collection_name='', parent=None, uid='', **kwargs): """ :param member_name: singular name of the resource. It should be the - appropriate singular version of the resource given your locale and used - with members of the collection. + appropriate singular version of the resource given your locale + and used with members of the collection. - :param collection_name: plural name of the resource. It will be used to - refer to the resource collection methods and should be a plural version - of the ``member_name`` argument. + :param collection_name: plural name of the resource. It will be used + to refer to the resource collection methods and should be a + plural version of the ``member_name`` argument. Note: if collection_name is empty, it means resource is singular :param parent: parent resource name or object. @@ -211,7 +219,8 @@ def add(self, member_name, collection_name='', parent=None, uid='', :param kwargs: view: custom view to overwrite the default one. - the rest of the keyward arguments are passed to add_resource_routes call. + the rest of the keyward arguments are passed to + add_resource_routes call. :return: ResourceMap object """ @@ -221,7 +230,8 @@ def add(self, member_name, collection_name='', parent=None, uid='', prefix = kwargs.pop('prefix', '') - uid = (uid or ':'.join(filter(bool, [parent.uid, prefix, member_name]))) + uid = (uid or + ':'.join(filter(bool, [parent.uid, prefix, member_name]))) if uid in self.resource_map: raise ValueError('%s already exists in resource map' % uid) @@ -233,7 +243,7 @@ def add(self, member_name, collection_name='', parent=None, uid='', prefix=prefix) view = maybe_dotted( - kwargs.pop('view', None) or default_view(new_resource)) + kwargs.pop('view', None) or get_default_view_path(new_resource)) for name, val in kwargs.pop('view_args', {}).items(): setattr(view, name, val) @@ -260,7 +270,7 @@ def add(self, member_name, collection_name='', parent=None, uid='', kwargs['path_prefix'] = '/'.join(path_segs) if prefix: - kwargs['path_prefix'] += '/'+prefix + kwargs['path_prefix'] += '/' + prefix name_segs = [a.member_name for a in new_resource.ancestors] name_segs.insert(1, prefix) @@ -273,6 +283,7 @@ def add(self, member_name, collection_name='', parent=None, uid='', kwargs.setdefault('auth', root_resource.auth) kwargs.setdefault('factory', root_resource.default_factory) + _factory = maybe_dotted(kwargs['factory']) kwargs['auth'] = kwargs.get('auth', root_resource.auth) @@ -291,10 +302,12 @@ def add(self, member_name, collection_name='', parent=None, uid='', new_resource)) parent.children.append(new_resource) + view._resource = new_resource + view._factory = _factory return new_resource - def add_from(self, resource, **kwargs): + def add_from_child(self, resource, **kwargs): """ Add a resource with its all children resources to the current resource. """ @@ -302,4 +315,4 @@ def add_from(self, resource, **kwargs): new_resource = self.add( resource.member_name, resource.collection_name, **kwargs) for child in resource.children: - new_resource.add_from(child, **kwargs) + new_resource.add_from_child(child, **kwargs) diff --git a/nefertari/scripts/es.py b/nefertari/scripts/es.py index 9d561f5..1627ade 100644 --- a/nefertari/scripts/es.py +++ b/nefertari/scripts/es.py @@ -8,6 +8,7 @@ from zope.dottedname.resolve import resolve from nefertari.utils import dictset, split_strip, to_dicts +from nefertari import engine def main(argv=sys.argv, quiet=False): @@ -38,7 +39,8 @@ def __init__(self, argv, log): '--quiet', help='Quiet mode', action='store_true', default=False) parser.add_argument( - '--models', help='List of dotted paths of models to index', + '--models', + help='Comma-separeted list of model names to index', required=True) parser.add_argument( '--params', help='Url-encoded params for each model') @@ -46,9 +48,10 @@ def __init__(self, argv, log): parser.add_argument('--chunk', help='Index chunk size', type=int) parser.add_argument( '--force', - help=('Force reindex of all documents. Only documents that ' - 'are missing from index are indexed by default.'), - type=bool, default=False) + help=('Force reindexing of all documents. By default, only ' + 'documents that are missing from index are indexed.'), + action='store_true', + default=False) self.options = parser.parse_args() if not self.options.config: @@ -68,14 +71,13 @@ def __init__(self, argv, log): self.settings = dictset(registry.settings) - def run(self, quiet=False): + def run(self): from nefertari.elasticsearch import ES ES.setup(self.settings) - models_paths = split_strip(self.options.models) + model_names = split_strip(self.options.models) - for path in models_paths: - model = resolve(path) - model_name = path.split('.')[-1] + for model_name in model_names: + model = engine.get_document_cls(model_name) params = self.options.params or '' params = dict([ @@ -91,6 +93,6 @@ def run(self, quiet=False): if self.options.force: es.index(documents, chunk_size=chunk_size) else: - es.index_missing(documents, chunk_size=chunk_size) + es.index_missing_documents(documents, chunk_size=chunk_size) return 0 diff --git a/nefertari/scripts/post2api.py b/nefertari/scripts/post2api.py index 8893ec3..e6809a5 100644 --- a/nefertari/scripts/post2api.py +++ b/nefertari/scripts/post2api.py @@ -28,14 +28,14 @@ def load(inputfile, destination): def load_singular_objects(inputfile, destination): parent_route, dynamic_part = destination.split('{') parent_route = parent_route.strip('/') - id_field, singlular_field = dynamic_part.split('}') + pk_field, singlular_field = dynamic_part.split('}') singlular_field = singlular_field.strip('/') json_file = open(inputfile) json_data = json.load(json_file) objects_count = len(json_data) - query_string = '?_limit={}'.format(objects_count, id_field) + query_string = '?_limit={}'.format(objects_count) parent_objects = requests.get(parent_route + query_string).json()['data'] for parent in parent_objects: diff --git a/nefertari/tests/test_view.py b/nefertari/tests/test_view.py deleted file mode 100644 index 532240f..0000000 --- a/nefertari/tests/test_view.py +++ /dev/null @@ -1,105 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -import unittest -import mock -from webtest import TestApp -from pyramid.config import Configurator - -from nefertari.view import BaseView -from nefertari.json_httpexceptions import * -from nefertari.wrappers import wrap_me - - -class TestBaseView(unittest.TestCase): - - def test_BaseView(self, *a): - - class UsersView(BaseView): - - def __init__(self, context, request): - BaseView.__init__(self, context, request) - - def show(self, id): - return u'John Doe' - - def convert_ids2objects(self, *args, **kwargs): - pass - - request = mock.MagicMock(content_type='') - request.matched_route.pattern = '/users' - view = UsersView(request.context, request) - - self.assertEqual(u'John Doe', view.show(1)) - - self.assertRaises(JHTTPMethodNotAllowed, view.index) - - with self.assertRaises(AttributeError): - view.frobnicate() - - # delete is an allowed action, but it raises since BaseView - # does not implement it. - with self.assertRaises(JHTTPMethodNotAllowed): - view.delete() - - def test_ViewMapper(self): - from nefertari.view import ViewMapper - - bc1 = mock.Mock() - bc3 = mock.Mock() - bc2 = mock.Mock() - ac1 = mock.Mock(return_value=['thing']) - - class MyView: - - def __init__(self, ctx, req): - self._before_calls = {'index': [bc1], 'show': [bc3]} - self._after_calls = {} - - @wrap_me(before=bc2) - def index(self): - return ['thing'] - - request = mock.MagicMock() - resource = mock.MagicMock(actions=['index']) - - wrapper = ViewMapper(**{'attr': 'index'})(MyView) - resp = wrapper(resource, request) - - bc1.assert_called_with(request=request) - - self.assertFalse(bc2.called) - self.assertFalse(bc3.called) - - def test_defalt_wrappers_and_wrap_me(self): - from nefertari import wrappers - - self.maxDiff = None - - def before_call(*a): - return a[2] - - def after_call(*a): - return a[2] - - class MyView(BaseView): - - @wrappers.wrap_me(before=before_call, after=after_call) - def index(self): - return [1, 2, 3] - - def convert_ids2objects(self, *args, **kwargs): - pass - - request = mock.MagicMock(content_type='') - resource = mock.MagicMock(actions=['index']) - view = MyView(resource, request) - - self.assertEqual(len(view._after_calls['index']), 3) - self.assertEqual(len(view._after_calls['show']), 2) - self.assertEqual(len(view._after_calls['delete']), 1) - self.assertEqual(len(view._after_calls['delete_many']), 1) - self.assertEqual(len(view._after_calls['update_many']), 1) - - self.assertEqual(view.index._before_calls, [before_call]) - self.assertEqual(view.index._after_calls, [after_call]) diff --git a/nefertari/tests/test_wrappers.py b/nefertari/tests/test_wrappers.py deleted file mode 100644 index 289087e..0000000 --- a/nefertari/tests/test_wrappers.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- -import unittest -import mock -from nefertari import wrappers - - -class WrappersTest(unittest.TestCase): - - def test_obj2dict(self): - result = mock.MagicMock() - result.to_dict.return_value = dict(a=1) - - res = wrappers.obj2dict(request=None)(result=result) - self.assertEqual(dict(a=1), res) - - result.to_dict.return_value = [dict(a=1), dict(b=2)] - self.assertEqual([dict(a=1), dict(b=2)], - wrappers.obj2dict(request=None)(result=result)) - - special = mock.MagicMock() - special.to_dict.return_value = {'special': 'dict'} - result = ['a', 'b', special] - self.assertEqual(['a', 'b', {'special': 'dict'}], - wrappers.obj2dict(request=None)(result=result)) - - self.assertEqual([], wrappers.obj2dict(request=None)(result=[])) diff --git a/nefertari/tweens.py b/nefertari/tweens.py index 03309d2..297fad9 100644 --- a/nefertari/tweens.py +++ b/nefertari/tweens.py @@ -17,7 +17,7 @@ def timing(request): try: return handler(request) finally: - delta = time.time()-start + delta = time.time() - start msg = '%s (%s) request took %s seconds' % ( request.method, request.url, delta) if delta > threshold: @@ -28,56 +28,6 @@ def timing(request): return timing -def post_tunneling(handler, registry): - """Allow other request methods to be tunneled via POST. - - This allows PUT, PATCH and DELETE requests to be tunneled via POST requests. - The method can be specified using a parameter or a header... - - The name of the parameter is '_method'; it can be a query or POST - parameter. The query parameter will be preferred if both the query and - POST parameters are present in the request. - - The name of the header is 'X-HTTP-Method-Override'. If the parameter - described above is passed, this will be ignored. - - The request method will be overwritten before it reaches application - code, such that the application will never be aware of the original - request method. Likewise, the parameter and header will be removed from - the request, and the application will never see them. - - """ - log.info('post_tunneling enabled') - - param_name = '_method' - header_name = 'X-HTTP-Method-Override' - allowed_methods = set(['PUT', 'DELETE', 'PATCH']) - disallowed_message = ( - 'Only these methods may be tunneled over POST: {0}.' - .format(sorted(list(allowed_methods)))) - - def post_tunneling(request): - if request.method == 'POST': - method = '' - - if param_name in request.GET: - method = request.GET[param_name] - elif param_name in request.POST: - method = request.POST[param_name] - elif header_name in request.headers: - method = request.headers[header_name] - - if method in allowed_methods: - request.GET.pop(param_name, None) - request.POST.pop(param_name, None) - request.headers.pop(header_name, None) - request.method = method - - return handler(request) - - return post_tunneling - - def get_tunneling(handler, registry): """ This allows all methods to be tunneled via GET for dev/debuging purposes. @@ -103,9 +53,10 @@ def get_tunneling(request): def cors(handler, registry): log.info('cors_tunneling enabled') + allow_origins_setting = registry.settings.get('cors.allow_origins', '') + allow_origins = [ - each.strip() for each in - registry.settings.get('cors.allow_origins', '').split(',')] + each.strip() for each in allow_origins_setting.split(',')] allow_credentials = registry.settings.get('cors.allow_credentials', None) def cors(request): @@ -121,7 +72,7 @@ def cors(request): return response - if not allow_origins: + if not allow_origins_setting: log.warning('cors.allow_origins is not set') else: log.info('Allow Origins = %s ' % allow_origins) @@ -129,7 +80,7 @@ def cors(request): if allow_credentials is None: log.warning('cors.allow_credentials is not set') - elif asbool(allow_credentials) and allow_origins == '*': + elif asbool(allow_credentials) and allow_origins_setting == '*': log.error('Not allowed Access-Control-Allow-Credentials ' 'to set to TRUE if origin is *') return @@ -145,7 +96,7 @@ def cache_control(handler, registry): def cache_control(request): response = handler(request) - #change only if the header cache-control is missing + # change only if the header cache-control is missing add_header = True for header in response.headerlist: if 'Cache-Control' in header: @@ -180,9 +131,9 @@ def ssl(request): def enable_selfalias(config, id_name): """ - This allows to replace id_name with "self". - i.e. /users/joe/account == /users/self/account if joe is in the session - as authorized user + This allows replacing id_name with "self". + e.g. /users/joe/account == /users/self/account if joe is in the session + as an authorized user """ def context_found_subscriber(event): diff --git a/nefertari/utility_views.py b/nefertari/utility_views.py index aa693a1..0ebab7d 100644 --- a/nefertari/utility_views.py +++ b/nefertari/utility_views.py @@ -1,13 +1,5 @@ from pyramid.view import view_config -import nefertari -from nefertari.json_httpexceptions import * -from nefertari import wrappers -from nefertari.view import BaseView - - -log = logging.getLogger(__name__) - @view_config(name='options_view', request_method='OPTIONS', route_name='options') @@ -32,124 +24,3 @@ def __call__(self): 'origin, x-requested-with, content-type' return request.response - - -class EngineView(BaseView): - def __init__(self, context, request): - super(EngineView, self).__init__(context, request) - self._params.process_int_param('_limit', 20) - - def add_self(**kwargs): - result = kwargs['result'] - request = kwargs['request'] - - try: - for each in result['data']: - each['self'] = "%s?id=%s" % ( - request.current_route_url(), each['id']) - except KeyError: - pass - - return result - - self.add_after_call('show', add_self) - # Wrap in a dict so it acts as "index" - self.add_after_call('show', wrappers.wrap_in_dict(self.request), pos=0) - - def show(self, id): - return self._model_class.get_collection(**self._params) - - def delete(self, id): - objs = self._model_class.get_collection(**self._params) - - if self.needs_confirmation(): - return objs - - count = self._model_class.count(objs) - self._model_class._delete_many(objs) - return JHTTPOk("Deleted %s %s objects" % (count, id)) - - -LOGNAME_MAP = dict( - NOTSET=logging.NOTSET, - DEBUG=logging.DEBUG, - INFO=logging.INFO, - WARNING=logging.WARNING, - ERROR=logging.ERROR, - CRITICAL=logging.CRITICAL, -) - - -class LogLevelView(BaseView): - def __init__(self, *arg, **kw): - super(LogLevelView, self).__init__(*arg, **kw) - - self.name = self.request.matchdict.get('id', 'root') - if self.name == 'root': - self.log = logging.getLogger() - else: - self.log = logging.getLogger(self.name) - - def setlevel(self, level): - log.info("SET logger '%s' to '%s'" % (self.name, level)) - self.log.setLevel(LOGNAME_MAP[level]) - - def show(self, id=None): - return dict( - logger=self.name, - level=logging.getLevelName(self.log.getEffectiveLevel()) - ) - - def update(self, id=None): - level = self._params.keys()[0].upper() - self.setlevel(level) - return JHTTPOk() - - def delete(self, id=None): - self.setlevel('INFO') - return JHTTPOk() - - -class SettingsView(BaseView): - settings = None - __orig = None - - def __init__(self, *arg, **kw): - super(SettingsView, self).__init__(*arg, **kw) - assert(self.settings) - self.__orig = self.settings.copy() - - def index(self): - return dict(self.settings) - - def show(self, id): - return self.settings[id] - - def update(self, id): - self.settings[id] = self._params['value'] - return JHTTPOk() - - def create(self): - key = self._params['key'] - value = self._params['value'] - - self.settings[key] = value - - return JHTTPCreate() - - def delete(self, id): - if 'reset' in self._params: - self.settings[id] = self.request.registry.settings[id] - else: - self.settings.pop(id, None) - - return JHTTPOk() - - def delete_many(self): - if self.needs_confirmation(): - return self.settings.keys() - - for name, val in self.settings.items(): - self.settings[name] = self.__orig[name] - - return JHTTPOk("Reset the settings to original values") diff --git a/nefertari/utils/__init__.py b/nefertari/utils/__init__.py index d8b4075..88f86df 100644 --- a/nefertari/utils/__init__.py +++ b/nefertari/utils/__init__.py @@ -1,7 +1,5 @@ from nefertari.utils.data import * from nefertari.utils.dictset import * from nefertari.utils.utils import * -from nefertari.utils.request import * -_requests = Requests _split = split_strip diff --git a/nefertari/utils/data.py b/nefertari/utils/data.py index b47a399..498720f 100644 --- a/nefertari/utils/data.py +++ b/nefertari/utils/data.py @@ -66,7 +66,7 @@ def to_dicts(collection, key=None, **kw): if key: each_dict = key(each_dict) _dicts.append(each_dict) - except AttributeError, e: + except AttributeError: _dicts.append(each) except TypeError: return collection diff --git a/nefertari/utils/dictset.py b/nefertari/utils/dictset.py index 084228a..ed9f0a0 100644 --- a/nefertari/utils/dictset.py +++ b/nefertari/utils/dictset.py @@ -15,7 +15,8 @@ def subset(self, keys): return dictset([[k, v] for k, v in self.items() if k in only]) if exclude: - return dictset([[k, v] for k, v in self.items() if k not in exclude]) + return dictset([[k, v] for k, v in self.items() + if k not in exclude]) return dictset() @@ -62,7 +63,8 @@ def asfloat(self, name, default=0.0, _set=False): def asdict(self, name, _type=None, _set=False): """ - Turn this 'a:2,b:blabla,c:True,a:'d' to {a:[2, 'd'], b:'blabla', c:True} + Turn this 'a:2,b:blabla,c:True,a:'d' to + {a:[2, 'd'], b:'blabla', c:True} """ @@ -76,13 +78,14 @@ def asdict(self, name, _type=None, _set=False): _dict = {} for item in split_strip(dict_str): key, _, val = item.partition(':') + val = _type(val) if key in _dict: - if type(_dict[key]) is list: + if isinstance(_dict[key], list): _dict[key].append(val) else: _dict[key] = [_dict[key], val] else: - _dict[key] = _type(val) + _dict[key] = val if _set: self[name] = _dict @@ -90,7 +93,7 @@ def asdict(self, name, _type=None, _set=False): return _dict def mget(self, prefix, defaults={}): - if prefix[-1] != '.': + if not prefix.endswith('.'): prefix += '.' _dict = dictset(defaults) @@ -145,10 +148,12 @@ def pop_bool_param(self, name, default=False): def process_datetime_param(self, name): if name in self: try: - self[name] = datetime.strptime(self[name], "%Y-%m-%dT%H:%M:%SZ") + self[name] = datetime.strptime( + self[name], "%Y-%m-%dT%H:%M:%SZ") except ValueError: - raise ValueError("Bad format for '%s' param. Must be ISO 8601, " - "YYYY-MM-DDThh:mm:ssZ" % name) + raise ValueError( + "Bad format for '%s' param. Must be ISO 8601, " + "YYYY-MM-DDThh:mm:ssZ" % name) return self.get(name, None) @@ -161,6 +166,7 @@ def process_float_param(self, name, default=None): elif default is not None: self[name] = default + return self.get(name, None) def process_int_param(self, name, default=None): if name in self: @@ -171,6 +177,7 @@ def process_int_param(self, name, default=None): elif default is not None: self[name] = default + return self.get(name, None) def process_dict_param(self, name, _type=None, pop=False): return self.asdict(name, _type, _set=not pop) diff --git a/nefertari/utils/request.py b/nefertari/utils/request.py deleted file mode 100644 index 0c9a53e..0000000 --- a/nefertari/utils/request.py +++ /dev/null @@ -1,129 +0,0 @@ -import logging -import requests -import urllib -from pyramid.response import Response - -from nefertari.utils.utils import json_dumps -from nefertari.json_httpexceptions import * - -log = logging.getLogger(__name__) - - -def pyramid_resp(resp, **kw): - return Response(status_code=resp.status_code, - headers=resp.headers, - body=resp.text, **kw) - - -class Requests(object): - def __init__(self, base_url=''): - self.base_url = base_url - - def prepare_url(self, path='', params={}): - url = self.base_url - - if path: - url = '%s%s' % (url, (path if path.startswith('/') else '/'+path)) - - if params: - url = '%s%s%s' % (url, '&' if '?' in url else '?', - urllib.urlencode(params)) - - return url - - def get(self, path, params={}, **kw): - url = self.prepare_url(path, params) - log.debug('%s', url) - - try: - resp = requests.get(url, **kw) - if not resp.ok: - raise exception_response(**resp.json()) - return resp.json() - except requests.ConnectionError as e: - raise JHTTPServerError('Server is down? %s' % e) - - def mget(self, path, params={}, page_size=None): - total = params['_limit'] - start = params.get('_start', 0) - params['_limit'] = page_size - page_count = total/page_size - - for ix in range(page_count): - params['_start'] = start + ix*page_size - yield self.get(path, params) - - reminder = total % page_size - if reminder: - params['_start'] = start + page_count*page_size - params['_limit'] = reminder - yield self.get(path, params) - - def post(self, path='', data={}, **kw): - url = self.prepare_url(path) - log.debug('%s, kwargs:%.512s', url, data) - try: - resp = requests.post( - url, data=json_dumps(data), - headers={'content-type': 'application/json'}, - **kw) - if not resp.ok: - raise exception_response(**resp.json()) - - return pyramid_resp(resp) - except requests.ConnectionError as e: - raise JHTTPServerError('Server is down? %s' % e) - - def mpost(self, path='', data={}, bulk_size=None, bulk_key=None): - bulk_data = data[bulk_key] - total = len(bulk_data) - page_count = total/bulk_size - - for ix in range(page_count): - data[bulk_key] = bulk_data[ix*bulk_size:(ix+1)*bulk_size] - yield self.post(path, data) - - reminder = total % bulk_size - if reminder: - st = page_count*bulk_size - data[bulk_key] = bulk_data[st:st+reminder] - yield self.post(path, data) - - def put(self, path='', data={}, **kw): - try: - url = self.prepare_url(path) - log.debug('%s, kwargs:%.512s', url, data) - - resp = requests.put( - url, data=json_dumps(data), - headers={'content-type': 'application/json'}, - **kw) - if not resp.ok: - raise exception_response(**resp.json()) - - return resp.json() - except requests.ConnectionError as e: - raise JHTTPServerError('Server is down? %s' % e) - - def head(self, path='', params={}): - try: - resp = requests.head(self.prepare_url(path, params)) - if not resp.ok: - raise exception_response(**resp.json()) - - except requests.ConnectionError as e: - raise JHTTPServerError('Server is down? %s' % e) - - def delete(self, path='', **kw): - url = self.prepare_url(path) - log.debug(url) - try: - resp = requests.delete( - url, headers={'content-type': 'application/json'}, - **kw) - if not resp.ok: - raise exception_response(**resp.json()) - - return resp.json() - except requests.ConnectionError as e: - raise JHTTPServerError('Server is down? %s' % e) diff --git a/nefertari/utils/utils.py b/nefertari/utils/utils.py index ea3c5bd..9e7d953 100644 --- a/nefertari/utils/utils.py +++ b/nefertari/utils/utils.py @@ -11,8 +11,10 @@ log = logging.getLogger(__name__) -def json_dumps(body): - return json.dumps(body, cls=_JSONEncoder) +def json_dumps(body, encoder=None): + if encoder is None: + encoder = _JSONEncoder + return json.dumps(body, cls=encoder) def split_strip(_str, on=','): @@ -82,14 +84,14 @@ def snake2camel(text): return ''.join([a.title() for a in text.split("_")]) -def maybe_dotted(modul, throw=True): - """ If ``modul`` is a dotted string pointing to the modul, - imports and returns the modul object. +def maybe_dotted(module, throw=True): + """ If ``module`` is a dotted string pointing to the module, + imports and returns the module object. """ try: - return Configurator().maybe_dotted(modul) - except ImportError, e: - err = '%s not found. %s' % (modul, e) + return Configurator().maybe_dotted(module) + except ImportError as e: + err = '%s not found. %s' % (module, e) if throw: raise ImportError(err) else: @@ -110,7 +112,7 @@ def isnumeric(value): try: float(value) return True - except ValueError: + except (ValueError, TypeError): return False diff --git a/nefertari/view.py b/nefertari/view.py index 0eba745..b37355f 100644 --- a/nefertari/view.py +++ b/nefertari/view.py @@ -6,10 +6,12 @@ from pyramid.settings import asbool from pyramid.request import Request -from nefertari.json_httpexceptions import * +from nefertari.json_httpexceptions import ( + JHTTPBadRequest, JHTTPNotFound, JHTTPMethodNotAllowed) from nefertari.utils import dictset from nefertari import wrappers from nefertari.resource import ACTIONS +from nefertari import engine log = logging.getLogger(__name__) @@ -21,7 +23,7 @@ def __init__(self, **kwargs): self.kwargs = kwargs def __call__(self, view): - #i.e index, create etc. + # i.e index, create etc. action_name = self.kwargs['attr'] def view_mapper_wrapper(context, request): @@ -29,13 +31,13 @@ def view_mapper_wrapper(context, request): matchdict.pop('action', None) matchdict.pop('traverse', None) - #instance of BaseView (or child of) + # instance of BaseView (or child of) view_obj = view(context, request) action = getattr(view_obj, action_name) request.action = action_name - # we should not run "after_calls" here, so lets save them in request - # as filters they will be ran in the renderer factory + # we should not run "after_calls" here, so lets save them in + # request as filters they will be ran in the renderer factory request.filters = view_obj._after_calls try: @@ -65,55 +67,105 @@ class BaseView(object): _json_encoder = None _model_class = None - def __init__(self, context, request, _params={}): + @staticmethod + def convert_dotted(params): + """ Convert dotted keys in :params: dictset to a nested dictset. + + E.g. {'settings.foo': 'bar'} -> {'settings': {'foo': 'bar'}} + """ + if not isinstance(params, dictset): + params = dictset(params) + + dotted = defaultdict(dict) + dotted_items = {k: v for k, v in params.items() if '.' in k} + + if dotted_items: + for key, value in dotted_items.items(): + field, subfield = key.split('.') + dotted[field].update({subfield: value}) + params = params.subset(['-' + k for k in dotted_items.keys()]) + params.update(dict(dotted)) + + return params + + def __init__(self, context, request, _query_params={}, _json_params={}): + """ Prepare data to be used across the view and run init methods. + + Each view has these dicts on data: + :_query_params: Params from a query string + :_json_params: Request JSON data. Populated only for + PUT, PATCH, POST methods + :_params: Join of _query_params and _json_params + + For method tunneling, _json_params contains the same data as + _query_params. + """ self.context = context self.request = request - - self._params = dictset(_params or request.params.mixed()) + self._query_params = dictset(_query_params or request.params.mixed()) + self._json_params = dictset(_json_params) ctype = request.content_type if request.method in ['POST', 'PUT', 'PATCH']: if ctype == 'application/json': try: - self._params.update(request.json) + self._json_params.update(request.json) except simplejson.JSONDecodeError: log.error( - "Expecting JSON. Received: '{}'. Request: {} {}".format( + "Expecting JSON. Received: '{}'. " + "Request: {} {}".format( request.body, request.method, request.url)) + self._json_params = BaseView.convert_dotted(self._json_params) + self._query_params = BaseView.convert_dotted(self._query_params) + + self._params = self._query_params.copy() + self._params.update(self._json_params) + # dict of the callables {'action':[callable1, callable2..]} - # as name implies, before calls are executed before the action is called - # after_calls are called after the action returns. + # as name implies, before calls are executed before the action is + # called after_calls are called after the action returns. self._before_calls = defaultdict(list) self._after_calls = defaultdict(list) # no accept headers, use default if '' in request.accept: request.override_renderer = self._default_renderer - elif 'application/json' in request.accept: request.override_renderer = 'nefertari_json' - elif 'text/plain' in request.accept: request.override_renderer = 'string' + self._run_init_actions() + + def _run_init_actions(self): self.setup_default_wrappers() self.convert_ids2objects() + self.set_public_limits() - if not getattr(self.request, 'user', None): + def set_public_limits(self): + """ Set public limits if auth is enabled and user is not + authenticated. + """ + root_resource = getattr(self, 'root_resource', None) + auth_enabled = root_resource is not None and root_resource.auth + if auth_enabled and not getattr(self.request, 'user', None): wrappers.set_public_limits(self) def convert_ids2objects(self): - """ Convert object IDs from `self._params` to objects if needed. + """ Convert object IDs from `self._json_params` to objects if needed. Only IDs tbat belong to relationship field of `self._model_class` are converted. """ - from nefertari.engine import is_relationship_field, relationship_cls - for field in self._params.keys(): - if not is_relationship_field(field, self._model_class): + if not self._model_class: + log.info("%s has no model defined" % self.__class__.__name__) + return + + for field in self._json_params.keys(): + if not engine.is_relationship_field(field, self._model_class): continue - model_cls = relationship_cls(field, self._model_class) + model_cls = engine.get_relationship_cls(field, self._model_class) self.id2obj(field, model_cls) def get_debug(self, package=None): @@ -124,9 +176,18 @@ def get_debug(self, package=None): return asbool(self.request.registry.settings.get(key)) def setup_default_wrappers(self): + root_resource = getattr(self, 'root_resource', None) + auth_enabled = root_resource and root_resource.auth + self._after_calls['index'] = [ wrappers.wrap_in_dict(self.request), wrappers.add_meta(self.request), + ] + if auth_enabled: + self._after_calls['index'] += [ + wrappers.apply_privacy(self.request), + ] + self._after_calls['index'] += [ wrappers.add_etag(self.request), ] @@ -134,6 +195,10 @@ def setup_default_wrappers(self): wrappers.wrap_in_dict(self.request), wrappers.add_meta(self.request), ] + if auth_enabled: + self._after_calls['show'] += [ + wrappers.apply_privacy(self.request), + ] self._after_calls['delete'] = [ wrappers.add_confirmation_url(self.request) @@ -171,53 +236,39 @@ def add_before_or_after_call(self, action, _callable, pos=None, else: callkind[action].insert(pos, _callable) - add_before_call = lambda self, *a, **k: self.add_before_or_after_call(*a, before=True, **k) - add_after_call = lambda self, *a, **k: self.add_before_or_after_call(*a, before=False, **k) + add_before_call = lambda self, *a, **k: self.add_before_or_after_call( + *a, before=True, **k) + add_after_call = lambda self, *a, **k: self.add_before_or_after_call( + *a, before=False, **k) def subrequest(self, url, params={}, method='GET'): req = Request.blank(url, cookies=self.request.cookies, content_type='application/json', method=method) - if req.method == 'GET' and params: + if method == 'GET' and params: req.body = urllib.urlencode(params) - if req.method == 'POST': + if method == 'POST': req.body = json.dumps(params) return self.request.invoke_subrequest(req) def needs_confirmation(self): - return '__confirmation' not in self._params - - def delete_many(self, **kw): - if not self._model_class: - log.error("%s _model_class in invalid: %s" % ( - self.__class__.__name__, self._model_class)) - raise JHTTPBadRequest - - objs = self._model_class.get_collection(**self._params) - - if self.needs_confirmation(): - return objs - - count = self._model_class.count(objs) - self._model_class._delete_many(objs) - return JHTTPOk("Deleted %s %s objects" % ( - count, self._model_class.__name__)) + return '__confirmation' not in self._query_params - def id2obj(self, name, model, id_field=None, setdefault=None): - if name not in self._params: + def id2obj(self, name, model, pk_field=None, setdefault=None): + if name not in self._json_params: return - if id_field is None: - id_field = model.id_field() + if pk_field is None: + pk_field = model.pk_field() def _get_object(id_): - if isinstance(id_, model): + if hasattr(id_, 'pk_field'): return id_ - obj = model.get(**{id_field: id_}) + obj = model.get(**{pk_field: id_}) if setdefault: return obj or setdefault else: @@ -225,11 +276,11 @@ def _get_object(id_): raise JHTTPBadRequest('id2obj: Object %s not found' % id_) return obj - ids = self._params[name] + ids = self._json_params[name] if isinstance(ids, list): - self._params[name] = [_get_object(_id) for _id in ids] + self._json_params[name] = [_get_object(_id) for _id in ids] else: - self._params[name] = _get_object(ids) + self._json_params[name] = _get_object(ids) def key_error_view(context, request): diff --git a/nefertari/wrappers.py b/nefertari/wrappers.py index ad1dfef..ad5423d 100644 --- a/nefertari/wrappers.py +++ b/nefertari/wrappers.py @@ -3,6 +3,8 @@ import logging +from nefertari import engine + log = logging.getLogger(__name__) @@ -69,12 +71,16 @@ def __eq__(self, other): # After calls. class obj2dict(object): + """ Convert object to dictionary. + + Sequence of objects is converted to sequence of dicts. + Conversion is performed by calling object's 'to_dict' method. + """ def __init__(self, request): self.request = request def __call__(self, **kwargs): '''converts objects in `result` into dicts''' - result = kwargs['result'] if isinstance(result, dict): return result @@ -87,7 +93,7 @@ def __call__(self, **kwargs): return result.to_dict(_keys=_fields, request=self.request) elif issequence(result): - #make sure its mutable, i.e list + # make sure its mutable, i.e list result = list(result) for ix, each in enumerate(result): result[ix] = obj2dict(self.request)( @@ -96,12 +102,95 @@ def __call__(self, **kwargs): return result +class apply_privacy(object): + """ Apply privacy rules to a JSON output. + + Passed 'result' kwarg's value may be a dictset or a collection JSON + output which contains objects' data under 'data' key as a sequence of + dictsets. + + Privacy is applied checking model's (got using '_type' key value) fields: + * _public_fields: Fields visible to non-authenticated users. + * _auth_fields: Fields visible to authenticated users. + + Admin can see all the fields. Whether user is admin, is checked by + calling 'is_admin()' method on 'self.request.user'. + + If this wrapper is called without request, no filtering is performed. + Fields visible to all types of users: 'self', '_type'. + """ + def __init__(self, request): + self.request = request + + def _filter_fields(self, data): + if '_type' not in data: + return data + try: + model_cls = engine.get_document_cls(data['_type']) + except ValueError as ex: + log.error(str(ex)) + return data + + public_fields = set(getattr(model_cls, '_public_fields', None) or []) + auth_fields = set(getattr(model_cls, '_auth_fields', None) or []) + fields = set(data.keys()) + + user = getattr(self.request, 'user', None) + if self.request: + # User authenticated + if user: + # User not admin + if not self.is_admin: + fields &= auth_fields + + # User not authenticated + else: + fields &= public_fields + + fields.add('_type') + fields.add('self') + return data.subset(fields) + + def __call__(self, **kwargs): + result = kwargs['result'] + if not isinstance(result, dict): + return result + data = result.get('data', result) + + if data: + self.is_admin = kwargs.get('is_admin') + if self.is_admin is None: + user = getattr(self.request, 'user', None) + self.is_admin = user is not None and type(user).is_admin(user) + if issequence(data) and not isinstance(data, dict): + kwargs = {'is_admin': self.is_admin} + data = [apply_privacy(self.request)(result=d, **kwargs) + for d in data] + else: + data = self._filter_fields(data) + + if 'data' in result: + result['data'] = data + else: + result = data + return result + + class wrap_in_dict(object): + """ Wraps 'result' kwarg value in dict. + + If object passed in 'result' kwarg has metadata in '_nefertari_meta' + attribute, it's metadata is preserved and then applied if object + is converted to a sequence of dicts. + + Conversion of object from 'result' kwargs is performed by calling + `obj2dict` wrapper. + """ def __init__(self, request): self.request = request def __call__(self, **kwargs): - '''if result is a list then wrap it in the dict''' + """ If result is a list then wrap it in the dict. """ result = kwargs['result'] if hasattr(result, '_nefertari_meta'): @@ -121,6 +210,14 @@ def __call__(self, **kwargs): class add_meta(object): + """ Add metadata to results. + + In particular adds: + * 'count': Number of results. Equals to number of objects in + `result['data']` + * 'self': For each object in `result['data']` adds a url which points + to current object + """ def __init__(self, request): self.request = request @@ -132,33 +229,40 @@ def __call__(self, **kwargs): for each in result['data']: try: each.setdefault('self', "%s/%s" % ( - self.request.current_route_url(), + self.request.path_url, urllib.quote(str(each['id'])))) except TypeError: pass - except (TypeError, KeyError): - pass finally: return result class add_confirmation_url(object): + """ Add confirmation url to confirm some action. + + Confirmation url is generated using `self.request.url`, `s__confirmation` + query param and a method name in `_m` param. + """ def __init__(self, request): self.request = request def __call__(self, **kwargs): - from nefertari.engine import BaseDocument result = kwargs['result'] q_or_a = '&' if self.request.params else '?' return dict( method=self.request.method, - count=BaseDocument.count(result), + count=engine.BaseDocument.count(result), confirmation_url=self.request.url+'%s__confirmation&_m=%s' % ( q_or_a, self.request.method)) class add_etag(object): + """ Add ETAG header to response. + + Etag is generated md5-encoding '_version' + 'id' of each object + in a sequence of objects returned. + """ def __init__(self, request): self.request = request @@ -176,7 +280,7 @@ def etag(data): for each in result['data']: etag_src += etag(each) - except (TypeError, KeyError) as e: + except (TypeError, KeyError): pass finally: @@ -195,7 +299,7 @@ def __call__(self, **kwargs): try: result._nefertari_meta['total'] = min( self.total, result._nefertari_meta['total']) - except (AttributeError, TypeError) as e: + except (AttributeError, TypeError): pass return result @@ -205,9 +309,9 @@ def set_public_limits(view): 'public_max_limit', 100)) try: - _limit = int(view._params.get('_limit', 20)) - _page = int(view._params.get('_page', 0)) - _start = int(view._params.get('_start', 0)) + _limit = int(view._query_params.get('_limit', 20)) + _page = int(view._query_params.get('_page', 0)) + _start = int(view._query_params.get('_start', 0)) view.add_after_call('index', set_total(view.request, total=public_max), pos=0) @@ -215,6 +319,6 @@ def set_public_limits(view): from nefertari.json_httpexceptions import JHTTPBadRequest raise JHTTPBadRequest("Bad _limit/_page param") - _start = _start or _page*_limit + _start = _start or _page * _limit if _start + _limit > public_max: - view._params['_limit'] = max((public_max - _start), 0) + view._query_params['_limit'] = max((public_max - _start), 0) diff --git a/requirements.dev b/requirements.dev index 3c26160..901f2a4 100644 --- a/requirements.dev +++ b/requirements.dev @@ -1,5 +1,5 @@ -pytest==2.6.4 -pytest-cov==1.8.1 +pytest +pytest-cov mock webtest Sphinx diff --git a/setup.py b/setup.py index 3a85049..2c312c5 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,9 @@ 'requests', 'simplejson', 'elasticsearch', - 'blinker' + 'blinker', + 'zope.dottedname', + 'cryptacular', ] setup( diff --git a/tests/test_acl.py b/tests/test_acl.py new file mode 100644 index 0000000..81c81ca --- /dev/null +++ b/tests/test_acl.py @@ -0,0 +1,125 @@ +import pytest +from mock import Mock +from pyramid.security import ALL_PERMISSIONS, Allow, Everyone, Authenticated + +from nefertari import acl + + +class TestACLsUnit(object): + + def test_baseacl_init(self): + acl_obj = acl.BaseACL(request='foo') + assert acl_obj.request == 'foo' + assert acl_obj.__acl__ == [(Allow, 'g:admin', ALL_PERMISSIONS)] + assert acl_obj.__context_acl__ == [ + (Allow, 'g:admin', ALL_PERMISSIONS)] + + def test_baseacl_acl_getter(self): + acl_obj = acl.BaseACL(request='foo') + assert acl_obj.acl is acl_obj.__acl__ + assert acl_obj.acl == [(Allow, 'g:admin', ALL_PERMISSIONS)] + + def test_baseacl_acl_setter(self): + acl_obj = acl.BaseACL(request='foo') + assert acl_obj.acl == [(Allow, 'g:admin', ALL_PERMISSIONS)] + ace = (Allow, Everyone, ['index', 'show']) + with pytest.raises(AssertionError): + acl_obj.acl = [ace] + acl_obj.acl = ace + assert acl_obj.acl == [(Allow, 'g:admin', ALL_PERMISSIONS), ace] + + def test_baseacl_context_acl(self): + acl_obj = acl.BaseACL(request='foo') + assert acl_obj.context_acl(None) is acl_obj.__context_acl__ + + def test_baseacl_getitem_no_context_cls(self): + acl_obj = acl.BaseACL(request='foo') + assert acl_obj.__context_class__ is None + with pytest.raises(AssertionError): + acl_obj.__getitem__('foo') + + def test_baseacl_getitem(self): + acl_obj = acl.BaseACL(request='foo') + clx_cls = Mock() + clx_cls.pk_field.return_value = 'storyname' + acl_obj.__context_class__ = clx_cls + obj = acl_obj.__getitem__('foo') + clx_cls.pk_field.assert_called_once_with() + clx_cls.get.assert_called_once_with( + __raise=True, storyname='foo') + assert obj.__acl__ == acl_obj.__context_acl__ + assert obj.__parent__ == acl_obj + assert obj.__name__ == 'foo' + + def test_rootacl(self): + acl_obj = acl.RootACL(request='foo') + assert acl_obj.__acl__ == [(Allow, 'g:admin', ALL_PERMISSIONS)] + assert acl_obj.request == 'foo' + + def test_adminacl(self): + acl_obj = acl.AdminACL(request='foo') + assert isinstance(acl_obj, acl.BaseACL) + assert acl_obj['foo'] == 1 + assert acl_obj['qweoo'] == 1 + + def test_guestacl_acl(self): + acl_obj = acl.GuestACL(request='foo') + assert acl_obj.acl == [ + (Allow, 'g:admin', ALL_PERMISSIONS), + (Allow, Everyone, ['index', 'show']) + ] + + def test_guestacl_context_acl(self): + acl_obj = acl.GuestACL(request='foo') + assert acl_obj.context_acl('asdasd') == [ + (Allow, 'g:admin', ALL_PERMISSIONS), + (Allow, Everyone, ['index', 'show']), + ] + + def test_authenticatedreadacl_acl(self): + acl_obj = acl.AuthenticatedReadACL(request='foo') + assert acl_obj.acl == [ + (Allow, 'g:admin', ALL_PERMISSIONS), + (Allow, Authenticated, 'index') + ] + + def test_authenticatedreadacl_context_acl(self): + acl_obj = acl.AuthenticatedReadACL(request='foo') + assert acl_obj.context_acl('asdasd') == [ + (Allow, 'g:admin', ALL_PERMISSIONS), + (Allow, Authenticated, 'show'), + ] + + +class TestSelfParamMixin(object): + + def test_resolve_self_key_wrong_key(self): + obj = acl.SelfParamMixin() + assert obj.param_value == 'self' + assert obj.resolve_self_key('') == '' + assert obj.resolve_self_key('foo') == 'foo' + + def test_resolve_self_key_user_not_logged_in(self): + obj = acl.SelfParamMixin() + obj.request = Mock(user=None) + assert obj.resolve_self_key('self') == 'self' + + def test_resolve_self_key_no_model_Cls(self): + obj = acl.SelfParamMixin() + obj.__context_class__ = None + obj.request = Mock(user=1) + assert obj.resolve_self_key('self') == 'self' + + def test_resolve_self_key_user_wrong_class(self): + obj = acl.SelfParamMixin() + obj.__context_class__ = dict + obj.request = Mock(user='a') + assert obj.resolve_self_key('self') == 'self' + + def test_resolve_self_key(self): + obj = acl.SelfParamMixin() + obj.__context_class__ = Mock + user = Mock(username='user12') + user.pk_field.return_value = 'username' + obj.request = Mock(user=user) + assert obj.resolve_self_key('self') == 'user12' diff --git a/tests/test_authentication/__init__.py b/tests/test_authentication/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_authentication/fixtures.py b/tests/test_authentication/fixtures.py new file mode 100644 index 0000000..7797f91 --- /dev/null +++ b/tests/test_authentication/fixtures.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.fixture(scope='module') +def engine_mock(request): + import nefertari + from mock import Mock + + nefertari.engine = Mock() + nefertari.engine.BaseDocument = object + return nefertari.engine diff --git a/tests/test_authentication/test_models.py b/tests/test_authentication/test_models.py new file mode 100644 index 0000000..5c41090 --- /dev/null +++ b/tests/test_authentication/test_models.py @@ -0,0 +1,256 @@ +import pytest +from mock import Mock, patch + +from .fixtures import engine_mock +from nefertari.json_httpexceptions import JHTTPBadRequest + + +class TestModelHelpers(object): + + def test_lower_strip(self, engine_mock): + from nefertari.authentication import models + assert models.lower_strip('Foo ') == 'foo' + assert models.lower_strip(None) == '' + + def test_encrypt_password(self, engine_mock): + from nefertari.authentication import models + encrypted = models.encrypt_password('foo') + assert models.crypt.match(encrypted) + assert encrypted != 'foo' + assert encrypted == models.encrypt_password(encrypted) + + @patch('nefertari.authentication.models.uuid.uuid4') + def test_create_apikey_token(self, mock_uuid, engine_mock): + from nefertari.authentication import models + mock_uuid.return_value = Mock(hex='foo-bar') + assert models.create_apikey_token() == 'foobar' + + +mixin_path = 'nefertari.authentication.models.AuthModelDefaultMixin.' + + +class TestAuthModelDefaultMixin(object): + def test_is_admin(self, engine_mock): + from nefertari.authentication import models + user = Mock(groups=['user']) + assert not models.AuthModelDefaultMixin.is_admin(user) + user = Mock(groups=['user', 'admin']) + assert models.AuthModelDefaultMixin.is_admin(user) + + @patch(mixin_path + 'get_resource') + def test_get_token_credentials(self, mock_res, engine_mock): + from nefertari.authentication import models + user = Mock() + user.api_key.token = 'foo-token' + mock_res.return_value = user + token = models.AuthModelDefaultMixin.get_token_credentials('user1', 1) + assert token == 'foo-token' + mock_res.assert_called_once_with(username='user1') + + @patch(mixin_path + 'get_resource') + def test_get_token_credentials_user_not_found(self, mock_res, engine_mock): + from nefertari.authentication import models + mock_res.return_value = None + token = models.AuthModelDefaultMixin.get_token_credentials('user1', 1) + assert token is None + mock_res.assert_called_once_with(username='user1') + + @patch('nefertari.authentication.models.forget') + @patch(mixin_path + 'get_resource') + def test_get_token_credentials_query_error( + self, mock_res, mock_forg, engine_mock): + from nefertari.authentication import models + mock_res.side_effect = Exception + token = models.AuthModelDefaultMixin.get_token_credentials('user1', 1) + assert token is None + mock_res.assert_called_once_with(username='user1') + mock_forg.assert_called_once_with(1) + + @patch(mixin_path + 'get_resource') + def test_get_groups_by_token(self, mock_res, engine_mock): + from nefertari.authentication import models + user = Mock(groups=['admin', 'user']) + user.api_key.token = 'token' + mock_res.return_value = user + groups = models.AuthModelDefaultMixin.get_groups_by_token( + 'user1', 'token', 1) + assert groups == ['g:admin', 'g:user'] + mock_res.assert_called_once_with(username='user1') + + @patch(mixin_path + 'get_resource') + def test_get_groups_by_token_user_not_found(self, mock_res, engine_mock): + from nefertari.authentication import models + mock_res.return_value = None + groups = models.AuthModelDefaultMixin.get_groups_by_token( + 'user1', 'token', 1) + assert groups is None + mock_res.assert_called_once_with(username='user1') + + @patch(mixin_path + 'get_resource') + def test_get_groups_by_token_wrong_token(self, mock_res, engine_mock): + from nefertari.authentication import models + user = Mock(groups=['admin', 'user']) + user.api_key.token = 'dasdasd' + mock_res.return_value = user + groups = models.AuthModelDefaultMixin.get_groups_by_token( + 'user1', 'token', 1) + assert groups is None + mock_res.assert_called_once_with(username='user1') + + @patch('nefertari.authentication.models.forget') + @patch(mixin_path + 'get_resource') + def test_get_groups_by_token_query_error( + self, mock_res, mock_forg, engine_mock): + from nefertari.authentication import models + mock_res.side_effect = Exception + groups = models.AuthModelDefaultMixin.get_groups_by_token( + 'user1', 'token', 1) + assert groups is None + mock_res.assert_called_once_with(username='user1') + mock_forg.assert_called_once_with(1) + + @patch(mixin_path + 'get_resource') + def test_authenticate_by_password(self, mock_res, engine_mock): + from nefertari.authentication import models + user = Mock(password=models.crypt.encode('foo')) + mock_res.return_value = user + success, usr = models.AuthModelDefaultMixin.authenticate_by_password( + {'login': 'user1', 'password': 'foo'}) + assert success + assert user == usr + mock_res.assert_called_once_with(username='user1') + models.AuthModelDefaultMixin.authenticate_by_password( + {'login': 'user1@example.com', 'password': 'foo'}) + mock_res.assert_called_with(email='user1@example.com') + + @patch(mixin_path + 'get_resource') + def test_authenticate_by_password_not_found(self, mock_res, engine_mock): + from nefertari.authentication import models + mock_res.return_value = None + success, usr = models.AuthModelDefaultMixin.authenticate_by_password( + {'login': 'user1', 'password': 'foo'}) + assert not success + assert usr is None + mock_res.assert_called_once_with(username='user1') + + @patch(mixin_path + 'get_resource') + def test_authenticate_by_password_pasword_not_matching( + self, mock_res, engine_mock): + from nefertari.authentication import models + user = Mock(password=models.crypt.encode('foo')) + mock_res.return_value = user + success, usr = models.AuthModelDefaultMixin.authenticate_by_password( + {'login': 'user1', 'password': 'asdasdasd'}) + assert not success + assert user == usr + mock_res.assert_called_once_with(username='user1') + + @patch(mixin_path + 'get_resource') + def test_authenticate_by_password_exception(self, mock_res, engine_mock): + from nefertari.authentication import models + mock_res.side_effect = Exception + success, usr = models.AuthModelDefaultMixin.authenticate_by_password( + {'login': 'user1', 'password': 'asdasdasd'}) + assert not success + assert usr is None + mock_res.assert_called_once_with(username='user1') + + @patch(mixin_path + 'pk_field') + @patch(mixin_path + 'get_resource') + def test_get_groups_by_userid(self, mock_res, mock_field, engine_mock): + from nefertari.authentication import models + mock_field.return_value = 'idid' + user = Mock(groups=['admin', 'user']) + mock_res.return_value = user + groups = models.AuthModelDefaultMixin.get_groups_by_userid( + 'user1', 1) + assert groups == ['g:admin', 'g:user'] + mock_res.assert_called_once_with(idid='user1') + + @patch(mixin_path + 'pk_field') + @patch(mixin_path + 'get_resource') + def test_get_groups_by_userid_user_not_found( + self, mock_res, mock_field, engine_mock): + from nefertari.authentication import models + mock_field.return_value = 'idid' + mock_res.return_value = None + groups = models.AuthModelDefaultMixin.get_groups_by_userid( + 'user1', 1) + assert groups is None + mock_res.assert_called_once_with(idid='user1') + + @patch('nefertari.authentication.models.forget') + @patch(mixin_path + 'pk_field') + @patch(mixin_path + 'get_resource') + def test_get_groups_by_userid_query_error( + self, mock_res, mock_field, mock_forg, engine_mock): + from nefertari.authentication import models + mock_field.return_value = 'idid' + mock_res.side_effect = Exception + groups = models.AuthModelDefaultMixin.get_groups_by_userid( + 'user1', 1) + assert groups is None + mock_res.assert_called_once_with(idid='user1') + mock_forg.assert_called_once_with(1) + + @patch(mixin_path + 'get_or_create') + def test_create_account(self, mock_get, engine_mock): + from nefertari.authentication import models + models.AuthModelDefaultMixin.create_account( + {'username': 1, 'password': 2, 'email': 3, 'foo': 4}) + mock_get.assert_called_once_with( + email=3, + defaults={'username': 1, 'password': 2, 'email': 3}) + + @patch(mixin_path + 'get_or_create') + def test_create_account_bad_request(self, mock_get, engine_mock): + from nefertari.authentication import models + mock_get.side_effect = JHTTPBadRequest + with pytest.raises(JHTTPBadRequest) as ex: + models.AuthModelDefaultMixin.create_account({'email': 3}) + assert str(ex.value) == 'Failed to create account.' + mock_get.assert_called_once_with(email=3, defaults={'email': 3}) + + @patch('nefertari.authentication.models.authenticated_userid') + @patch(mixin_path + 'pk_field') + @patch(mixin_path + 'get_resource') + def test_get_authuser_by_userid( + self, mock_res, mock_id, mock_auth, engine_mock): + from nefertari.authentication import models + mock_auth.return_value = 123 + mock_id.return_value = 'idid' + models.AuthModelDefaultMixin.get_authuser_by_userid(1) + mock_auth.assert_called_once_with(1) + mock_res.assert_called_once_with(idid=123) + + @patch('nefertari.authentication.models.authenticated_userid') + @patch(mixin_path + 'pk_field') + @patch(mixin_path + 'get_resource') + def test_get_authuser_by_userid_not_authenticated( + self, mock_res, mock_id, mock_auth, engine_mock): + from nefertari.authentication import models + mock_auth.return_value = None + mock_id.return_value = 'idid' + models.AuthModelDefaultMixin.get_authuser_by_userid(1) + mock_auth.assert_called_once_with(1) + assert not mock_res.called + + @patch('nefertari.authentication.models.authenticated_userid') + @patch(mixin_path + 'get_resource') + def test_get_authuser_by_name( + self, mock_res, mock_auth, engine_mock): + from nefertari.authentication import models + mock_auth.return_value = 'user1' + models.AuthModelDefaultMixin.get_authuser_by_name(1) + mock_auth.assert_called_once_with(1) + mock_res.assert_called_once_with(username='user1') + + @patch('nefertari.authentication.models.authenticated_userid') + @patch(mixin_path + 'get_resource') + def test_get_authuser_by_name_not_authenticated( + self, mock_res, mock_auth, engine_mock): + from nefertari.authentication import models + mock_auth.return_value = None + models.AuthModelDefaultMixin.get_authuser_by_name(1) + mock_auth.assert_called_once_with(1) + assert not mock_res.called diff --git a/tests/test_authentication/test_policies.py b/tests/test_authentication/test_policies.py new file mode 100644 index 0000000..dc273c0 --- /dev/null +++ b/tests/test_authentication/test_policies.py @@ -0,0 +1,108 @@ +from mock import Mock, patch + +from nefertari import authentication as auth +from .fixtures import engine_mock + + +@patch('nefertari.authentication.policies.create_apikey_model') +class TestApiKeyAuthenticationPolicy(object): + + def test_init(self, mock_apikey, engine_mock): + user_model = Mock() + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model=user_model, check='foo', + credentials_callback='bar') + assert not engine_mock.get_document_cls.called + mock_apikey.assert_called_once_with(user_model) + assert policy.check == 'foo' + assert policy.credentials_callback == 'bar' + + def test_init_string_user_model(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + engine_mock.get_document_cls.assert_called_once_with('User1') + mock_apikey.assert_called_once_with(engine_mock.get_document_cls()) + assert policy.check == 'foo' + assert policy.credentials_callback == 'bar' + + def test_remember(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + policy.credentials_callback = lambda uname, req: 'token' + headers = policy.remember(request=None, username='user1') + assert headers == [('WWW-Authenticate', 'ApiKey user1:token')] + + def test_forget(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + policy.realm = 'Foo' + headers = policy.forget(request=None) + assert headers == [('WWW-Authenticate', 'ApiKey realm="Foo"')] + + def test_unauthenticated_userid(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + policy._get_credentials = Mock() + policy._get_credentials.return_value = ('user1', 'token') + val = policy.unauthenticated_userid(request=1) + policy._get_credentials.assert_called_once_with(1) + assert val == 'user1' + + def test_callback_no_creds(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + policy._get_credentials = Mock(return_value=None) + policy.check = Mock() + policy.callback('user1', 1) + policy._get_credentials.assert_called_once_with(1) + assert not policy.check.called + + def test_callback(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + policy._get_credentials = Mock(return_value=('user1', 'token')) + policy.check = Mock() + policy.callback('user1', 1) + policy._get_credentials.assert_called_once_with(1) + policy.check.assert_called_once_with('user1', 'token', 1) + + def test_get_credentials_no_header(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + request = Mock(headers={}) + assert policy._get_credentials(request) is None + + def test_get_credentials_wrong_header(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + request = Mock(headers={'Authorization': 'foo'}) + assert policy._get_credentials(request) is None + + def test_get_credentials_not_apikey_header(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + request = Mock(headers={'Authorization': 'foo bar'}) + assert policy._get_credentials(request) is None + + def test_get_credentials_not_full_token(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + request = Mock(headers={'Authorization': 'ApiKey user1'}) + assert policy._get_credentials(request) is None + + def test_get_credentials(self, mock_apikey, engine_mock): + policy = auth.policies.ApiKeyAuthenticationPolicy( + user_model='User1', check='foo', + credentials_callback='bar') + request = Mock(headers={'Authorization': 'ApiKey user1:token'}) + assert policy._get_credentials(request) == ('user1', 'token') diff --git a/tests/test_elasticsearch.py b/tests/test_elasticsearch.py new file mode 100644 index 0000000..e6f678b --- /dev/null +++ b/tests/test_elasticsearch.py @@ -0,0 +1,687 @@ +import logging + +import pytest +from mock import Mock, patch, call +from elasticsearch.exceptions import TransportError + +from nefertari import elasticsearch as es +from nefertari.json_httpexceptions import JHTTPBadRequest, JHTTPNotFound +from nefertari.utils import dictset + + +class TestESHttpConnection(object): + + @patch('nefertari.elasticsearch.log') + def test_perform_request_debug(self, mock_log): + mock_log.level = logging.DEBUG + conn = es.ESHttpConnection() + conn.pool = Mock() + conn.pool.urlopen.return_value = Mock(data='foo', status=200) + conn.perform_request('POST', 'http://localhost:9200') + mock_log.debug.assert_called_once_with( + "('POST', 'http://localhost:9200')") + conn.perform_request('POST', 'http://localhost:9200'*200) + assert mock_log.debug.call_count == 2 + + def test_perform_request_exception(self): + conn = es.ESHttpConnection() + conn.pool = Mock() + conn.pool.urlopen.side_effect = TransportError('N/A', '') + with pytest.raises(JHTTPBadRequest): + conn.perform_request('POST', 'http://localhost:9200') + + @patch('nefertari.elasticsearch.log') + def test_perform_request_no_index(self, mock_log): + mock_log.level = logging.DEBUG + mock_log.debug.side_effect = TransportError(404, '') + conn = es.ESHttpConnection() + with pytest.raises(es.IndexNotFoundException): + conn.perform_request('POST', 'http://localhost:9200') + + +class TestHelperFunctions(object): + @patch('nefertari.elasticsearch.ES') + def test_includeme(self, mock_es): + config = Mock() + config.registry.settings = {'foo': 'bar'} + es.includeme(config) + mock_es.setup.assert_called_once_with({'foo': 'bar'}) + + def test_apply_sort(self): + assert es.apply_sort('+foo,-bar ,zoo') == 'foo:asc,bar:desc,zoo:asc' + + def test_apply_sort_empty(self): + assert es.apply_sort('') == '' + + def test_build_terms(self): + terms = es.build_terms('foo', [1, 2, 3]) + assert terms == 'foo:1 OR foo:2 OR foo:3' + + def test_build_terms_custom_operator(self): + terms = es.build_terms('foo', [1, 2, 3], operator='AND') + assert terms == 'foo:1 AND foo:2 AND foo:3' + + def test_build_qs(self): + qs = es.build_qs(dictset({'foo': 1, 'bar': '_all', 'zoo': 2})) + assert qs == 'foo:1 AND zoo:2' + + def test_build_list(self): + qs = es.build_qs(dictset({'foo': [1, 2], 'zoo': 3})) + assert qs == 'foo:1 OR foo:2 AND zoo:3' + + def test_build_dunder_key(self): + qs = es.build_qs(dictset({'foo': [1, 2], '__zoo__': 3})) + assert qs == 'foo:1 OR foo:2' + + def test_build_raw_terms(self): + qs = es.build_qs(dictset({'foo': [1, 2]}), _raw_terms=' AND qoo:1') + assert qs == 'foo:1 OR foo:2 AND qoo:1' + + def test_build_operator(self): + qs = es.build_qs(dictset({'foo': 1, 'qoo': 2}), operator='OR') + assert qs == 'qoo:2 OR foo:1' + + def test_es_docs(self): + assert issubclass(es._ESDocs, list) + docs = es._ESDocs() + assert docs._total == 0 + assert docs._start == 0 + + @patch('nefertari.elasticsearch.ES') + def test_bulk_body(self, mock_es): + es._bulk_body('foo') + mock_es.api.bulk.assert_called_once_with(body='foo') + + +class TestES(object): + + @patch('nefertari.elasticsearch.ES.settings') + def test_init(self, mock_set): + obj = es.ES(source='Foo') + assert obj.index_name == mock_set.index_name + assert obj.doc_type == 'foo' + assert obj.chunk_size == 100 + obj = es.ES(source='Foo', index_name='a', chunk_size=2) + assert obj.index_name == 'a' + assert obj.doc_type == 'foo' + assert obj.chunk_size == 2 + + def test_src2type(self): + assert es.ES.src2type('FooO') == 'fooo' + + @patch('nefertari.elasticsearch.engine') + @patch('nefertari.elasticsearch.elasticsearch') + def test_setup(self, mock_es, mock_engine): + settings = dictset({ + 'elasticsearch.hosts': '127.0.0.1:8080,127.0.0.2:8090', + 'elasticsearch.sniff': 'true', + }) + es.ES.setup(settings) + mock_es.Elasticsearch.assert_called_once_with( + hosts=[{'host': '127.0.0.1', 'port': '8080'}, + {'host': '127.0.0.2', 'port': '8090'}], + serializer=mock_engine.ESJSONSerializer(), + connection_class=es.ESHttpConnection, + sniff_on_start=True, + sniff_on_connection_fail=True + ) + assert es.ES.api == mock_es.Elasticsearch() + + @patch('nefertari.elasticsearch.engine') + @patch('nefertari.elasticsearch.elasticsearch') + def test_setup_no_settings(self, mock_es, mock_engine): + settings = dictset({}) + with pytest.raises(Exception) as ex: + es.ES.setup(settings) + assert 'Bad or missing settings for elasticsearch' in str(ex.value) + assert not mock_es.Elasticsearch.called + + def test_process_chunks(self): + obj = es.ES('Foo', 'foondex') + operation = Mock() + documents = [1, 2, 3, 4, 5] + obj.process_chunks(documents, operation, chunk_size=100) + operation.assert_called_once_with([1, 2, 3, 4, 5]) + + def test_process_chunks_multiple(self): + obj = es.ES('Foo', 'foondex') + operation = Mock() + documents = [1, 2, 3, 4, 5] + obj.process_chunks(documents, operation, chunk_size=3) + operation.assert_has_calls([call([1, 2, 3]), call([4, 5])]) + + def test_process_chunks_no_docs(self): + obj = es.ES('Foo', 'foondex') + operation = Mock() + obj.process_chunks([], operation, chunk_size=3) + assert not operation.called + + def test_prep_bulk_documents_not_dict(self): + obj = es.ES('Foo', 'foondex') + with pytest.raises(ValueError) as ex: + obj.prep_bulk_documents('', 'q') + assert str(ex.value) == 'Document type must be `dict` not a `str`' + + def test_prep_bulk_documents(self): + obj = es.ES('Foo', 'foondex') + docs = [ + {'_type': 'Story', 'id': 'story1'}, + {'_type': 'Story', 'id': 'story2'}, + ] + prepared = obj.prep_bulk_documents('myaction', docs) + assert len(prepared) == 2 + doc1meta, doc1 = prepared[0] + assert doc1meta.keys() == ['myaction'] + assert doc1meta['myaction'].keys() == [ + 'action', '_type', '_id', '_index'] + assert doc1 == {'_type': 'Story', 'id': 'story1'} + assert doc1meta['myaction']['action'] == 'myaction' + assert doc1meta['myaction']['_index'] == 'foondex' + assert doc1meta['myaction']['_type'] == 'story' + assert doc1meta['myaction']['_id'] == 'story1' + + def test_prep_bulk_documents_no_type(self): + obj = es.ES('Foo', 'foondex') + docs = [ + {'id': 'story2'}, + ] + prepared = obj.prep_bulk_documents('myaction', docs) + assert len(prepared) == 1 + doc2meta, doc2 = prepared[0] + assert doc2meta.keys() == ['myaction'] + assert doc2meta['myaction'].keys() == [ + 'action', '_type', '_id', '_index'] + assert doc2 == {'id': 'story2'} + assert doc2meta['myaction']['action'] == 'myaction' + assert doc2meta['myaction']['_index'] == 'foondex' + assert doc2meta['myaction']['_type'] == 'foo' + assert doc2meta['myaction']['_id'] == 'story2' + + def test_bulk_no_docs(self): + obj = es.ES('Foo', 'foondex') + assert obj._bulk('myaction', []) is None + + @patch('nefertari.elasticsearch.ES.prep_bulk_documents') + @patch('nefertari.elasticsearch.ES.process_chunks') + def test_bulk(self, mock_proc, mock_prep): + obj = es.ES('Foo', 'foondex', chunk_size=1) + docs = [ + [{'delete': {'action': 'delete', '_id': 'story1'}}, + {'_type': 'Story', 'id': 'story1', 'timestamp': 1}], + [{'index': {'action': 'index', '_id': 'story2'}}, + {'_type': 'Story', 'id': 'story2', 'timestamp': 2}], + ] + mock_prep.return_value = docs + obj._bulk('myaction', docs) + mock_prep.assert_called_once_with('myaction', docs) + mock_proc.assert_called_once_with( + documents=[ + {'delete': {'action': 'delete', '_id': 'story1'}}, + {'index': {'action': 'index', '_id': 'story2'}, + '_timestamp': 2}, + {'_type': 'Story', 'id': 'story2', 'timestamp': 2}, + ], + operation=es._bulk_body, + chunk_size=2 + ) + + @patch('nefertari.elasticsearch.ES.prep_bulk_documents') + @patch('nefertari.elasticsearch.ES.process_chunks') + def test_bulk_no_prepared_docs(self, mock_proc, mock_prep): + obj = es.ES('Foo', 'foondex', chunk_size=1) + mock_prep.return_value = [] + obj._bulk('myaction', ['a'], chunk_size=4) + mock_prep.assert_called_once_with('myaction', ['a']) + assert not mock_proc.called + + @patch('nefertari.elasticsearch.ES._bulk') + def test_index(self, mock_bulk): + obj = es.ES('Foo', 'foondex') + obj.index(['a'], chunk_size=4) + mock_bulk.assert_called_once_with('index', ['a'], 4) + + @patch('nefertari.elasticsearch.ES._bulk') + def test_delete(self, mock_bulk): + obj = es.ES('Foo', 'foondex') + obj.delete(ids=[1, 2]) + mock_bulk.assert_called_once_with( + 'delete', [{'id': 1, '_type': 'foo'}, {'id': 2, '_type': 'foo'}]) + + @patch('nefertari.elasticsearch.ES._bulk') + def test_delete_single_obj(self, mock_bulk): + obj = es.ES('Foo', 'foondex') + obj.delete(ids=1) + mock_bulk.assert_called_once_with( + 'delete', [{'id': 1, '_type': 'foo'}]) + + @patch('nefertari.elasticsearch.ES._bulk') + @patch('nefertari.elasticsearch.ES.api.mget') + def test_index_missing_documents(self, mock_mget, mock_bulk): + obj = es.ES('Foo', 'foondex') + documents = [ + {'id': 1, 'name': 'foo'}, + {'id': 2, 'name': 'bar'}, + {'id': 3, 'name': 'baz'}, + ] + mock_mget.return_value = {'docs': [ + {'_id': '1', 'name': 'foo', 'found': False}, + {'_id': '2', 'name': 'bar', 'found': True}, + {'_id': '3', 'name': 'baz'}, + ]} + obj.index_missing_documents(documents, 10) + mock_mget.assert_called_once_with( + index='foondex', + doc_type='foo', + fields=['_id'], + body={'ids': [1, 2, 3]} + ) + mock_bulk.assert_called_once_with( + 'index', [{'id': 1, 'name': 'foo'}, {'id': 3, 'name': 'baz'}], 10) + + @patch('nefertari.elasticsearch.ES._bulk') + @patch('nefertari.elasticsearch.ES.api.mget') + def test_index_missing_documents_no_index(self, mock_mget, mock_bulk): + obj = es.ES('Foo', 'foondex') + documents = [ + {'id': 1, 'name': 'foo'}, + ] + mock_mget.side_effect = es.IndexNotFoundException() + obj.index_missing_documents(documents, 10) + mock_mget.assert_called_once_with( + index='foondex', + doc_type='foo', + fields=['_id'], + body={'ids': [1]} + ) + mock_bulk.assert_called_once_with( + 'index', [{'id': 1, 'name': 'foo'}], 10) + + @patch('nefertari.elasticsearch.ES._bulk') + @patch('nefertari.elasticsearch.ES.api.mget') + def test_index_missing_documents_no_docs_passed(self, mock_mget, mock_bulk): + obj = es.ES('Foo', 'foondex') + assert obj.index_missing_documents([], 10) is None + assert not mock_mget.called + assert not mock_bulk.called + + @patch('nefertari.elasticsearch.ES._bulk') + @patch('nefertari.elasticsearch.ES.api.mget') + def test_index_missing_documents_all_docs_found(self, mock_mget, mock_bulk): + obj = es.ES('Foo', 'foondex') + documents = [ + {'id': 1, 'name': 'foo'}, + ] + mock_mget.return_value = {'docs': [ + {'_id': '1', 'name': 'foo', 'found': True}, + ]} + obj.index_missing_documents(documents, 10) + mock_mget.assert_called_once_with( + index='foondex', + doc_type='foo', + fields=['_id'], + body={'ids': [1]} + ) + assert not mock_bulk.called + + def test_get_by_ids_no_ids(self): + obj = es.ES('Foo', 'foondex') + docs = obj.get_by_ids([]) + assert isinstance(docs, es._ESDocs) + assert len(docs) == 0 + + @patch('nefertari.elasticsearch.ES.api.mget') + def test_get_by_ids(self, mock_mget): + obj = es.ES('Foo', 'foondex') + documents = [{'_id': 1, '_type': 'Story'}] + mock_mget.return_value = { + 'docs': [{ + '_type': 'foo', + '_id': 1, + '_source': {'_id': 1, '_type': 'Story', 'name': 'bar'}, + 'fields': {'name': 'bar'} + }] + } + docs = obj.get_by_ids(documents, _page=0) + mock_mget.assert_called_once_with( + body={'docs': [{'_index': 'foondex', '_type': 'story', '_id': 1}]} + ) + assert len(docs) == 1 + assert docs[0]._id == 1 + assert docs[0].name == 'bar' + assert docs[0]._type == 'Story' + assert docs._nefertari_meta['total'] == 1 + assert docs._nefertari_meta['start'] == 0 + assert docs._nefertari_meta['fields'] == [] + + @patch('nefertari.elasticsearch.ES.api.mget') + def test_get_by_ids_fields(self, mock_mget): + obj = es.ES('Foo', 'foondex') + documents = [{'_id': 1, '_type': 'Story'}] + mock_mget.return_value = { + 'docs': [{ + '_type': 'foo', + '_id': 1, + '_source': {'_id': 1, '_type': 'Story', 'name': 'bar'}, + 'fields': {'name': 'bar'} + }] + } + docs = obj.get_by_ids(documents, _limit=1, _fields=['name']) + mock_mget.assert_called_once_with( + body={'docs': [{'_index': 'foondex', '_type': 'story', '_id': 1}]}, + fields=['name'] + ) + assert len(docs) == 1 + assert not hasattr(docs[0], '_id') + assert not hasattr(docs[0], '_type') + assert docs[0].name == 'bar' + assert docs._nefertari_meta['total'] == 1 + assert docs._nefertari_meta['start'] == 0 + assert docs._nefertari_meta['fields'] == ['name'] + + @patch('nefertari.elasticsearch.ES.api.mget') + def test_get_by_ids_no_index_raise(self, mock_mget): + obj = es.ES('Foo', 'foondex') + documents = [{'_id': 1, '_type': 'Story'}] + mock_mget.side_effect = es.IndexNotFoundException() + with pytest.raises(JHTTPNotFound) as ex: + obj.get_by_ids(documents, __raise_on_empty=True) + assert 'resource not found (Index does not exist)' in str(ex.value) + + @patch('nefertari.elasticsearch.ES.api.mget') + def test_get_by_ids_no_index_not_raise(self, mock_mget): + obj = es.ES('Foo', 'foondex') + documents = [{'_id': 1, '_type': 'Story'}] + mock_mget.side_effect = es.IndexNotFoundException() + try: + docs = obj.get_by_ids(documents, __raise_on_empty=False) + except JHTTPNotFound: + raise Exception('Unexpected error') + assert len(docs) == 0 + + @patch('nefertari.elasticsearch.ES.api.mget') + def test_get_by_ids_not_found_raise(self, mock_mget): + obj = es.ES('Foo', 'foondex') + documents = [{'_id': 1, '_type': 'Story'}] + mock_mget.return_value = {'docs': [{'_type': 'foo', '_id': 1}]} + with pytest.raises(JHTTPNotFound): + obj.get_by_ids(documents, __raise_on_empty=True) + + @patch('nefertari.elasticsearch.ES.api.mget') + def test_get_by_ids_not_found_not_raise(self, mock_mget): + obj = es.ES('Foo', 'foondex') + documents = [{'_id': 1, '_type': 'Story'}] + mock_mget.return_value = {'docs': [{'_type': 'foo', '_id': 1}]} + try: + docs = obj.get_by_ids(documents, __raise_on_empty=False) + except JHTTPNotFound: + raise Exception('Unexpected error') + assert len(docs) == 0 + + def test_build_search_params_no_body(self): + obj = es.ES('Foo', 'foondex') + params = obj.build_search_params( + {'foo': 1, 'zoo': 2, '_raw_terms': ' AND q:5', '_limit': 10} + ) + assert params.keys() == ['body', 'doc_type', 'from_', 'size', 'index'] + assert params['body'] == { + 'query': {'query_string': {'query': 'foo:1 AND zoo:2 AND q:5'}}} + assert params['index'] == 'foondex' + assert params['doc_type'] == 'foo' + + def test_build_search_params_no_body_no_qs(self): + obj = es.ES('Foo', 'foondex') + params = obj.build_search_params({'_limit': 10}) + assert params.keys() == ['body', 'doc_type', 'from_', 'size', 'index'] + assert params['body'] == {'query': {'match_all': {}}} + assert params['index'] == 'foondex' + assert params['doc_type'] == 'foo' + + def test_build_search_params_no_limit(self): + obj = es.ES('Foo', 'foondex') + with pytest.raises(JHTTPBadRequest) as ex: + obj.build_search_params({'foo': 1}) + assert str(ex.value) == 'Missing _limit' + + def test_build_search_params_sort(self): + obj = es.ES('Foo', 'foondex') + params = obj.build_search_params({ + 'foo': 1, '_sort': '+a,-b,c', '_limit': 10}) + assert params.keys() == [ + 'body', 'doc_type', 'index', 'sort', 'from_', 'size'] + assert params['body'] == { + 'query': {'query_string': {'query': 'foo:1'}}} + assert params['index'] == 'foondex' + assert params['doc_type'] == 'foo' + assert params['sort'] == 'a:asc,b:desc,c:asc' + + def test_build_search_params_fields(self): + obj = es.ES('Foo', 'foondex') + params = obj.build_search_params({ + 'foo': 1, '_fields': ['a'], '_limit': 10}) + assert params.keys() == [ + 'body', 'doc_type', 'index', 'fields', 'from_', 'size'] + assert params['body'] == { + 'query': {'query_string': {'query': 'foo:1'}}} + assert params['index'] == 'foondex' + assert params['doc_type'] == 'foo' + assert params['fields'] == ['a'] + + def test_build_search_params_search_fields(self): + obj = es.ES('Foo', 'foondex') + params = obj.build_search_params({ + 'foo': 1, '_search_fields': 'a,b', '_limit': 10}) + assert params.keys() == ['body', 'doc_type', 'from_', 'size', 'index'] + assert params['body'] == {'query': {'query_string': { + 'fields': ['b^1', 'a^2'], + 'query': 'foo:1'}}} + assert params['index'] == 'foondex' + assert params['doc_type'] == 'foo' + + @patch('nefertari.elasticsearch.ES.api.count') + def test_do_count(self, mock_count): + obj = es.ES('Foo', 'foondex') + mock_count.return_value = {'count': 123} + val = obj.do_count( + {'foo': 1, 'size': 2, 'from_': 0, 'sort': 'foo:asc'}) + assert val == 123 + mock_count.assert_called_once_with(foo=1) + + @patch('nefertari.elasticsearch.ES.api.count') + def test_do_count_no_index(self, mock_count): + obj = es.ES('Foo', 'foondex') + mock_count.side_effect = es.IndexNotFoundException() + val = obj.do_count( + {'foo': 1, 'size': 2, 'from_': 0, 'sort': 'foo:asc'}) + assert val == 0 + mock_count.assert_called_once_with(foo=1) + + @patch('nefertari.elasticsearch.ES.build_search_params') + @patch('nefertari.elasticsearch.ES.do_count') + def test_get_collection_count_without_body(self, mock_count, mock_build): + obj = es.ES('Foo', 'foondex') + mock_build.return_value = {'foo': 'bar'} + obj.get_collection(_count=True, foo=1) + mock_count.assert_called_once_with({'foo': 'bar'}) + mock_build.assert_called_once_with({'_count': True, 'foo': 1}) + + @patch('nefertari.elasticsearch.ES.build_search_params') + @patch('nefertari.elasticsearch.ES.do_count') + def test_get_collection_count_with_body(self, mock_count, mock_build): + obj = es.ES('Foo', 'foondex') + obj.get_collection(_count=True, foo=1, body={'foo': 'bar'}) + mock_count.assert_called_once_with( + {'body': {'foo': 'bar'}, '_count': True, 'foo': 1}) + assert not mock_build.called + + @patch('nefertari.elasticsearch.ES.api.search') + def test_get_collection_fields(self, mock_search): + obj = es.ES('Foo', 'foondex') + mock_search.return_value = { + 'hits': { + 'hits': [{'fields': {'foo': 'bar', 'id': 1}, '_score': 2}], + 'total': 4, + }, + 'took': 2.8, + } + docs = obj.get_collection( + fields=['foo'], body={'foo': 'bar'}, from_=0) + mock_search.assert_called_once_with(body={'foo': 'bar'}, from_=0) + assert len(docs) == 1 + assert docs[0].id == 1 + assert docs[0]._score == 2 + assert docs[0].foo == 'bar' + assert docs._nefertari_meta['total'] == 4 + assert docs._nefertari_meta['start'] == 0 + assert docs._nefertari_meta['fields'] == ['foo'] + assert docs._nefertari_meta['took'] == 2.8 + + @patch('nefertari.elasticsearch.ES.api.search') + def test_get_collection_source(self, mock_search): + obj = es.ES('Foo', 'foondex') + mock_search.return_value = { + 'hits': { + 'hits': [{'_source': {'foo': 'bar', 'id': 1}, '_score': 2}], + 'total': 4, + }, + 'took': 2.8, + } + docs = obj.get_collection(body={'foo': 'bar'}, from_=0) + mock_search.assert_called_once_with(body={'foo': 'bar'}, from_=0) + assert len(docs) == 1 + assert docs[0].id == 1 + assert docs[0]._score == 2 + assert docs[0].foo == 'bar' + assert docs._nefertari_meta['total'] == 4 + assert docs._nefertari_meta['start'] == 0 + assert docs._nefertari_meta['fields'] == '' + assert docs._nefertari_meta['took'] == 2.8 + + @patch('nefertari.elasticsearch.ES.api.search') + def test_get_collection_no_index_raise(self, mock_search): + obj = es.ES('Foo', 'foondex') + mock_search.side_effect = es.IndexNotFoundException() + with pytest.raises(JHTTPNotFound) as ex: + obj.get_collection( + body={'foo': 'bar'}, __raise_on_empty=True, + from_=0) + assert 'resource not found (Index does not exist)' in str(ex.value) + + @patch('nefertari.elasticsearch.ES.api.search') + def test_get_collection_no_index_not_raise(self, mock_search): + obj = es.ES('Foo', 'foondex') + mock_search.side_effect = es.IndexNotFoundException() + try: + docs = obj.get_collection( + body={'foo': 'bar'}, __raise_on_empty=False, + from_=0) + except JHTTPNotFound: + raise Exception('Unexpected error') + assert len(docs) == 0 + + @patch('nefertari.elasticsearch.ES.api.search') + def test_get_collection_not_found_raise(self, mock_search): + obj = es.ES('Foo', 'foondex') + mock_search.return_value = { + 'hits': { + 'hits': [], + 'total': 4, + }, + 'took': 2.8, + } + with pytest.raises(JHTTPNotFound): + obj.get_collection( + body={'foo': 'bar'}, __raise_on_empty=True, + from_=0) + + @patch('nefertari.elasticsearch.ES.api.search') + def test_get_collection_not_found_not_raise(self, mock_search): + obj = es.ES('Foo', 'foondex') + mock_search.return_value = { + 'hits': { + 'hits': [], + 'total': 4, + }, + 'took': 2.8, + } + try: + docs = obj.get_collection( + body={'foo': 'bar'}, __raise_on_empty=False, + from_=0) + except JHTTPNotFound: + raise Exception('Unexpected error') + assert len(docs) == 0 + + @patch('nefertari.elasticsearch.ES.api.get_source') + def test_get_resource(self, mock_get): + obj = es.ES('Foo', 'foondex') + mock_get.return_value = {'foo': 'bar', 'id': 4, '_type': 'Story'} + story = obj.get_resource(name='foo') + assert story.id == 4 + assert story.foo == 'bar' + mock_get.assert_called_once_with( + name='foo', index='foondex', doc_type='foo', ignore=404) + + @patch('nefertari.elasticsearch.ES.api.get_source') + def test_get_resource_no_index_raise(self, mock_get): + obj = es.ES('Foo', 'foondex') + mock_get.side_effect = es.IndexNotFoundException() + with pytest.raises(JHTTPNotFound) as ex: + obj.get_resource(name='foo') + assert 'resource not found (Index does not exist)' in str(ex.value) + + @patch('nefertari.elasticsearch.ES.api.get_source') + def test_get_resource_no_index_not_raise(self, mock_get): + obj = es.ES('Foo', 'foondex') + mock_get.side_effect = es.IndexNotFoundException() + try: + obj.get_resource(name='foo', __raise_on_empty=False) + except JHTTPNotFound: + raise Exception('Unexpected error') + + @patch('nefertari.elasticsearch.ES.api.get_source') + def test_get_resource_not_found_raise(self, mock_get): + obj = es.ES('Foo', 'foondex') + mock_get.return_value = {} + with pytest.raises(JHTTPNotFound): + obj.get_resource(name='foo') + + @patch('nefertari.elasticsearch.ES.api.get_source') + def test_get_resource_not_found_not_raise(self, mock_get): + obj = es.ES('Foo', 'foondex') + mock_get.return_value = {} + try: + obj.get_resource(name='foo', __raise_on_empty=False) + except JHTTPNotFound: + raise Exception('Unexpected error') + + @patch('nefertari.elasticsearch.ES.get_resource') + def test_get(self, mock_get): + obj = es.ES('Foo', 'foondex') + obj.get(__raise=True, foo=1) + mock_get.assert_called_once_with(__raise_on_empty=True, foo=1) + + @patch('nefertari.elasticsearch.ES.settings') + @patch('nefertari.elasticsearch.ES.index') + def test_index_refs(self, mock_ind, mock_settings): + class Foo(object): + _index_enabled = True + + docs = [Foo()] + db_obj = Mock() + db_obj.get_reference_documents.return_value = [(Foo, docs)] + mock_settings.index_name = 'foo' + es.ES.index_refs(db_obj) + mock_ind.assert_called_once_with(docs) + + @patch('nefertari.elasticsearch.ES.settings') + @patch('nefertari.elasticsearch.ES.index') + def test_index_refs_index_disabled(self, mock_ind, mock_settings): + class Foo(object): + _index_enabled = False + + docs = [Foo()] + db_obj = Mock() + db_obj.get_reference_documents.return_value = [(Foo, docs)] + mock_settings.index_name = 'foo' + es.ES.index_refs(db_obj) + assert not mock_ind.called diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..b35de78 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,25 @@ +from mock import Mock, patch + + +class TestEngine(object): + @patch('nefertari.engine.resolve') + def test_includeme(self, mock_resolve): + module = Mock() + config = Mock() + config.registry.settings = {'nefertari.engine': 'foo'} + module.log = 1 + module.__testvar__ = 3 + module.another_var = 4 + mock_resolve.return_value = module + from nefertari import engine + assert not hasattr(engine, 'log') + assert not hasattr(engine, '__testvar__') + assert not hasattr(engine, 'another_var') + + engine.includeme(config) + + config.include.assert_called_once_with('foo') + mock_resolve.assert_called_once_with('foo') + assert not hasattr(engine, 'log') + assert not hasattr(engine, '__testvar__') + assert hasattr(engine, 'another_var') diff --git a/tests/test_json_httpexceptions.py b/tests/test_json_httpexceptions.py new file mode 100644 index 0000000..e7cd9ae --- /dev/null +++ b/tests/test_json_httpexceptions.py @@ -0,0 +1,159 @@ +import json + +import pytest +from mock import Mock, patch + +from nefertari import json_httpexceptions as jsonex +from nefertari.renderers import _JSONEncoder + + +class TestJSONHTTPExceptionsModule(object): + + def test_includeme(self): + config = Mock() + jsonex.includeme(config) + config.add_view.assert_called_once_with( + view=jsonex.httperrors, + context=jsonex.http_exc.HTTPError) + + @patch.object(jsonex, 'traceback') + def test_add_stack(self, mock_trace): + mock_trace.format_stack.return_value = ['foo', 'bar'] + assert jsonex.add_stack() == 'foobar' + + def test_create_json_response(self): + request = Mock( + url='http://example.com', + client_addr='127.0.0.1', + remote_addr='127.0.0.2') + obj = Mock( + status_int=401, + location='http://example.com/api') + obj2 = jsonex.create_json_response( + obj, request, encoder=_JSONEncoder, + status_code=402, explanation='success', + message='foo', title='bar') + assert obj2.content_type == 'application/json' + assert isinstance(obj2.body, basestring) + body = json.loads(obj2.body) + assert body.keys() == [ + 'remote_addr', 'status_code', 'explanation', 'title', + 'message', 'id', 'timestamp', 'request_url', 'client_addr' + ] + assert body['remote_addr'] == '127.0.0.2' + assert body['client_addr'] == '127.0.0.1' + assert body['status_code'] == 402 + assert body['explanation'] == 'success' + assert body['title'] == 'bar' + assert body['message'] == 'foo' + assert body['id'] == 'api' + assert body['request_url'] == 'http://example.com' + + @patch.object(jsonex, 'add_stack') + def test_create_json_response_obj_properties(self, mock_stack): + mock_stack.return_value = 'foo' + obj = Mock( + status_int=401, + location='http://example.com/api', + status_code=402, explanation='success', + message='foo', title='bar') + obj2 = jsonex.create_json_response( + obj, None, encoder=_JSONEncoder) + body = json.loads(obj2.body) + assert body['status_code'] == 402 + assert body['explanation'] == 'success' + assert body['title'] == 'bar' + assert body['message'] == 'foo' + assert body['id'] == 'api' + + @patch.object(jsonex, 'add_stack') + def test_create_json_response_stack_calls(self, mock_stack): + mock_stack.return_value = 'foo' + obj = Mock(status_int=401, location='http://example.com/api') + jsonex.create_json_response(obj, None, encoder=_JSONEncoder) + assert mock_stack.call_count == 0 + + obj = Mock(status_int=500, location='http://example.com/api') + jsonex.create_json_response(obj, None, encoder=_JSONEncoder) + mock_stack.assert_called_with() + assert mock_stack.call_count == 1 + + obj = Mock(status_int=401, location='http://example.com/api') + jsonex.create_json_response( + obj, None, encoder=_JSONEncoder, show_stack=True) + mock_stack.assert_called_with() + assert mock_stack.call_count == 2 + + obj = Mock(status_int=401, location='http://example.com/api') + jsonex.create_json_response( + obj, None, encoder=_JSONEncoder, log_it=True) + mock_stack.assert_called_with() + assert mock_stack.call_count == 3 + + def test_exception_response(self): + jsonex.STATUS_MAP[12345] = lambda x: x + 3 + assert jsonex.exception_response(12345, x=1) == 4 + with pytest.raises(KeyError): + jsonex.exception_response(3123123123123123) + jsonex.STATUS_MAP.pop(12345, None) + + def test_status_map(self): + assert list(sorted(jsonex.STATUS_MAP.keys())) == [ + 200, 201, 202, 203, 204, 205, 206, + 300, 301, 302, 303, 304, 305, 307, + 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, + 411, 412, 413, 414, 415, 416, 417, 422, 423, 424, + 500, 501, 502, 503, 504, 505, 507 + ] + for code_exc in jsonex.STATUS_MAP.values(): + assert hasattr(jsonex, code_exc.__name__) + + @patch.object(jsonex, 'create_json_response') + def test_httperrors(self, mock_create): + jsonex.httperrors({'foo': 'bar'}, 1) + mock_create.assert_called_once_with({'foo': 'bar'}, request=1) + + @patch.object(jsonex, 'create_json_response') + def test_jhttpcreated(self, mock_create): + resp = jsonex.JHTTPCreated( + resource={'foo': 'bar'}, + location='http://example.com/1', + encoder=1) + mock_create.assert_called_once_with( + resp, data={'foo': 'bar', 'self': 'http://example.com/1'}, + encoder=1) + + @patch.object(jsonex, 'apply_privacy') + @patch.object(jsonex, 'create_json_response') + def test_jhttpcreated_privacy_applied(self, mock_create, mock_priv): + wrapper = Mock() + mock_priv.return_value = wrapper + wrapper.return_value = {'foo': 'bar', 'self': 'http://example.com/1'} + request = Mock() + request.registry._root_resources = {'foo': Mock(auth=True)} + resp = jsonex.JHTTPCreated( + resource={'foo': 'bar', 'zoo': 1}, + location='http://example.com/1', + encoder=1, + request=request) + mock_create.assert_called_once_with( + resp, data={'foo': 'bar', 'self': 'http://example.com/1'}, + encoder=1) + mock_priv.assert_called_once_with(request=request) + wrapper.assert_called_once_with( + result={'self': 'http://example.com/1', 'foo': 'bar', 'zoo': 1}) + + @patch.object(jsonex, 'apply_privacy') + @patch.object(jsonex, 'create_json_response') + def test_jhttpcreated_auth_disabled(self, mock_create, mock_priv): + request = Mock() + request.registry._root_resources = {'foo': Mock(auth=False)} + resp = jsonex.JHTTPCreated( + resource={'foo': 'bar', 'zoo': 1}, + location='http://example.com/1', + encoder=1, + request=request) + mock_create.assert_called_once_with( + resp, data={'foo': 'bar', 'zoo': 1, 'self': 'http://example.com/1'}, + encoder=1) + assert not mock_priv.called diff --git a/nefertari/tests/test_pyramid_integration.py b/tests/test_pyramid_integration.py similarity index 100% rename from nefertari/tests/test_pyramid_integration.py rename to tests/test_pyramid_integration.py diff --git a/nefertari/tests/test_renderers.py b/tests/test_renderers.py similarity index 55% rename from nefertari/tests/test_renderers.py rename to tests/test_renderers.py index 9244d3a..43d3eea 100644 --- a/nefertari/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -84,3 +84,48 @@ def test_JsonRendererFactory(self): {'request': request, 'view': view})) self.assertDictContainsSubset(self._get_dummy_expected(), result) self.assertEqual('application/json', request.response.content_type) + + @mock.patch('nefertari.renderers.wrappers') + def test_JsonRendererFactory_run_after_calls(self, mock_wrap): + from nefertari.renderers import JsonRendererFactory + factory = JsonRendererFactory({ + 'name': 'json', + 'package': None, + 'registry': None + }) + request = mock.Mock(action='create') + factory.run_after_calls(1, {'request': request}) + assert not mock_wrap.wrap_in_dict.called + + request = mock.Mock(action='show') + factory.run_after_calls(1, {'request': request}) + mock_wrap.wrap_in_dict.assert_called_once_with(request) + mock_wrap.wrap_in_dict().assert_called_once_with(result=1) + + def test_NefertariJsonRendererFactory_run_after_calls(self): + from nefertari.renderers import NefertariJsonRendererFactory + factory = NefertariJsonRendererFactory(None) + filters = { + 'super_action': [lambda request, result: result + ' processed'], + } + request = mock.Mock(action='super_action', filters=filters) + processed = factory.run_after_calls('foo', {'request': request}) + assert processed == 'foo processed' + + def test_NefertariJsonRendererFactory_run_after_calls_no_filters(self): + from nefertari.renderers import NefertariJsonRendererFactory + factory = NefertariJsonRendererFactory(None) + request = mock.Mock(action='action', filters={}) + processed = factory.run_after_calls('foo', {'request': request}) + assert processed == 'foo' + + def test_NefertariJsonRendererFactory_run_after_calls_unknown_action(self): + from nefertari.renderers import NefertariJsonRendererFactory + factory = NefertariJsonRendererFactory(None) + filter = { + 'super_action': [lambda request, result: result + ' processed'], + } + request = mock.Mock(action='simple_action', filters=filter) + request = mock.Mock(action='action', filters={}) + processed = factory.run_after_calls('foo', {'request': request}) + assert processed == 'foo' diff --git a/nefertari/tests/test_resource.py b/tests/test_resource.py similarity index 75% rename from nefertari/tests/test_resource.py rename to tests/test_resource.py index 802a1bb..0923044 100644 --- a/nefertari/tests/test_resource.py +++ b/tests/test_resource.py @@ -71,6 +71,11 @@ def __getattr__(self, attr): class TestResourceGeneration(Test): + def test_get_resource_map(self): + from nefertari.resource import get_resource_map + request = mock.Mock() + assert get_resource_map(request) == request.registry._resources_map + def test_basic_resources(self): from nefertari.resource import add_resource_routes add_resource_routes(self.config, DummyCrudView, 'message', 'messages') @@ -138,6 +143,22 @@ def test_resources_with_name_prefix(self): route_path('special_message', testing.DummyRequest(), id=1) ) + def test_resources_with_name_prefix_from_config(self): + from nefertari.resource import add_resource_routes + self.config.route_prefix = 'api' + add_resource_routes( + self.config, + DummyCrudView, + 'message', + 'messages', + name_prefix='foo_' + ) + + self.assertEqual( + '/api/messages/1', + route_path('api_foo_message', testing.DummyRequest(), id=1) + ) + class TestResourceRecognition(Test): def setUp(self): @@ -154,8 +175,8 @@ def setUp(self): self.app = TestApp(self.config.make_wsgi_app()) self.collection_path = '/messages' self.collection_name = 'messages' - self.member_path = '/messages/{id}' - self.member_name = 'message' + self.member_path = '/messages/{id}' + self.member_name = 'message' def test_get_collection(self): self.assertEqual(self.app.get('/messages').body, 'index') @@ -206,8 +227,8 @@ def test_delete_member(self): class TestResource(Test): - def test_default_view(self, *a): - from nefertari.resource import Resource, default_view + def test_get_default_view_path(self, *a): + from nefertari.resource import Resource, get_default_view_path m = Resource( self.config, @@ -216,15 +237,38 @@ def test_default_view(self, *a): ) self.assertEqual( - "nefertari.tests.views.group_members:GroupMembersView", - default_view(m) + "test_resource.views.group_members:GroupMembersView", + get_default_view_path(m) ) - #singular + # singular m = Resource(self.config, member_name='group_member') self.assertEqual( - "nefertari.tests.views.group_member:GroupMemberView", - default_view(m) + "test_resource.views.group_member:GroupMemberView", + get_default_view_path(m) + ) + + def test_get_default_view_path_resource_prefix(self, *a): + from nefertari.resource import Resource, get_default_view_path + + m = Resource( + self.config, + member_name='group_member', + collection_name='group_members' + ) + m.prefix = 'foo' + + self.assertEqual( + "test_resource.views.foo_group_members:FooGroupMembersView", + get_default_view_path(m) + ) + + # singular + m = Resource(self.config, member_name='group_member') + m.prefix = 'foo' + self.assertEqual( + "test_resource.views.foo_group_member:FooGroupMemberView", + get_default_view_path(m) ) def test_singular_resource(self, *a): @@ -255,35 +299,17 @@ def test_singular_resource(self, *a): grandpa_id=1, id=2) ) - self.assertEqual( - app.put('/grandpas').body, - app.post('/grandpas', params=dict(_method='PUT')).body - ) + self.assertEqual(app.put('/grandpas').body, '"update_many"') - self.assertEqual( - app.delete('/grandpas/1').body, - app.post('/grandpas/1', params=dict(_method='DELETE')).body - ) + self.assertEqual(app.delete('/grandpas/1').body, '"delete"') - self.assertEqual( - app.put('/thing').body, - app.post('/thing', params=dict(_method='PUT')).body - ) + self.assertEqual(app.put('/thing').body, '"update"') - self.assertEqual( - app.delete('/thing').body, - app.post('/thing', params=dict(_method='DELETE')).body - ) + self.assertEqual(app.delete('/thing').body, '"delete"') - self.assertEqual( - app.put('/grandpas/1/wife').body, - app.post('/grandpas/1/wife', params=dict(_method='PUT')).body - ) + self.assertEqual(app.put('/grandpas/1/wife').body, '"update"') - self.assertEqual( - app.delete('/grandpas/1/wife').body, - app.post('/grandpas/1/wife', params=dict(_method='DELETE')).body - ) + self.assertEqual(app.delete('/grandpas/1/wife').body, '"delete"') self.assertEqual('"show"', app.get('/grandpas/1').body) self.assertEqual('"show"', app.get('/grandpas/1/wife').body) @@ -304,48 +330,41 @@ def test_renderer_override(self, *args): config.begin() app = TestApp(config.make_wsgi_app()) - #no headers, user renderer==string.returns string + # no headers, user renderer==string.returns string self.assertEqual('"index"', app.get('/things').body) - #header is sting, renderer is string. returns string - self.assertEqual( - 'index', app.get('/things', - headers={'ACCEPT': 'text/plain'}).body) + # header is sting, renderer is string. returns string + self.assertEqual('index', app.get('/things', + headers={'ACCEPT': 'text/plain'}).body) - #header is json, renderer is string. returns json - self.assertEqual( - '"index"', app.get('/things', - headers={'ACCEPT': 'application/json'}).body) + # header is json, renderer is string. returns json + self.assertEqual('"index"', app.get('/things', + headers={'ACCEPT': 'application/json'}).body) - #no header. returns json + # no header. returns json self.assertEqual('"index"', app.get('/2things').body) - #header==json, renderer==json, returns json - self.assertEqual( - '"index"', app.get('/2things', - headers={'ACCEPT': 'application/json'}).body) + # header==json, renderer==json, returns json + self.assertEqual('"index"', app.get('/2things', + headers={'ACCEPT': 'application/json'}).body) - #header==text, renderer==json, returns string - self.assertEqual( - "index", app.get('/2things', - headers={'ACCEPT': 'text/plain'}).body) + # header==text, renderer==json, returns string + self.assertEqual("index", app.get('/2things', + headers={'ACCEPT': 'text/plain'}).body) # no header, no renderer. uses default_renderer, returns # View._default_renderer==nefertari_json self.assertEqual('"index"', app.get('/3things').body) - self.assertEqual( - '"index"', app.get('/3things', - headers={'ACCEPT': 'application/json'}).body) + self.assertEqual('"index"', app.get('/3things', + headers={'ACCEPT': 'application/json'}).body) - self.assertEqual( - 'index', app.get('/3things', - headers={'ACCEPT': 'text/plain'}).body) + self.assertEqual('index', app.get('/3things', + headers={'ACCEPT': 'text/plain'}).body) - #bad accept.defaults to json - self.assertEqual( - '"index"', app.get('/3things', - headers={'ACCEPT': 'text/blablabla'}).body) + # bad accept.defaults to json + self.assertEqual('"index"', app.get('/3things', + headers={'ACCEPT': 'text/blablabla'}).body) def test_nonBaseView_default_renderer(self, *a): config = _create_config() @@ -364,13 +383,54 @@ def test_nested_resources(self, *a): aa = root.add('a', 'as', view=get_test_view_class('A')) bb = aa.add('b', 'bs', view=get_test_view_class('B')) cc = bb.add('c', 'cs', view=get_test_view_class('C')) - dd = cc.add('d', 'ds', view=get_test_view_class('D')) + cc.add('d', 'ds', view=get_test_view_class('D')) config.begin() app = TestApp(config.make_wsgi_app()) app.get('/as/1/bs/2/cs/3/ds/4') + def test_add_resource_prefix(self, *a): + config = _create_config() + root = config.get_root_resource() + resource = root.add( + 'message', 'messages', + view=get_test_view_class('A'), + prefix='api') + assert resource.uid == 'api:message' + + config.begin() + + self.assertEqual( + '/api/messages', + route_path('api:messages', testing.DummyRequest()) + ) + + def test_add_resource_view_args(self, *a): + config = _create_config() + root = config.get_root_resource() + view = get_test_view_class('A') + assert not hasattr(view, 'foo') + root.add('message', 'messages', view=view, + view_args={'foo': 'bar'}) + assert view.foo == 'bar' + + def test_nested_resource_id_name(self, *a): + config = _create_config() + root = config.get_root_resource() + + aa = root.add( + 'a', 'as', view=get_test_view_class('A'), + id_name='super_id') + aa.add('b', 'bs', view=get_test_view_class('B')) + + config.begin() + + self.assertEqual( + '/as/1/bs', + route_path('a:bs', testing.DummyRequest(), super_id=1) + ) + # @mock.patch('nefertari.resource.add_tunneling') class TestMockedResource(Test): @@ -479,7 +539,7 @@ def test_add_resource_routes_with_parent_param(self, *arg): m_add_resource_routes = arg[0] m = Resource(self.config) - g = m.add('grandpa', 'grandpas', view=View) + m.add('grandpa', 'grandpas', view=View) m.add('parent', 'parents', parent='grandpa', view=View) m_add_resource_routes.assert_called_with( @@ -531,13 +591,13 @@ def test_add_resource_routes_from(self, *args): gm = root.add('grandma', 'grandmas', view=View) pa = gm.add('parent', 'parents', view=View) boy = pa.add('boy', 'boys', view=View) - grchild = boy.add('child', 'children', view=View) + boy.add('child', 'children', view=View) girl = pa.add('girl', 'girls', view=View) self.assertEqual(len(root.resource_map), 5) gp = root.add('grandpa', 'grandpas', view=View) - gp.add_from(pa, view=View) + gp.add_from_child(pa, view=View) self.assertEqual( pa.children[0], diff --git a/tests/test_tweens.py b/tests/test_tweens.py new file mode 100644 index 0000000..d7ba8cd --- /dev/null +++ b/tests/test_tweens.py @@ -0,0 +1,252 @@ +from mock import Mock, patch + +from nefertari import tweens + + +def mock_timer(): + mock_timer.time = 0 + + def time_func(): + mock_timer.time += 1 + return mock_timer.time + return time_func + + +class DummyConfigurator(object): + def __init__(self): + self.subscribed = [] + + def add_subscriber(self, wrapped, ifaces): + self.subscribed.append((wrapped, ifaces)) + + +class TestTweens(object): + + @patch('nefertari.tweens.time') + @patch('nefertari.tweens.log') + def test_request_timing(self, mock_log, mock_time): + mock_time.time = mock_timer() + request = Mock(method='GET', url='http://example.com') + registry = Mock() + registry.settings = {'request_timing.slow_request_threshold': 1000} + handler = lambda request: request + timing = tweens.request_timing(handler, registry) + timing(request) + mock_log.debug.assert_called_once_with( + 'GET (http://example.com) request took 1 seconds') + assert not mock_log.warning.called + + @patch('nefertari.tweens.time') + @patch('nefertari.tweens.log') + def test_request_timing_slow_request(self, mock_log, mock_time): + mock_time.time = mock_timer() + request = Mock(method='GET', url='http://example.com') + registry = Mock() + registry.settings = {'request_timing.slow_request_threshold': 0} + handler = lambda request: request + timing = tweens.request_timing(handler, registry) + timing(request) + mock_log.warning.assert_called_once_with( + 'GET (http://example.com) request took 1 seconds') + assert not mock_log.debug.called + + def test_get_tunneling(self): + class GET(dict): + def mixed(self): + return self + + request = Mock(GET=GET({'_m': 'POST', 'foo': 'bar'}), method='GET') + get_tunneling = tweens.get_tunneling(lambda x: x, None) + get_tunneling(request) + assert request.GET == {"foo": "bar"} + assert request.method == 'POST' + assert request.content_type == 'application/json' + assert request.body == '{"foo": "bar"}' + + def test_get_tunneling_not_allowed_method(self): + class GET(dict): + def mixed(self): + return self + + request = Mock( + GET=GET({'_m': 'DELETE', 'foo': 'bar'}), method='GET', + body=None, content_type=None) + get_tunneling = tweens.get_tunneling(lambda x: x, None) + get_tunneling(request) + assert request.GET == {"foo": "bar"} + assert request.method == 'DELETE' + assert request.content_type is None + assert request.body is None + + def test_cors_no_origins_no_creds(self): + registry = Mock(settings={ + 'cors.allow_origins': '', + 'cors.allow_credentials': None, + }) + handler = lambda x: Mock(headerlist=[]) + request = Mock( + headers={'Origin': '127.0.0.1:8080'}, + host_url='127.0.0.1:8080') + response = tweens.cors(handler, registry)(request) + assert response.headerlist == [] + + def test_cors_disallow_creds(self): + registry = Mock(settings={ + 'cors.allow_origins': '', + 'cors.allow_credentials': False, + }) + handler = lambda x: Mock(headerlist=[]) + request = Mock( + headers={'Origin': '127.0.0.1:8080'}, + host_url='127.0.0.1:8080') + response = tweens.cors(handler, registry)(request) + assert response.headerlist == [ + ('Access-Control-Allow-Credentials', False)] + + def test_cors_allow_creds_and_origin(self): + registry = Mock(settings={ + 'cors.allow_origins': '127.0.0.1:8080,127.0.0.1:8090', + 'cors.allow_credentials': True, + }) + handler = lambda x: Mock(headerlist=[]) + request = Mock( + headers={'Origin': '127.0.0.1:8080'}, + host_url='127.0.0.1:8080') + response = tweens.cors(handler, registry)(request) + assert response.headerlist == [ + ('Access-Control-Allow-Origin', '127.0.0.1:8080'), + ('Access-Control-Allow-Credentials', True)] + + def test_cors_wrong_origin(self): + registry = Mock(settings={ + 'cors.allow_origins': '127.0.0.1:8080,127.0.0.1:8090', + 'cors.allow_credentials': None, + }) + handler = lambda x: Mock(headerlist=[]) + request = Mock( + headers={'Origin': '127.0.0.1:8000'}, + host_url='127.0.0.1:8000') + response = tweens.cors(handler, registry)(request) + assert response.headerlist == [] + + def test_cors_source_or_host_url(self): + registry = Mock(settings={ + 'cors.allow_origins': '127.0.0.1:8080,127.0.0.1:8090', + 'cors.allow_credentials': None, + }) + handler = lambda x: Mock(headerlist=[]) + request = Mock( + headers={'Origin': '127.0.0.1:8080'}, + host_url='') + response = tweens.cors(handler, registry)(request) + assert response.headerlist == [ + ('Access-Control-Allow-Origin', '127.0.0.1:8080')] + + request = Mock( + headers={}, + host_url='127.0.0.1:8080') + response = tweens.cors(handler, registry)(request) + assert response.headerlist == [ + ('Access-Control-Allow-Origin', '127.0.0.1:8080')] + + def test_cors_allow_origins_star(self): + registry = Mock(settings={ + 'cors.allow_origins': '*', + 'cors.allow_credentials': True, + }) + handler = lambda x: Mock(headerlist=[]) + cors = tweens.cors(handler, registry) + assert cors is None + + def test_cache_control_header_not_set(self): + handler = lambda x: Mock(headerlist=[('Cache-Control', '')]) + response = tweens.cache_control(handler, None)(None) + assert not response.cache_expires.called + + def test_cache_control_header_set(self): + handler = lambda x: Mock(headerlist=[]) + response = tweens.cache_control(handler, None)(None) + response.cache_expires.assert_called_once_with(0) + + def test_ssl_url_scheme(self): + request = Mock( + scheme=None, + environ={'HTTP_X_URL_SCHEME': 'Foo'} + ) + tweens.ssl(lambda x: x, None)(request) + assert request.environ['wsgi.url_scheme'] == 'foo' + assert request.scheme == 'foo' + + def test_ssl_forwarded_proto(self): + request = Mock( + scheme=None, + environ={'HTTP_X_FORWARDED_PROTO': 'Foo'} + ) + tweens.ssl(lambda x: x, None)(request) + assert request.environ['wsgi.url_scheme'] == 'foo' + assert request.scheme == 'foo' + + def test_ssl_no_scheme(self): + request = Mock(scheme=None, environ={}) + tweens.ssl(lambda x: x, None)(request) + assert request.environ == {} + assert request.scheme is None + + def test_enable_selfalias(self): + from pyramid.events import ContextFound + config = DummyConfigurator() + assert config.subscribed == [] + tweens.enable_selfalias(config, 'foo') + assert len(config.subscribed) == 1 + assert callable(config.subscribed[0][0]) + assert config.subscribed[0][1] is ContextFound + + def test_context_found_subscriber_alias_enabled(self): + config = DummyConfigurator() + tweens.enable_selfalias(config, 'foo') + context_found_subscriber = config.subscribed[0][0] + request = Mock( + user=Mock(username='user12'), + matchdict={'foo': 'self'}) + context_found_subscriber(Mock(request=request)) + assert request.matchdict['foo'] == 'user12' + + def test_context_found_subscriber_no_matchdict(self): + config = DummyConfigurator() + tweens.enable_selfalias(config, 'foo') + context_found_subscriber = config.subscribed[0][0] + request = Mock( + user=Mock(username='user12'), + matchdict=None) + context_found_subscriber(Mock(request=request)) + assert request.matchdict is None + + def test_context_found_subscriber_not_self(self): + config = DummyConfigurator() + tweens.enable_selfalias(config, 'foo') + context_found_subscriber = config.subscribed[0][0] + request = Mock( + user=Mock(username='user12'), + matchdict={'foo': '1'}) + context_found_subscriber(Mock(request=request)) + assert request.matchdict['foo'] == '1' + + def test_context_found_subscriber_not_authenticated(self): + config = DummyConfigurator() + tweens.enable_selfalias(config, 'foo') + context_found_subscriber = config.subscribed[0][0] + request = Mock( + user=None, + matchdict={'foo': 'self'}) + context_found_subscriber(Mock(request=request)) + assert request.matchdict['foo'] == 'self' + + def test_context_found_subscriber_wrong_id_name(self): + config = DummyConfigurator() + tweens.enable_selfalias(config, 'foo') + context_found_subscriber = config.subscribed[0][0] + request = Mock( + user=Mock(username='user12'), + matchdict={'qoo': 'self'}) + context_found_subscriber(Mock(request=request)) + assert request.matchdict['qoo'] == 'self' diff --git a/tests/test_utility_views.py b/tests/test_utility_views.py new file mode 100644 index 0000000..5b83d70 --- /dev/null +++ b/tests/test_utility_views.py @@ -0,0 +1,43 @@ +from mock import Mock + +from nefertari import utility_views as uviews + + +class TestOptionsView(object): + header_str = 'HEAD, TRACE, GET, PATCH, PUT, POST, OPTIONS, DELETE' + + def test_call_methods_header(self): + response = Mock(headers={}) + request = Mock( + headers={'Access-Control-Request-Method': ''}, + response=response) + resp = uviews.OptionsView(request=request)() + assert resp is response + assert response.headers == { + 'Allow': self.header_str, + 'Access-Control-Allow-Methods': self.header_str, + } + + def test_call_headers_header(self): + response = Mock(headers={}) + request = Mock( + headers={'Access-Control-Request-Headers': ''}, + response=response) + resp = uviews.OptionsView(request=request)() + assert resp is response + assert response.headers == { + 'Allow': self.header_str, + 'Access-Control-Allow-Headers': ( + 'origin, x-requested-with, content-type'), + } + + def test_call_no_headers(self): + response = Mock(headers={}) + request = Mock( + headers={}, + response=response) + resp = uviews.OptionsView(request=request)() + assert resp is response + assert response.headers == { + 'Allow': self.header_str, + } diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_utils/test_data.py b/tests/test_utils/test_data.py new file mode 100644 index 0000000..6eb1bd6 --- /dev/null +++ b/tests/test_utils/test_data.py @@ -0,0 +1,120 @@ +from nefertari.utils import data as dutils + + +class DummyModel(dict): + def to_dict(self, *args, **kwargs): + return self + + +class TestDataUtils(object): + + def test_data_proxy_not_model(self): + proxy = dutils.DataProxy({'foo': 'bar'}) + data = proxy.to_dict() + assert data == {'_type': 'DataProxy', 'foo': 'bar'} + + def test_data_proxy_not_model_keys(self): + proxy = dutils.DataProxy({'foo': 'bar', 'id': 1}) + data = proxy.to_dict(_keys=['foo']) + assert data == {'_type': 'DataProxy', 'foo': 'bar'} + + def test_data_proxy_model(self): + obj = DummyModel({'foo1': 'bar1'}) + proxy = dutils.DataProxy({'foo': obj}) + data = proxy.to_dict() + assert data == {'_type': 'DataProxy', 'foo': {'foo1': 'bar1'}} + + def test_data_proxy_model_keys(self): + obj = DummyModel({'foo1': 'bar1'}) + proxy = dutils.DataProxy({'foo': obj, 'id': 1}) + data = proxy.to_dict(_keys=['foo']) + assert data == {'_type': 'DataProxy', 'foo': {'foo1': 'bar1'}} + + def test_data_proxy_model_no_depth(self): + obj = DummyModel({'foo1': 'bar1'}) + proxy = dutils.DataProxy({'foo': obj}) + data = proxy.to_dict(__depth=0) + assert data == {'_type': 'DataProxy', 'foo': obj} + + def test_data_proxy_model_sequence(self): + obj = DummyModel({'foo1': 'bar1'}) + proxy = dutils.DataProxy({'foo': [obj]}) + data = proxy.to_dict() + assert data == {'_type': 'DataProxy', 'foo': [{'foo1': 'bar1'}]} + + def test_dict2obj_regular_value(self): + obj = dutils.dict2obj({'foo': 'bar', 'baz': 1}) + assert isinstance(obj, dutils.DataProxy) + assert obj.foo == 'bar' + assert obj.baz == 1 + + def test_dict2obj_dict_value(self): + obj = dutils.dict2obj({'foo': {'baz': 1}}) + assert isinstance(obj, dutils.DataProxy) + assert isinstance(obj.foo, dutils.DataProxy) + assert obj.foo.baz == 1 + + def test_dict2obj_list_value(self): + obj = dutils.dict2obj({'foo': [{'baz': 1}]}) + assert isinstance(obj, dutils.DataProxy) + assert isinstance(obj.foo, list) + assert len(obj.foo) == 1 + assert isinstance(obj.foo[0], dutils.DataProxy) + assert obj.foo[0].baz == 1 + + def test_dict2obj_no_data(self): + assert dutils.dict2obj({}) == {} + + def test_to_objs(self): + collection = dutils.to_objs([{'foo': 'bar'}]) + assert len(collection) == 1 + assert isinstance(collection[0], dutils.DataProxy) + assert collection[0].foo == 'bar' + + def test_to_dicts_regular_case(self): + collection = [DummyModel({'foo': 'bar'})] + dicts = dutils.to_dicts(collection) + assert dicts == [{'foo': 'bar'}] + + def test_to_dicts_with_key(self): + collection = [DummyModel({'foo': 'bar', 'id': '1'})] + dicts = dutils.to_dicts(collection, key=lambda d: {'super': d['foo']}) + assert dicts == [{'super': 'bar'}] + + def test_to_dicts_attr_error(self): + obj = DummyModel({'foo': 'bar'}) + dicts = dutils.to_dicts([obj, {'a': 'b'}]) + assert dicts == [obj, {'a': 'b'}] + + def test_to_dicts_type_error(self): + def key(d): + raise TypeError() + obj = DummyModel({'foo': 'bar'}) + dicts = dutils.to_dicts([obj], key=key) + assert dicts == [obj] + + def test_obj2dict_dict(self): + assert dutils.obj2dict({'foo': 'bar'}) == {'foo': 'bar'} + + def test_obj2dict_list(self): + assert dutils.obj2dict([{'foo': 'bar'}]) == [{'foo': 'bar'}] + + def test_obj2dict_object(self): + class A(object): + pass + obj = A() + obj.foo = 'bar' + assert dutils.obj2dict(obj) == {'foo': 'bar'} + + def test_obj2dict_object_classkey(self): + class A(object): + pass + obj = A() + obj.foo = 'bar' + assert dutils.obj2dict(obj, classkey='kls') == { + 'foo': 'bar', 'kls': 'A'} + + def test_obj2dict_simple_types(self): + assert dutils.obj2dict(1) == 1 + assert dutils.obj2dict('foo') == 'foo' + assert dutils.obj2dict(None) is None diff --git a/tests/test_utils/test_dictset.py b/tests/test_utils/test_dictset.py new file mode 100644 index 0000000..7f2ff87 --- /dev/null +++ b/tests/test_utils/test_dictset.py @@ -0,0 +1,281 @@ +from datetime import datetime + +import pytest + +from nefertari.utils.dictset import dictset + + +class TestDictset(object): + def test_copy(self): + dset1 = dictset({'foo': 'bar'}) + dset2 = dset1.copy() + assert dset1 is not dset2 + assert dset1 == dset2 + + def test_subset_no_keys(self): + dset1 = dictset({'foo': 'bar', 'fruit': 'apple'}) + dset2 = dset1.subset([]) + assert dset1 is not dset2 + assert dset1 != dset2 + + def test_subset(self): + dset1 = dictset({'foo': 'bar', 'fruit': 'apple'}) + dset2 = dset1.subset(['foo', 'nonexisting']) + assert dict(dset2) == {'foo': 'bar'} + + def test_subset_exclude(self): + dset1 = dictset({'foo': 'bar', 'fruit': 'apple'}) + dset2 = dset1.subset(['-fruit', '-nonexisting']) + assert dict(dset2) == {'foo': 'bar'} + + def test_remove(self): + dset1 = dictset({'foo': 'bar', 'fruit': 'apple'}) + dset2 = dset1.remove(['fruit']) + assert dict(dset2) == {'foo': 'bar'} + + def test_getattr(self): + dset1 = dictset({'foo': 'bar', 'fruit': 'apple'}) + assert dset1.foo == 'bar' + assert dset1.fruit == 'apple' + + def test_setattr(self): + dset1 = dictset() + dset1.boo = 1 + assert dict(dset1) == {'boo': 1} + + def test_asbool(self): + dset1 = dictset({ + 'foo': 'true', 'fruit': 'false', + 'baz': True, 'zoo': False}) + assert dset1.asbool('foo') + assert dset1.asbool('baz') + assert isinstance(dset1.asbool('foo'), bool) + assert isinstance(dset1.asbool('baz'), bool) + assert not dset1.asbool('fruit') + assert not dset1.asbool('zoo') + + def test_asbool_set(self): + dset1 = dictset({'foo': 'true'}) + assert dset1.asbool('foo', _set=True) + assert dset1.foo + assert dset1.foo != 'true' + + def test_asbool_pop(self): + dset1 = dictset({'foo': 'true'}) + assert dset1.asbool('foo', pop=True) + assert 'foo' not in dset1 + + def test_asbool_default(self): + dset1 = dictset({'foo': 'true'}) + assert dset1.asbool('foo1', default=True) + assert 'foo1' not in dset1 + + def test_aslist(self): + dset1 = dictset({'foo': '1,2, 3'}) + assert dset1.aslist('foo') == ['1', '2', '3'] + + def test_aslist_set(self): + dset1 = dictset({'foo': '1,2,3'}) + assert dset1.aslist('foo', _set=True) == ['1', '2', '3'] + assert dset1.foo == ['1', '2', '3'] + + def test_aslist_default(self): + dset1 = dictset({'foo': '1,2,3'}) + assert dset1.aslist('foo1', default=['1']) == ['1'] + assert 'foo1' not in dset1 + + def test_asint(self): + assert dictset({'foo': '1'}).asint('foo') == 1 + + def test_asint_set(self): + dset = dictset({'foo': '1'}) + assert dset.asint('foo', _set=True) == 1 + assert dset.foo == 1 + + def test_asint_default(self): + dset = dictset({'foo': '1'}) + assert dset.asint('foo1', default=2) == 2 + assert 'foo1' not in dset + + def test_asfloat(self): + assert dictset({'foo': '1.0'}).asfloat('foo') == 1.0 + + def test_asfloat_set(self): + dset = dictset({'foo': '1.0'}) + assert dset.asfloat('foo', _set=True) == 1.0 + assert dset.foo == 1.0 + + def test_asfloat_default(self): + dset = dictset({'foo': '1.0'}) + assert dset.asfloat('foo1', default=2.0) == 2.0 + assert 'foo1' not in dset + + def test_asdict(self): + dset = dictset({'foo': "a:2,b:blabla,c:True,a:'d',a:1"}) + assert dset.asdict('foo') == { + 'a': ['2', "'d'", '1'], 'b': 'blabla', 'c': 'True'} + + def test_asdict_type(self): + dset = dictset({'foo': "a:2,a:1"}) + assert dset.asdict('foo', _type=lambda x: int(x)) == {'a': [2, 1]} + + def test_asdict_set(self): + dset = dictset({'foo': "a:2,b:blabla,c:True,a:'d',a:1"}) + assert dset.asdict('foo', _set=True) == { + 'a': ['2', "'d'", '1'], 'b': 'blabla', 'c': 'True'} + assert dset.foo == { + 'a': ['2', "'d'", '1'], 'b': 'blabla', 'c': 'True'} + + def test_asdict_wrong_key(self): + dset = dictset({'foo': "a:2,b:blabla,c:True,a:'d',a:1"}) + assert dset.asdict('boo') == {} + + def test_mget(self): + dset = dictset({'foo.key1': '1', 'foo.key2': '2', 'boo': '3'}) + dset2 = dset.mget('foo') + assert dset2 == {'key1': '1', 'key2': '2'} + + def test_mget_defaults(self): + dset = dictset({'foo.key1': '1', 'foo.key2': '2', 'boo': '3'}) + dset2 = dset.mget('foo', defaults={'q': 1}) + assert dset2 == {'key1': '1', 'key2': '2', 'q': 1} + + def test_update(self): + dset = dictset({'boo': '3'}) + dset2 = dset.update(foo=1) + assert dset.foo == 1 + assert dset is dset2 + + def test_process_list_param_string(self): + dset = dictset({'boo': '1,2'}) + assert dset.process_list_param('boo') == ['1', '2'] + assert dset.boo == ['1', '2'] + + def test_process_list_param_setdefault(self): + dset = dictset({'boo': '1,2'}) + assert dset.process_list_param('foo', setdefault=[1]) == [1] + assert dset.foo == [1] + + def test_process_list_param_default(self): + dset = dictset({'boo': '1,2'}) + assert dset.process_list_param('foo', default=[1]) == [1] + assert 'foo' not in dset + + def test_process_list_param_pop(self): + dset = dictset({'boo': '1,2'}) + assert dset.process_list_param('boo', pop=True) == ['1', '2'] + assert 'boo' not in dset + + def test_process_list_param_pop_default(self): + dset = dictset({'boo': '1,2'}) + assert dset.process_list_param('foo', default=[1], pop=True) == [1] + assert 'foo' not in dset + + def test_process_list_param_type(self): + dset = dictset({'boo': '1,2'}) + assert dset.process_list_param('boo', _type=lambda x: int(x)) == [1, 2] + + def test_process_bool_param(self): + dset = dictset({'boo': 'true', 'foo': 'false'}) + assert dset.process_bool_param('boo') + assert not dset.process_bool_param('foo') + assert dset.boo + assert not dset.foo + assert isinstance(dset.boo, bool) + assert isinstance(dset.foo, bool) + + def test_process_bool_param_default(self): + dset = dictset({'boo': 'true'}) + assert dset.process_bool_param('foo', default=True) + assert dset.foo + assert isinstance(dset.foo, bool) + + def test_pop_bool_param(self): + dset = dictset({'boo': 'true'}) + param = dset.pop_bool_param('boo') + assert param + assert isinstance(param, bool) + assert 'boo' not in dset + + def test_pop_bool_param_default(self): + dset = dictset({'boo': 'true'}) + param = dset.pop_bool_param('foo', default=True) + assert param + assert isinstance(param, bool) + assert 'foo' not in dset + + def test_process_datetime_param(self): + dset = dictset({'boo': '2014-01-02T03:04:05Z'}) + dtime = dset.process_datetime_param('boo') + assert dtime is dset.boo + assert dtime == dset.boo + assert isinstance(dtime, datetime) + assert dtime.year == 2014 + assert dtime.month == 1 + assert dtime.day == 2 + assert dtime.hour == 3 + assert dtime.minute == 4 + assert dtime.second == 5 + + def test_process_datetime_param_wrong_format(self): + dset = dictset({'boo': '2014-01-'}) + with pytest.raises(ValueError) as ex: + dset.process_datetime_param('boo') + assert 'Bad format' in str(ex.value) + + def test_process_float_param(self): + dset = dictset({'boo': '1.5'}) + assert dset.process_float_param('boo') == 1.5 + assert dset.boo == 1.5 + + def test_process_float_param_value_err(self): + dset = dictset({'boo': 'a'}) + with pytest.raises(ValueError) as ex: + dset.process_float_param('boo') + assert 'boo must be a decimal' == str(ex.value) + + def test_process_float_param_default(self): + dset = dictset({'boo': '1.5'}) + assert dset.process_float_param('foo', default=2.5) == 2.5 + assert dset.foo == 2.5 + + def test_process_int_param(self): + dset = dictset({'boo': '1'}) + assert dset.process_int_param('boo') == 1 + assert dset.boo == 1 + + def test_process_int_param_value_err(self): + dset = dictset({'boo': 'a'}) + with pytest.raises(ValueError) as ex: + dset.process_int_param('boo') + assert 'boo must be a decimal' == str(ex.value) + + def test_process_int_param_default(self): + dset = dictset({'boo': '1'}) + assert dset.process_int_param('foo', default=2) == 2 + assert dset.foo == 2 + + def test_process_dict_param(self): + dset = dictset({'boo': 'a:1'}) + assert dset.process_dict_param('boo') == {'a': '1'} + assert dset.boo == {'a': '1'} + + def test_process_dict_param_type(self): + dset = dictset({'boo': 'a:1'}) + assert dset.process_dict_param('boo', _type=lambda x: int(x)) == { + 'a': 1} + assert dset.boo == {'a': 1} + + def test_process_dict_param_pop(self): + dset = dictset({'boo': 'a:1'}) + assert dset.process_dict_param('boo', pop=True) == {'a': '1'} + assert 'boo' not in dset + + def test_pop_by_values(self): + dset = dictset({'boo': '1', 'foo': '2'}) + dset.pop_by_values('3') + assert dict(dset) == {'boo': '1', 'foo': '2'} + dset.pop_by_values('1') + assert dict(dset) == {'foo': '2'} + dset.pop_by_values('2') + assert dict(dset) == {} diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py new file mode 100644 index 0000000..cbc39e4 --- /dev/null +++ b/tests/test_utils/test_utils.py @@ -0,0 +1,120 @@ +import pytest +from mock import patch, call + +from nefertari.utils import utils + + +class TestUtils(object): + + @patch('nefertari.utils.utils.json') + def test_json_dumps(self, mock_json): + from nefertari.renderers import _JSONEncoder + utils.json_dumps('foo') + mock_json.dumps.assert_called_once_with('foo', cls=_JSONEncoder) + + @patch('nefertari.utils.utils.json') + def test_json_dumps_encoder(self, mock_json): + utils.json_dumps('foo', 'enc') + mock_json.dumps.assert_called_once_with('foo', cls='enc') + + def test_split_strip(self): + assert utils.split_strip('1, 2,') == ['1', '2'] + assert utils.split_strip('1, 2') == ['1', '2'] + assert utils.split_strip('1;2;', on=';') == ['1', '2'] + + def test_process_limit_start_and_page(self): + with pytest.raises(ValueError) as ex: + utils.process_limit(1, 2, 3) + assert 'at the same time' in str(ex.value) + + def test_process_limit_start(self): + start, limit = utils.process_limit(start=1, page=None, limit=5) + assert start == 1 + assert limit == 5 + + def test_process_limit_page(self): + start, limit = utils.process_limit(start=None, page=2, limit=5) + assert start == 10 + assert limit == 5 + + def test_process_limit_no_start_page(self): + start, limit = utils.process_limit(start=None, page=None, limit=5) + assert start == 0 + assert limit == 5 + + def test_process_limit_lower_than_zero(self): + with pytest.raises(ValueError) as ex: + utils.process_limit(1, None, -3) + assert 'can not be < 0' in str(ex.value) + with pytest.raises(ValueError) as ex: + utils.process_limit(-1, None, 3) + assert 'can not be < 0' in str(ex.value) + + def test_extend_list_string(self): + assert utils.extend_list('foo, bar,') == ['foo', 'bar'] + + def test_extend_list_sequence_string(self): + assert utils.extend_list(['foo, bar,']) == ['foo', 'bar'] + + def test_extend_list_sequence_elements(self): + assert utils.extend_list(['foo', 'bar', '1,2']) == [ + 'foo', 'bar', '1', '2'] + + def test_process_fields_string(self): + only, exclude = utils.process_fields('a,b,-c') + assert only == ['a', 'b'] + assert exclude == ['c'] + + def test_process_fields_empty_field(self): + only, exclude = utils.process_fields(['a', 'b', '-c', '']) + assert only == ['a', 'b'] + assert exclude == ['c'] + + def test_snake2camel(self): + assert utils.snake2camel('foo_bar') == 'FooBar' + assert utils.snake2camel('foobar') == 'Foobar' + + @patch('nefertari.utils.utils.Configurator') + def test_maybe_dotted(self, mock_conf): + result = utils.maybe_dotted('foo.bar') + mock_conf.assert_called_once_with() + mock_conf().maybe_dotted.assert_called_once_with('foo.bar') + assert result == mock_conf().maybe_dotted() + + @patch('nefertari.utils.utils.Configurator') + def test_maybe_dotted_err_throw(self, mock_conf): + mock_conf.side_effect = ImportError + with pytest.raises(ImportError): + utils.maybe_dotted('foo.bar', throw=True) + + @patch('nefertari.utils.utils.Configurator') + def test_maybe_dotted_err_no_throw(self, mock_conf): + mock_conf.side_effect = ImportError + assert utils.maybe_dotted('foo.bar', throw=False) is None + + @patch('nefertari.utils.utils.os') + def test_chdir(self, mock_os): + with utils.chdir('/tmp'): + pass + mock_os.getcwd.assert_called_once_with() + mock_os.chdir.assert_has_calls([ + call('/tmp'), call(mock_os.getcwd()) + ]) + + def test_isnumeric(self): + from decimal import Decimal + assert not utils.isnumeric('asd') + assert not utils.isnumeric(dict()) + assert not utils.isnumeric([]) + assert not utils.isnumeric(()) + assert utils.isnumeric(1) + assert utils.isnumeric(2.0) + assert utils.isnumeric(Decimal(1)) + + def test_issequence(self): + assert utils.issequence(dict()) + assert utils.issequence([]) + assert utils.issequence(()) + assert not utils.issequence('asd') + assert not utils.issequence(1) + assert not utils.issequence(2.0) diff --git a/tests/test_view.py b/tests/test_view.py new file mode 100644 index 0000000..1e60206 --- /dev/null +++ b/tests/test_view.py @@ -0,0 +1,564 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import pytest +from mock import Mock, MagicMock, patch, call, PropertyMock + +from nefertari.view import ( + BaseView, error_view, key_error_view, value_error_view) +from nefertari.json_httpexceptions import ( + JHTTPBadRequest, JHTTPNotFound, JHTTPMethodNotAllowed) +from nefertari.wrappers import wrap_me, ValidationError, ResourceNotFound + + +class TestViewMapper(object): + + def test_viewmapper(self): + from nefertari.view import ViewMapper + + bc1 = Mock() + bc3 = Mock() + bc2 = Mock() + + class MyView(object): + def __init__(self, ctx, req): + self._before_calls = {'index': [bc1], 'show': [bc3]} + self._after_calls = {'show': [bc2]} + + @wrap_me(before=bc2) + def index(self): + return ['thing'] + + request = MagicMock() + resource = MagicMock(actions=['index']) + + wrapper = ViewMapper(**{'attr': 'index'})(MyView) + result = wrapper(resource, request) + + assert request.filters == {'show': [bc2]} + assert request.action == 'index' + assert result == ['thing'] + + bc1.assert_called_with(request=request) + assert not bc2.called + assert not bc3.called + + def test_viewmapper_bad_request(self): + from nefertari.view import ViewMapper + + bc1 = Mock(side_effect=ValidationError) + + class MyView(object): + def __init__(self, ctx, req): + self._before_calls = {'index': [bc1]} + self._after_calls = {} + + def index(self): + return ['thing'] + + request = Mock() + resource = Mock(actions=['index']) + wrapper = ViewMapper(**{'attr': 'index'})(MyView) + with pytest.raises(JHTTPBadRequest): + wrapper(resource, request) + + def test_viewmapper_not_found(self): + from nefertari.view import ViewMapper + + bc1 = Mock(side_effect=ResourceNotFound) + + class MyView(object): + def __init__(self, ctx, req): + self._before_calls = {'index': [bc1]} + self._after_calls = {} + + def index(self): + return ['thing'] + + request = Mock() + resource = Mock(actions=['index']) + wrapper = ViewMapper(**{'attr': 'index'})(MyView) + with pytest.raises(JHTTPNotFound): + wrapper(resource, request) + + +class TestBaseView(object): + + def test_baseview(self, *a): + + class UsersView(BaseView): + + def __init__(self, context, request): + BaseView.__init__(self, context, request) + + def show(self, id): + return u'John Doe' + + def convert_ids2objects(self, *args, **kwargs): + pass + + request = MagicMock(content_type='') + request.matched_route.pattern = '/users' + view = UsersView(request.context, request) + + assert u'John Doe' == view.show(1) + + with pytest.raises(JHTTPMethodNotAllowed): + view.index() + + with pytest.raises(AttributeError): + view.frobnicate() + + # delete is an allowed action, but it raises since BaseView + # does not implement it. + with pytest.raises(JHTTPMethodNotAllowed): + view.delete() + + def test_convert_dotted(self): + converted = BaseView.convert_dotted({ + 'settings.foo': 'bar', + 'option': 'value' + }) + assert converted['settings'] == {'foo': 'bar'} + assert converted['option'] == 'value' + assert 'settings.foo' not in converted + + def test_convert_dotted_no_dotted(self): + converted = BaseView.convert_dotted({ + 'option': 'value' + }) + assert converted == {'option': 'value'} + + @patch('nefertari.view.BaseView._run_init_actions') + def test_init(self, run): + request = Mock( + content_type='application/json', + json={'param1.foo': 'val1', 'param3': 'val3'}, + method='POST', + accept=[''], + ) + request.params.mixed.return_value = {'param2.foo': 'val2'} + view = BaseView(context={'foo': 'bar'}, request=request) + run.assert_called_once_with() + assert request.override_renderer == 'nefertari_json' + assert list(sorted(view._params.keys())) == [ + 'param1', 'param2', 'param3'] + assert view._params['param1'] == {'foo': 'val1'} + assert view._params['param2'] == {'foo': 'val2'} + assert view._params['param3'] == 'val3' + assert view.request == request + assert view.context == {'foo': 'bar'} + assert view._before_calls == {} + assert view._after_calls == {} + + @patch('nefertari.view.BaseView._run_init_actions') + def test_init_json_accept_header(self, run): + request = Mock( + content_type='application/json', + json={'param1.foo': 'val1', 'param3': 'val3'}, + method='POST', + accept=['application/json'], + ) + request.params.mixed.return_value = {'param2.foo': 'val2'} + BaseView(context={'foo': 'bar'}, request=request) + assert request.override_renderer == 'nefertari_json' + + @patch('nefertari.view.BaseView._run_init_actions') + def test_init_text_ct_and_accept(self, run): + request = Mock( + content_type='text/plain', + json={'param1.foo': 'val1', 'param3': 'val3'}, + method='POST', + accept=['text/plain'], + ) + request.params.mixed.return_value = {'param2.foo': 'val2'} + view = BaseView(context={'foo': 'bar'}, request=request) + assert request.override_renderer == 'string' + assert view._params.keys() == ['param2'] + + @patch('nefertari.view.BaseView._run_init_actions') + def test_init_json_error(self, run): + import simplejson + request = Mock( + content_type='application/json', + method='POST', + accept=['application/json'], + ) + type(request).json = PropertyMock( + side_effect=simplejson.JSONDecodeError( + 'foo', 'asdasdasdasd', pos=1)) + request.params.mixed.return_value = {'param2.foo': 'val2'} + view = BaseView(context={'foo': 'bar'}, request=request) + assert request.override_renderer == 'nefertari_json' + assert view._params.keys() == ['param2'] + + @patch('nefertari.view.BaseView.setup_default_wrappers') + @patch('nefertari.view.BaseView.convert_ids2objects') + @patch('nefertari.view.BaseView.set_public_limits') + def test_run_init_actions(self, limit, conv, setpub): + request = Mock( + content_type='text/plain', + json={'param1.foo': 'val1', 'param3': 'val3'}, + method='POST', + accept=['text/plain'], + ) + request.params.mixed.return_value = {'param2.foo': 'val2'} + BaseView(context={'foo': 'bar'}, request=request) + limit.assert_called_once_with() + conv.assert_called_once_with() + setpub.assert_called_once_with() + + @patch('nefertari.view.wrappers') + @patch('nefertari.view.BaseView._run_init_actions') + def test_set_public_limits_no_root(self, run, wrap): + request = Mock(content_type='', method='', accept=['']) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + view.root_resource = None + view.set_public_limits() + assert not wrap.set_public_limits.called + + @patch('nefertari.view.wrappers') + @patch('nefertari.view.BaseView._run_init_actions') + def test_set_public_limits_no_auth(self, run, wrap): + request = Mock(content_type='', method='', accept=['']) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + view.root_resource = Mock(auth=False) + view.set_public_limits() + assert not wrap.set_public_limits.called + + @patch('nefertari.view.wrappers') + @patch('nefertari.view.BaseView._run_init_actions') + def test_set_public_limits_user_authenticated(self, run, wrap): + request = Mock(content_type='', method='', accept=[''], user='foo') + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + view.root_resource = Mock(auth=True) + view.set_public_limits() + assert not wrap.set_public_limits.called + + @patch('nefertari.view.wrappers') + @patch('nefertari.view.BaseView._run_init_actions') + def test_set_public_limits_applied(self, run, wrap): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + view.root_resource = Mock(auth=True) + view.set_public_limits() + wrap.set_public_limits.assert_called_once_with(view) + + @patch('nefertari.view.engine') + @patch('nefertari.view.BaseView.id2obj') + @patch('nefertari.view.BaseView._run_init_actions') + def test_convert_ids2objects_non_relational(self, run, id2obj, eng): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo1': 'bar'}, + _json_params={'foo': 'bar'}) + view._model_class = 'Model1' + eng.is_relationship_field.return_value = False + view.convert_ids2objects() + eng.is_relationship_field.assert_called_once_with('foo', 'Model1') + assert not id2obj.called + + @patch('nefertari.view.engine') + @patch('nefertari.view.BaseView.id2obj') + @patch('nefertari.view.BaseView._run_init_actions') + def test_convert_ids2objects_relational(self, run, id2obj, eng): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo1': 'bar'}, + _json_params={'foo': 'bar'}) + view._model_class = 'Model1' + eng.is_relationship_field.return_value = True + view.convert_ids2objects() + eng.get_relationship_cls.assert_called_once_with('foo', 'Model1') + id2obj.assert_called_once_with('foo', eng.get_relationship_cls()) + + @patch('nefertari.view.BaseView._run_init_actions') + def test_get_debug(self, run): + request = Mock(content_type='', method='', accept=[''], user=None) + request.registry.settings = {'super.debug': 'true'} + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + assert view.get_debug(package='super') + + @patch('nefertari.view.BaseView._run_init_actions') + def test_get_debug_no_package(self, run): + request = Mock(content_type='', method='', accept=[''], user=None) + request.registry.settings = {'debug': 'false'} + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + assert not view.get_debug() + + @patch('nefertari.view.wrappers') + @patch('nefertari.view.BaseView._run_init_actions') + def test_setup_default_wrappers_with_auth(self, run, wrap): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + view.root_resource = Mock(auth=True) + view.setup_default_wrappers() + assert len(view._after_calls['index']) == 4 + assert len(view._after_calls['show']) == 3 + assert len(view._after_calls['delete']) == 1 + assert len(view._after_calls['delete_many']) == 1 + assert len(view._after_calls['update_many']) == 1 + assert wrap.apply_privacy.call_count == 2 + + @patch('nefertari.view.wrappers') + @patch('nefertari.view.BaseView._run_init_actions') + def test_setup_default_wrappers_no_auth(self, run, wrap): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + view.root_resource = Mock(auth=None) + view.setup_default_wrappers() + assert len(view._after_calls['index']) == 3 + assert len(view._after_calls['show']) == 2 + assert len(view._after_calls['delete']) == 1 + assert len(view._after_calls['delete_many']) == 1 + assert len(view._after_calls['update_many']) == 1 + assert not wrap.apply_privacy.called + + def test_defalt_wrappers_and_wrap_me(self): + from nefertari import wrappers + + self.maxDiff = None + + def before_call(*a): + return a[2] + + def after_call(*a): + return a[2] + + class MyView(BaseView): + + @wrappers.wrap_me(before=before_call, after=after_call) + def index(self): + return [1, 2, 3] + + def convert_ids2objects(self, *args, **kwargs): + pass + + request = MagicMock(content_type='') + resource = MagicMock(actions=['index']) + view = MyView(resource, request) + + assert len(view._after_calls['index']) == 3 + assert len(view._after_calls['show']) == 2 + assert len(view._after_calls['delete']) == 1 + assert len(view._after_calls['delete_many']) == 1 + assert len(view._after_calls['update_many']) == 1 + + assert view.index._before_calls == [before_call] + assert view.index._after_calls == [after_call] + + @patch('nefertari.view.BaseView._run_init_actions') + def test_not_allowed_action(self, run): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + with pytest.raises(JHTTPMethodNotAllowed): + view.not_allowed_action() + + @patch('nefertari.view.BaseView._run_init_actions') + def test_add_before_or_after_before(self, run): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + callable_ = lambda x: x + view.add_before_or_after_call( + action='foo', _callable=callable_, pos=None, before=True) + assert callable_ in view._before_calls['foo'] + + @patch('nefertari.view.BaseView._run_init_actions') + def test_add_before_or_after_after(self, run): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + callable_ = lambda x: x + view.add_before_or_after_call( + action='foo', _callable=callable_, pos=None, before=False) + assert callable_ in view._after_calls['foo'] + + @patch('nefertari.view.BaseView._run_init_actions') + def test_add_before_or_after_position(self, run): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + callable1 = lambda x: x + callable2 = lambda x: x + x + view.add_before_or_after_call( + action='foo', _callable=callable1, pos=None, + before=False) + assert callable1 is view._after_calls['foo'][0] + view.add_before_or_after_call( + action='foo', _callable=callable2, pos=0, + before=False) + assert callable2 is view._after_calls['foo'][0] + assert callable1 is view._after_calls['foo'][1] + + @patch('nefertari.view.BaseView._run_init_actions') + def test_add_before_or_after_not_callable(self, run): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + with pytest.raises(ValueError) as ex: + view.add_before_or_after_call( + action='foo', _callable='asdasd', pos=None, + before=False) + assert str(ex.value) == 'asdasd is not a callable' + + @patch('nefertari.view.urllib') + @patch('nefertari.view.Request') + @patch('nefertari.view.BaseView._run_init_actions') + def test_subrequest_get(self, run, req, ulib): + request = Mock( + content_type='', method='', accept=[''], user=None, + cookies=['1']) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + view.subrequest(url='http://', params={'par': 'val'}, method='GET') + req.blank.assert_called_once_with( + 'http://', cookies=['1'], content_type='application/json', + method='GET') + view.request.invoke_subrequest.assert_called_once_with(req.blank()) + ulib.urlencode.assert_called_once_with({'par': 'val'}) + + @patch('nefertari.view.json') + @patch('nefertari.view.Request') + @patch('nefertari.view.BaseView._run_init_actions') + def test_subrequest_post(self, run, req, json): + request = Mock( + content_type='', method='', accept=[''], user=None, + cookies=['1']) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + view.subrequest(url='http://', params={'par': 'val'}, method='POST') + req.blank.assert_called_once_with( + 'http://', cookies=['1'], content_type='application/json', + method='POST') + view.request.invoke_subrequest.assert_called_once_with(req.blank()) + json.dumps.assert_called_once_with({'par': 'val'}) + + @patch('nefertari.view.BaseView._run_init_actions') + def test_needs_confirmation(self, run): + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _query_params={'foo': 'bar'}) + view._query_params['__confirmation'] = '' + assert not view.needs_confirmation() + view._query_params.pop('__confirmation') + assert view.needs_confirmation() + + @patch('nefertari.view.BaseView._run_init_actions') + def test_id2obj(self, run): + model = Mock() + model.pk_field.return_value = 'idname' + model.get.return_value = 'foo' + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _json_params={'foo': 'bar'}, + _query_params={'foo1': 'bar1'}) + view._json_params['user'] = '1' + view.id2obj(name='user', model=model) + assert view._json_params['user'] == 'foo' + model.pk_field.assert_called_once_with() + model.get.assert_called_once_with(idname='1') + + @patch('nefertari.view.BaseView._run_init_actions') + def test_id2obj_list(self, run): + model = Mock() + model.pk_field.return_value = 'idname' + model.get.return_value = 'foo' + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _json_params={'foo': 'bar'}, + _query_params={'foo1': 'bar1'}) + view._json_params['user'] = ['1'] + view.id2obj(name='user', model=model) + assert view._json_params['user'] == ['foo'] + model.pk_field.assert_called_once_with() + model.get.assert_called_once_with(idname='1') + + @patch('nefertari.view.BaseView._run_init_actions') + def test_id2obj_not_in_params(self, run): + model = Mock() + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _json_params={'foo': 'bar'}, + _query_params={'foo1': 'bar1'}) + view.id2obj(name='asdasdasd', model=model) + assert not model.pk_field.called + assert not model.get.called + + @patch('nefertari.view.BaseView._run_init_actions') + def test_id2obj_setdefault(self, run): + model = Mock() + model.pk_field.return_value = 'idname' + model.get.return_value = None + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _json_params={'foo': 'bar'}, + _query_params={'foo1': 'bar1'}) + view._json_params['user'] = '1' + view.id2obj(name='user', model=model, setdefault=123) + assert view._json_params['user'] == 123 + model.pk_field.assert_called_once_with() + model.get.assert_called_once_with(idname='1') + + @patch('nefertari.view.BaseView._run_init_actions') + def test_id2obj_already_object(self, run): + id_ = Mock() + model = Mock() + model.pk_field.return_value = 'idname' + model.get.return_value = None + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _json_params={'foo': 'bar'}, + _query_params={'foo1': 'bar1'}) + view._json_params['user'] = id_ + view.id2obj(name='user', model=model, setdefault=123) + assert view._json_params['user'] == id_ + model.pk_field.assert_called_once_with() + assert not model.get.called + + @patch('nefertari.view.BaseView._run_init_actions') + def test_id2obj_not_found(self, run): + model = Mock() + model.pk_field.return_value = 'idname' + model.get.return_value = None + request = Mock(content_type='', method='', accept=[''], user=None) + view = BaseView( + context={}, request=request, _json_params={'foo': 'bar'}, + _query_params={'foo1': 'bar1'}) + view._json_params['user'] = '1' + with pytest.raises(JHTTPBadRequest) as ex: + view.id2obj(name='user', model=model) + assert str(ex.value) == 'id2obj: Object 1 not found' + + +class TestViewHelpers(object): + def test_key_error_view(self): + resp = key_error_view(Mock(message='foo'), None) + assert str(resp.message) == "Bad or missing param 'foo'" + + def test_value_error_view(self): + resp = value_error_view(Mock(message='foo'), None) + assert str(resp.message) == "Bad or missing value 'foo'" + + def test_error_view(self): + resp = error_view(Mock(message='foo'), None) + assert str(resp.message) == "foo" + + def test_includeme(self): + from nefertari.view import includeme + config = Mock() + includeme(config) + calls = [ + call(key_error_view, context=KeyError), + call(value_error_view, context=ValueError), + call(error_view, context=Exception) + ] + config.add_view.assert_has_calls(calls, any_order=True) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py new file mode 100644 index 0000000..a80e89b --- /dev/null +++ b/tests/test_wrappers.py @@ -0,0 +1,432 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import unittest + +import pytest +from mock import Mock, patch +from pyramid.testing import DummyRequest + +from nefertari import wrappers +from nefertari.utils import dictset + + +class TestWrappers(unittest.TestCase): + model_test_data = dictset({ + '_type': 'foo', + 'self': 'http://example.com/1', + 'name': 'User1', + 'desc': 'User 1 data', + 'id': 1, + 'other_field': 123 + }) + + def test_issequence(self): + class A(object): + def __init__(self, *args): + for arg in args: + setattr(self, arg, lambda x: x) + + assert not wrappers.issequence(A('strip')) + assert not wrappers.issequence(A('foo')) + assert wrappers.issequence(A('__getitem__')) + assert wrappers.issequence(A('__iter__')) + assert wrappers.issequence(A('__iter__', 'foo')) + assert wrappers.issequence(A('__getitem__', 'foo')) + + def test_wrap_me_init(self): + wrap = wrappers.wrap_me(before='foo', after=['bar']) + assert wrap.before == ['foo'] + assert wrap.after == ['bar'] + + wrap = wrappers.wrap_me(after=['bar']) + assert wrap.before == [] + assert wrap.after == ['bar'] + + wrap = wrappers.wrap_me() + assert wrap.before == [] + assert wrap.after == [] + + def test_wrap_me_call(self): + meth = lambda x: x + wrap = wrappers.wrap_me(before=['foo'], after=['bar']) + assert not hasattr(meth, '_before_calls') + assert not hasattr(meth, '_after_calls') + wrap(meth) + assert meth._before_calls == ['foo'] + assert meth._after_calls == ['bar'] + + wrap(meth) + assert meth._before_calls == ['foo', 'foo'] + assert meth._after_calls == ['bar', 'bar'] + + def test_callable_base(self): + class A(wrappers.callable_base): + pass + + obj1 = A(id=1) + obj2 = A(id=2) + assert obj1 == obj2 + assert obj1.kwargs == {'id': 1} + assert obj2.kwargs == {'id': 2} + + def test_obj2dict_dict_result(self): + assert wrappers.obj2dict(None)(result={'a': 'b'}) == {'a': 'b'} + + def test_obj2dict_regular(self): + result = Mock() + result.to_dict.return_value = {'a': 1} + assert wrappers.obj2dict(None)(result=result) == {'a': 1} + + def test_obj2dict_list_from_todict(self): + result = Mock() + result.to_dict.return_value = [dict(a=1), dict(b=2)] + self.assertEqual( + [dict(a=1), dict(b=2)], + wrappers.obj2dict(request=None)(result=result)) + + def test_obj2dict_nested(self): + special = Mock() + special.to_dict.return_value = {'special': 'dict'} + result = [special] + self.assertEqual( + [{'special': 'dict'}], + wrappers.obj2dict(request=None)(result=result)) + + def test_obj2dict_other_type(self): + self.assertEqual( + 'foo', + wrappers.obj2dict(request=None)(result='foo')) + + def test_add_meta(self): + result = {'data': [{'id': 4}]} + request = DummyRequest(path='http://example.com', environ={}) + result = wrappers.add_meta(request=request)(result=result) + assert result['count'] == 1 + assert result['data'][0]['self'] == 'http://example.com/4' + + environ = {'QUERY_STRING': '_limit=100'} + request = DummyRequest(path='http://example.com?_limit=100', + environ=environ) + assert request.path == 'http://example.com?_limit=100' + result = wrappers.add_meta(request=request)(result=result) + assert result['count'] == 1 + assert result['data'][0]['self'] == 'http://example.com/4' + + @patch('nefertari.wrappers.urllib') + def test_add_meta_type_error(self, mock_lib): + mock_lib.quote.side_effect = TypeError + result = {'data': [{'id': 4}]} + request = DummyRequest(path='http://example.com', environ={}) + result = wrappers.add_meta(request=request)(result=result) + assert result['count'] == 1 + assert result['data'][0] == {'id': 4} + + def test_apply_privacy_no_data(self): + assert wrappers.apply_privacy(None)(result={}) == {} + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_item_non_auth(self, mock_eng): + document_cls = Mock( + _public_fields=['name', 'desc'], + _auth_fields=['id']) + mock_eng.get_document_cls.return_value = document_cls + request = Mock(user=None) + filtered = wrappers.apply_privacy(request)(result=self.model_test_data) + assert list(sorted(filtered.keys())) == [ + '_type', 'desc', 'name', 'self'] + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_item_no_request(self, mock_eng): + document_cls = Mock( + _public_fields=['name', 'desc'], + _auth_fields=['id']) + mock_eng.get_document_cls.return_value = document_cls + filtered = wrappers.apply_privacy(None)(result=self.model_test_data) + assert list(sorted(filtered.keys())) == [ + '_type', 'desc', 'id', 'name', 'other_field', 'self'] + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_item_auth(self, mock_eng): + document_cls = Mock( + _public_fields=['name', 'desc'], + _auth_fields=['id']) + mock_eng.get_document_cls.return_value = document_cls + request = Mock(user=Mock()) + filtered = wrappers.apply_privacy(request)( + result=self.model_test_data, is_admin=False) + assert list(sorted(filtered.keys())) == [ + '_type', 'id', 'self'] + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_item_auth_calculated(self, mock_eng): + document_cls = Mock( + _public_fields=['name', 'desc'], + _auth_fields=['id']) + mock_eng.get_document_cls.return_value = document_cls + + class User(object): + @classmethod + def is_admin(self, obj): + return False + + request = Mock(user=User()) + filtered = wrappers.apply_privacy(request)(result=self.model_test_data) + assert list(sorted(filtered.keys())) == [ + '_type', 'id', 'self'] + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_item_admin(self, mock_eng): + document_cls = Mock( + _public_fields=['name', 'desc'], + _auth_fields=['id']) + mock_eng.get_document_cls.return_value = document_cls + request = Mock(user=Mock()) + filtered = wrappers.apply_privacy(request)( + result=self.model_test_data, is_admin=True) + assert list(sorted(filtered.keys())) == [ + '_type', 'desc', 'id', 'name', 'other_field', 'self'] + filtered['_type'] == 'foo' + filtered['desc'] == 'User 1 data' + filtered['id'] == 1 + filtered['name'] == 'User1' + filtered['other_field'] == 123 + filtered['self'] == 'http://example.com/1' + mock_eng.get_document_cls.assert_called_once_with('foo') + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_no_type(self, mock_eng): + data = self.model_test_data.copy() + data.pop('_type') + request = Mock(user=Mock()) + filtered = wrappers.apply_privacy(request)( + result=data, is_admin=True) + assert list(sorted(filtered.keys())) == [ + 'desc', 'id', 'name', 'other_field', 'self'] + filtered['desc'] == 'User 1 data' + filtered['id'] == 1 + filtered['name'] == 'User1' + filtered['other_field'] == 123 + filtered['self'] == 'http://example.com/1' + assert not mock_eng.get_document_cls.called + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_not_dict(self, mock_eng): + request = Mock(user=Mock()) + filtered = wrappers.apply_privacy(request)( + result='foo', is_admin=True) + assert filtered == 'foo' + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_item_admin_calculated(self, mock_eng): + document_cls = Mock( + _public_fields=['name', 'desc'], + _auth_fields=['id']) + mock_eng.get_document_cls.return_value = document_cls + + class User(object): + @classmethod + def is_admin(self, obj): + return True + + request = Mock(user=User()) + filtered = wrappers.apply_privacy(request)(result=self.model_test_data) + assert list(sorted(filtered.keys())) == [ + '_type', 'desc', 'id', 'name', 'other_field', 'self'] + filtered['_type'] == 'foo' + filtered['desc'] == 'User 1 data' + filtered['id'] == 1 + filtered['name'] == 'User1' + filtered['other_field'] == 123 + filtered['self'] == 'http://example.com/1' + mock_eng.get_document_cls.assert_called_once_with('foo') + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_item_no_document_cls(self, mock_eng): + mock_eng.get_document_cls.side_effect = ValueError + request = Mock(user=Mock()) + filtered = wrappers.apply_privacy(request)( + result=self.model_test_data, is_admin=True) + assert list(sorted(filtered.keys())) == [ + '_type', 'desc', 'id', 'name', 'other_field', 'self'] + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_item_no_fields(self, mock_eng): + document_cls = Mock( + _public_fields=['name', 'desc'], + _auth_fields=[]) + mock_eng.get_document_cls.return_value = document_cls + request = Mock(user=Mock()) + filtered = wrappers.apply_privacy(request)( + result=self.model_test_data, is_admin=False) + assert list(sorted(filtered.keys())) == ['_type', 'self'] + + @patch('nefertari.wrappers.engine') + def test_apply_privacy_collection(self, mock_eng): + document_cls = Mock( + _public_fields=['name', 'desc'], + _auth_fields=['id']) + mock_eng.get_document_cls.return_value = document_cls + request = Mock(user=Mock()) + result = { + 'total': 1, + 'count': 1, + 'data': [self.model_test_data] + } + filtered = wrappers.apply_privacy(request)( + result=result, is_admin=False) + assert list(sorted(filtered.keys())) == ['count', 'data', 'total'] + assert len(filtered['data']) == 1 + data = filtered['data'][0] + assert list(sorted(data.keys())) == ['_type', 'id', 'self'] + + @patch('nefertari.wrappers.obj2dict') + def test_wrap_in_dict_no_meta_dict(self, mock_obj): + result = Mock(spec=[]) + mock_obj.return_value = lambda **kw: {'foo': 'bar'} + processed = wrappers.wrap_in_dict(123)(result=result, qoo=1) + mock_obj.assert_called_once_with(123) + assert processed == {'foo': 'bar'} + + @patch('nefertari.wrappers.obj2dict') + def test_wrap_in_dict_meta_dict(self, mock_obj): + mock_obj.return_value = lambda **kw: {'foo': 'bar'} + result = Mock(_nefertari_meta={'meta': 'metameta'}) + processed = wrappers.wrap_in_dict(123)(result=result, qoo=1) + mock_obj.assert_called_once_with(123) + assert processed == {'foo': 'bar'} + + @patch('nefertari.wrappers.obj2dict') + def test_wrap_in_dict_no_meta_list(self, mock_obj): + result = Mock(spec=[]) + mock_obj.return_value = lambda **kw: [{'foo': 'bar'}] + processed = wrappers.wrap_in_dict(123)(result=result, qoo=1) + mock_obj.assert_called_once_with(123) + assert processed == {'data': [{'foo': 'bar'}]} + + @patch('nefertari.wrappers.obj2dict') + def test_wrap_in_dict_meta_list(self, mock_obj): + mock_obj.return_value = lambda **kw: [{'foo': 'bar'}] + result = Mock() + result._nefertari_meta = {'meta': 'metameta'} + processed = wrappers.wrap_in_dict(123)(result=result, qoo=1) + mock_obj.assert_called_once_with(123) + assert processed == {'data': [{'foo': 'bar'}], 'meta': 'metameta'} + + @patch('nefertari.wrappers.engine') + def test_add_confirmation_url(self, mock_eng): + mock_eng.BaseDocument.count.return_value = 12321 + request = Mock( + url='http://example.com/api?foo=bar', + params={'foo': 'bar'}, + method='GET' + ) + result = wrappers.add_confirmation_url(request)(result=3) + assert result['method'] == 'GET' + assert result['count'] == 12321 + assert result['confirmation_url'] == ( + 'http://example.com/api?foo=bar&__confirmation&_m=GET') + + @patch('nefertari.wrappers.engine') + def test_add_confirmation_url_no_request_params(self, mock_eng): + mock_eng.BaseDocument.count.return_value = 12321 + request = Mock( + url='http://example.com/api', + params=None, + method='GET' + ) + result = wrappers.add_confirmation_url(request)(result=3) + assert result['method'] == 'GET' + assert result['count'] == 12321 + assert result['confirmation_url'] == ( + 'http://example.com/api?__confirmation&_m=GET') + + def test_add_etag_no_data(self): + wrapper = wrappers.add_etag(Mock()) + wrapper.request.response.etag = None + wrapper(result={'data': []}) + assert wrapper.request.response.etag is None + wrapper(result={}) + assert wrapper.request.response.etag is None + + def test_add_etag(self): + wrapper = wrappers.add_etag(Mock()) + wrapper.request.response.etag = None + wrapper(result={'data': [ + {'id': 1, '_version': 1}, + {'id': 2, '_version': 1}, + ]}) + expected1 = '20d135f0f28185b84a4cf7aa51f29500' + assert wrapper.request.response.etag == expected1 + + # Etag is the same when data isn't changed + wrapper(result={'data': [ + {'id': 1, '_version': 1}, + {'id': 2, '_version': 1}, + ]}) + assert isinstance(wrapper.request.response.etag, basestring) + assert wrapper.request.response.etag == expected1 + + # New object added + wrapper(result={'data': [ + {'id': 1, '_version': 1}, + {'id': 2, '_version': 1}, + {'id': 3, '_version': 1}, + ]}) + assert isinstance(wrapper.request.response.etag, basestring) + assert wrapper.request.response.etag != expected1 + + # Existing object's version changed + wrapper(result={'data': [ + {'id': 1, '_version': 1}, + {'id': 2, '_version': 2}, + ]}) + assert isinstance(wrapper.request.response.etag, basestring) + assert wrapper.request.response.etag != expected1 + + def test_set_total(self): + result = Mock(_nefertari_meta={'total': 5}) + processed = wrappers.set_total(None, 2)(result=result) + assert processed._nefertari_meta['total'] == 2 + + result = Mock(_nefertari_meta={'total': 1}) + processed = wrappers.set_total(None, 2)(result=result) + assert processed._nefertari_meta['total'] == 1 + + def test_set_total_no_meta(self): + result = Mock(spec=[]) + processed = wrappers.set_total(None, 2)(result=result) + assert not hasattr(processed, '_nefertari_meta') + + @patch('nefertari.wrappers.set_total') + def test_set_public_limits(self, mock_set): + request = Mock() + request.registry.settings = {'public_max_limit': 123} + view = Mock( + request=request, + _query_params={'_limit': 100, '_page': 1, '_start': 90}) + wrappers.set_public_limits(view) + mock_set.assert_called_once_with(view.request, total=123) + view.add_after_call.assert_called_once_with( + 'index', mock_set(), pos=0) + assert view._query_params['_limit'] == 33 + + @patch('nefertari.wrappers.set_total') + def test_set_public_limits_no_params(self, mock_set): + request = Mock() + request.registry.settings = {} + view = Mock(request=request, _query_params={}) + wrappers.set_public_limits(view) + mock_set.assert_called_once_with(view.request, total=100) + view.add_after_call.assert_called_once_with( + 'index', mock_set(), pos=0) + assert '_limit' not in view._query_params + + @patch('nefertari.wrappers.set_total') + def test_set_public_limits_value_err(self, mock_set): + from nefertari.json_httpexceptions import JHTTPBadRequest + request = Mock() + request.registry.settings = {} + view = Mock(request=request, _query_params={}) + mock_set.side_effect = ValueError + with pytest.raises(JHTTPBadRequest): + wrappers.set_public_limits(view) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..09e6c6b --- /dev/null +++ b/tox.ini @@ -0,0 +1,15 @@ +[tox] +envlist = py27 + +[testenv] +setenv = + PYTHONHASHSEED=0 +deps = -rrequirements.dev +commands = py.test {posargs:--cov nefertari tests} + +[testenv:flake8] +deps = + flake8==2.3.0 + pep8==1.6.2 +commands = + flake8 nefertari