diff --git a/.github/test-pre-script.sh b/.github/test-pre-script.sh new file mode 100644 index 0000000..1ccf6b5 --- /dev/null +++ b/.github/test-pre-script.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +# Call script with sudo to install the required native libraries for installing the plugin's dependencies +sudo bash -xe "$(dirname "$0")"/../install-libs.sh \ No newline at end of file diff --git a/.github/workflows/integration_test.yaml b/.github/workflows/integration_test.yaml new file mode 100644 index 0000000..c9d539f --- /dev/null +++ b/.github/workflows/integration_test.yaml @@ -0,0 +1,18 @@ +name: Integration tests + +on: + pull_request: + +jobs: + integration-tests: + runs-on: ubuntu-latest + name: Integration Tests + steps: + - uses: actions/checkout@v3 + - name: Install required native libraries + run: sudo bash -xe install-libs.sh + - name: Install tox + run: python3 -m pip install tox + - name: Run integration tests + run: | + tox -e integration diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..85e6b9e --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,11 @@ +name: Tests + +on: + pull_request: + +jobs: + unit-tests: + uses: canonical/operator-workflows/.github/workflows/test.yaml@main + secrets: inherit + with: + pre-run-script: .github/test-pre-script.sh diff --git a/README.md b/README.md new file mode 100644 index 0000000..af449cf --- /dev/null +++ b/README.md @@ -0,0 +1,138 @@ +# Flask-Multipass-SAML-Groups + +This package provides an identity provider for [Flask-Multipass](https://github.com/indico/flask-multipass), +which allows you to use SAML groups. It is designed to be used +as a plugin for [Indico](https://github.com/indico/indico). + +> **Warning** +> The current code base has not been extensively tested and should be considered experimental. + + +## Motivation + +The current SAML identity provider in Flask-Multipass does not support groups (see [issue](https://github.com/indico/flask-multipass/issues/66)), +but groups are a very useful feature for Indico. This plugin provides a solution to this problem. + + +## Installation + +### Package installation +You need to install the package on the same virtual environment as your Indico instance. +You might use the following commands to switch to the Indico environment + +```bash +su - indico +source ~/.venv/bin/activate +``` + +Some of the dependencies, like [xmlsec](https://xmlsec.readthedocs.io/en/stable/install.html), +require native libraries to be installed on the system. To install these libraries on an +Ubuntu system, you can use the `install-packages.sh` file: + +```bash +sudo bash install-libs.sh +``` + +You can then install this package either via local source: + +```bash +git clone https://github.com/canonical/flask-multipass-saml-groups.git +cd flask-multipass-saml-groups +python setup.py install +``` + +or with pip: + +```bash +pip install git+https://github.com/canonical/flask-multipass-saml-groups.git +``` + + +### Indico setup + +In your Indico setup, you should see that the plugin is now available: + +```bash +indico setup list-plugins +``` + +In order to activate the plugin, you must add it to the list of active plugins in your Indico configuration file: + +```python +PLUGINS = { ..., 'saml_groups' } +``` + +Beyond that, the plugin uses its own database tables to persist the groups. Therefore you need to run + +```bash +indico db --all-plugins upgrade +``` +See [here](https://docs.getindico.io/en/latest/installation/plugins/) for more information on installing +Indico plugins. + + +### Identity provider configuration +The configuration is almost identical to the SAML identity provider in Flask-Multipass, +but you should use the type `saml_groups` instead of `saml`. The identity provider must be used +together with the SAML auth Provider, in order to receive the SAML groups in the authentication +data. + +The following is an example section in `indico.conf`: +```python + +_my_saml_config = { + 'sp': { + 'entityId': 'https://events.example.com', + 'x509cert': '', + 'privateKey': '', + }, + 'idp': { + 'entityId': 'https://login.example.com', + 'x509cert': 'YmFzZTY0IGVuY29kZWQgY2VydAo', + 'singleSignOnService': { + 'url': 'https://login.example.com/saml/', + 'binding': 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' + }, + 'singleLogoutService': { + 'url': 'https://login.example.com/+logout', + 'binding': 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' + } + }, + 'security': { + 'nameIdEncrypted': False, + 'authnRequestsSigned': False, + 'logoutRequestSigned': False, + 'logoutResponseSigned': False, + 'signMetadata': False, + 'wantMessagesSigned': False, + 'wantAssertionsSigned': False, + 'wantNameId' : False, + 'wantNameIdEncrypted': False, + 'wantAssertionsEncrypted': False, + 'allowSingleLabelDomains': False, + 'signatureAlgorithm': 'http://www.w3.org/2001/04/xmldsig-more#rsa-sha256', + 'digestAlgorithm': 'http://www.w3.org/2001/04/xmlenc#sha256' + }, +} + +MULTIPASS_AUTH_PROVIDERS = { + 'ubuntu': { + 'type': 'saml', + 'title': 'SAML SSO', + 'saml_config': _my_saml_config, + }, +} +IDENTITY_PROVIDERS = { +"ubuntu": { + "type": "saml_groups", + "trusted_email": True, + "mapping": { + "user_name": "username", + "first_name": "fullname", + "last_name": "", + "email": "email", + }, + "identifier_field": "openid", + } +} +``` \ No newline at end of file diff --git a/flask_multipass_saml_groups/__init__.py b/flask_multipass_saml_groups/__init__.py new file mode 100644 index 0000000..2c8134c --- /dev/null +++ b/flask_multipass_saml_groups/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""The package containing the SAML Groups plugin.""" diff --git a/flask_multipass_saml_groups/group_provider/__init__.py b/flask_multipass_saml_groups/group_provider/__init__.py new file mode 100644 index 0000000..3e8a8e0 --- /dev/null +++ b/flask_multipass_saml_groups/group_provider/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""The package containing the group providers for the SAML Groups plugin.""" diff --git a/flask_multipass_saml_groups/group_provider/base.py b/flask_multipass_saml_groups/group_provider/base.py new file mode 100644 index 0000000..540519d --- /dev/null +++ b/flask_multipass_saml_groups/group_provider/base.py @@ -0,0 +1,85 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. +"""Defines the interface for a group provider.""" + +from abc import ABCMeta, abstractmethod +from typing import Iterable, Optional + +from flask_multipass import Group, IdentityProvider + + +class GroupProvider(metaclass=ABCMeta): + """A group provider is responsible for managing groups and their members. + + Attrs: + group_class (type): The class to use for groups. + """ + + group_class = Group + + def __init__(self, identity_provider: IdentityProvider): + """Initialize the group provider. + + Args: + identity_provider: The associated identity provider. Usually required because the group + needs to know the identity provider. + """ + + @abstractmethod + def add_group(self, name: str) -> None: # pragma: no cover + """Add a group. + + Args: + name: The name of the group. + """ + + @abstractmethod + def get_group(self, name: str) -> Optional[Group]: # pragma: no cover + """Get a group. + + Args: + name: The name of the group. + + Returns: + The group or None if it does not exist. + """ + return None + + @abstractmethod + def get_groups(self) -> Iterable[Group]: # pragma: no cover + """Get all groups. + + Returns: + An iterable of all groups. + """ + return [] + + @abstractmethod + def get_user_groups(self, identifier: str) -> Iterable[Group]: # pragma: no cover + """Get all groups a user is a member of. + + Args: + identifier: The unique user identifier used by the provider. + + Returns: + iterable: An iterable of groups the user is a member of. + """ + return [] + + @abstractmethod + def add_group_member(self, identifier: str, group_name: str) -> None: # pragma: no cover + """Add a user to a group. + + Args: + identifier: The unique user identifier used by the provider. + group_name: The name of the group. + """ + + @abstractmethod + def remove_group_member(self, identifier: str, group_name: str) -> None: # pragma: no cover + """Remove a user from a group. + + Args: + identifier: The unique user identifier used by the provider. + group_name: The name of the group. + """ diff --git a/flask_multipass_saml_groups/group_provider/sql.py b/flask_multipass_saml_groups/group_provider/sql.py new file mode 100644 index 0000000..0f12672 --- /dev/null +++ b/flask_multipass_saml_groups/group_provider/sql.py @@ -0,0 +1,177 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""A group provider that persists groups and their members in a SQL database provided by Indico.""" + +from typing import Iterable, Iterator, Optional + +from flask_multipass import Group, IdentityInfo, IdentityProvider +from indico.core.db import db + +from flask_multipass_saml_groups.group_provider.base import GroupProvider +from flask_multipass_saml_groups.models.saml_groups import SAMLGroup as DBGroup +from flask_multipass_saml_groups.models.saml_groups import SAMLUser + + +class SQLGroup(Group): + """A group whose group membership is persisted in a SQL database. + + Attrs: + supports_member_list (bool): If the group supports getting the list of members + """ + + supports_member_list = True + + def __init__(self, provider: IdentityProvider, name: str): + """Initialize the group. + + Args: + provider: The associated identity provider. + name: The unique, case-sensitive name of this group. + """ + super().__init__(provider, name) + self._provider = provider + self._name = name + + def get_members(self) -> Iterator[IdentityInfo]: + """Return the members of the group. + + Returns: + An iterator over IdentityInfo objects. + """ + db_group = DBGroup.query.filter_by(name=self._name).first() + if db_group: + return iter( + map( + lambda m: IdentityInfo(provider=self._provider, identifier=m.identifier), + db_group.members, + ) + ) + return iter([]) + + def has_member(self, identifier: str) -> bool: + """Check if a given identity is a member of the group. + + Args: + identifier: The unique user identifier used by the provider. + + Returns: + True if the user is a member of the group, False otherwise. + """ + return ( + DBGroup.query.filter_by(name=self._name) + .join(DBGroup.members) + .filter_by(identifier=identifier) + .first() + is not None + ) + + +class SQLGroupProvider(GroupProvider): + """Provide access to Groups persisted with a SQL database. + + Attrs: + group_class (class): The class to use for groups. + """ + + # pylint does not recognize the methods of db.session, which is a proxy object + # pylint: disable=no-member + + group_class = SQLGroup + + def __init__(self, identity_provider: IdentityProvider): + """Initialize the group provider. + + Args: + identity_provider: The identity provider this group provider is associated with. + """ + super().__init__(identity_provider) + self._identity_provider = identity_provider + + def add_group(self, name: str) -> None: + """Add a group. + + Args: + name: The name of the group. + """ + grp = DBGroup.query.filter_by(name=name).first() + if not grp: + db.session.add(DBGroup(name=name)) + db.session.commit() + + def get_group(self, name: str) -> Optional[SQLGroup]: + """Get a group. + + Args: + name: The name of the group. + + Returns: + The group or None if it does not exist. + """ + grp = DBGroup.query.filter_by(name=name).first() + if grp: + return SQLGroup(provider=self._identity_provider, name=grp.name) + return None + + def get_groups(self) -> Iterable[SQLGroup]: + """Get all groups. + + Returns: + An iterable of all groups. + """ + return map( + lambda g: SQLGroup(provider=self._identity_provider, name=g.name), + DBGroup.query.all(), + ) + + def get_user_groups(self, identifier: str) -> Iterable[SQLGroup]: + """Get all groups a user is a member of. + + Args: + identifier: The unique user identifier used by the provider. + + Returns: + iterable: An iterable of groups the user is a member of. + """ + user = SAMLUser.query.filter_by(identifier=identifier).first() + if user: + return map( + lambda g: SQLGroup(name=g.name, provider=self._identity_provider), + user.groups, + ) + return [] + + def add_group_member(self, identifier: str, group_name: str) -> None: + """Add a user to a group. + + Args: + identifier: The unique user identifier used by the provider. + group_name: The name of the group. + """ + user = SAMLUser.query.filter_by(identifier=identifier).first() + grp = DBGroup.query.filter_by(name=group_name).first() + + if not user: + user = SAMLUser(identifier=identifier) + db.session.add(user) + if not grp: + grp = DBGroup(name=group_name) + db.session.add(grp) + + if user not in grp.members: + grp.members.append(user) + db.session.commit() + + def remove_group_member(self, identifier: str, group_name: str) -> None: + """Remove a user from a group. + + Args: + identifier: The unique user identifier used by the provider. + group_name: The name of the group. + """ + user = SAMLUser.query.filter_by(identifier=identifier).first() + grp = DBGroup.query.filter_by(name=group_name).first() + + if grp and user in grp.members: + grp.members.remove(user) + db.session.commit() diff --git a/flask_multipass_saml_groups/migrations/20230725_1640_ae387f5fc14a_initial_migration.py b/flask_multipass_saml_groups/migrations/20230725_1640_ae387f5fc14a_initial_migration.py new file mode 100644 index 0000000..6b07a15 --- /dev/null +++ b/flask_multipass_saml_groups/migrations/20230725_1640_ae387f5fc14a_initial_migration.py @@ -0,0 +1,90 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +# noqa disable qa, because file is autogenerated +# flake8: noqa +# type: ignore + +"""initial migration + +Revision ID: ae387f5fc14a +Revises: +Create Date: 2023-07-25 16:40:17.110259 +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.sql.ddl import CreateSchema, DropSchema + +# revision identifiers, used by Alembic. +revision = "ae387f5fc14a" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): # noqa + op.execute(CreateSchema("plugin_saml_groups")) + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "saml_groups", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), + schema="plugin_saml_groups", + ) + with op.batch_alter_table("saml_groups", schema="plugin_saml_groups") as batch_op: + batch_op.create_index(batch_op.f("ix_saml_groups_name"), ["name"], unique=True) + + op.create_table( + "saml_users", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("identifier", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), + schema="plugin_saml_groups", + ) + with op.batch_alter_table("saml_users", schema="plugin_saml_groups") as batch_op: + batch_op.create_index(batch_op.f("ix_saml_users_identifier"), ["identifier"], unique=True) + + op.create_table( + "saml_group_members", + sa.Column("group_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["group_id"], + ["plugin_saml_groups.saml_groups.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["plugin_saml_groups.saml_users.id"], + ), + sa.PrimaryKeyConstraint("group_id", "user_id"), + schema="plugin_saml_groups", + ) + with op.batch_alter_table("saml_group_members", schema="plugin_saml_groups") as batch_op: + batch_op.create_index( + batch_op.f("ix_saml_group_members_group_id"), ["group_id"], unique=False + ) + batch_op.create_index( + batch_op.f("ix_saml_group_members_user_id"), ["user_id"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade(): # noqa + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("saml_group_members", schema="plugin_saml_groups") as batch_op: + batch_op.drop_index(batch_op.f("ix_saml_group_members_user_id")) + batch_op.drop_index(batch_op.f("ix_saml_group_members_group_id")) + + op.drop_table("saml_group_members", schema="plugin_saml_groups") + with op.batch_alter_table("saml_users", schema="plugin_saml_groups") as batch_op: + batch_op.drop_index(batch_op.f("ix_saml_users_identifier")) + + op.drop_table("saml_users", schema="plugin_saml_groups") + with op.batch_alter_table("saml_groups", schema="plugin_saml_groups") as batch_op: + batch_op.drop_index(batch_op.f("ix_saml_groups_name")) + + op.drop_table("saml_groups", schema="plugin_saml_groups") + # ### end Alembic commands ### + op.execute(DropSchema("plugin_saml_groups")) diff --git a/flask_multipass_saml_groups/models/__init__.py b/flask_multipass_saml_groups/models/__init__.py new file mode 100644 index 0000000..c9a5ee1 --- /dev/null +++ b/flask_multipass_saml_groups/models/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""The package containing the SQLAlchemy models for the SAML Groups plugin.""" diff --git a/flask_multipass_saml_groups/models/saml_groups.py b/flask_multipass_saml_groups/models/saml_groups.py new file mode 100644 index 0000000..5b135ac --- /dev/null +++ b/flask_multipass_saml_groups/models/saml_groups.py @@ -0,0 +1,75 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""The database models for the SAML Groups plugin.""" + +from typing import List + +from indico.core.db import db +from sqlalchemy.orm import Mapped + +SCHEMA = "plugin_saml_groups" +group_members_table = db.Table( + "saml_group_members", + db.metadata, + db.Column( + "group_id", + db.Integer, + db.ForeignKey(f"{SCHEMA}.saml_groups.id"), + primary_key=True, + nullable=False, + index=True, + ), + db.Column( + "user_id", + db.Integer, + db.ForeignKey(f"{SCHEMA}.saml_users.id"), + primary_key=True, + nullable=False, + index=True, + ), + schema=SCHEMA, +) + + +class SAMLGroup(db.Model): # pylint: disable=too-few-public-methods + """The model containing the groups. + + Attrs: + id: The group's ID + name: The group's name + """ + + __tablename__ = "saml_groups" + __table_args__ = {"schema": SCHEMA} + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String, nullable=False, unique=True, index=True) + + +class SAMLUser(db.Model): # pylint: disable=too-few-public-methods + """The model containing the user identifiers. + + Attrs: + id: The user's ID in the database + identifier: The user's identifier from the identity provider + groups: The groups the user is a member of + """ + + __tablename__ = "saml_users" + __table_args__ = {"schema": SCHEMA} + + id = db.Column(db.Integer, primary_key=True) + identifier = db.Column(db.String, nullable=False, unique=True, index=True) + groups: Mapped[List[SAMLGroup]] = db.relationship( + SAMLGroup, + secondary=group_members_table, + back_populates="members", + ) + + +SAMLGroup.members = db.relationship( + SAMLUser, + secondary=group_members_table, + back_populates="groups", +) diff --git a/flask_multipass_saml_groups/plugin.py b/flask_multipass_saml_groups/plugin.py new file mode 100644 index 0000000..cad2a22 --- /dev/null +++ b/flask_multipass_saml_groups/plugin.py @@ -0,0 +1,12 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. +"""Marks the package in order to be used by the Indico plugin system.""" + +from indico.core.plugins import IndicoPlugin + + +class SAMLGroupsPlugin(IndicoPlugin): # pylint: disable=too-few-public-methods + """SAML Groups Plugin. + + The plugin provides an identity provider for SAML which supports groups. + """ diff --git a/flask_multipass_saml_groups/provider.py b/flask_multipass_saml_groups/provider.py new file mode 100644 index 0000000..3c604aa --- /dev/null +++ b/flask_multipass_saml_groups/provider.py @@ -0,0 +1,147 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. +# +"""SAML Groups Identity Provider.""" +import operator +from typing import Dict, Iterable, Optional, Type + +from flask_multipass import ( + AuthInfo, + Group, + IdentityInfo, + IdentityProvider, + IdentityRetrievalFailed, + Multipass, +) + +from flask_multipass_saml_groups.group_provider.base import GroupProvider +from flask_multipass_saml_groups.group_provider.sql import SQLGroupProvider + +DEFAULT_IDENTIFIER_FIELD = "_saml_nameid_qualified" + +SAML_GRP_ATTR_NAME = "urn:oasis:names:tc:SAML:2.0:profiles:attribute:DCE:groups" + + +class SAMLGroupsIdentityProvider(IdentityProvider): + """Provides identity information using SAML and supports groups. + + Attrs: + supports_get (bool): If the provider supports getting identity information + based from an identifier + supports_groups (bool): If the provider also provides groups and membership information + supports_get_identity_groups (bool): If the provider supports getting the list of groups an + identity belongs to + group_class (class): The class to use for groups. Defaults to flask_multipass.Group but + concrete class will be used from group_provider_class + """ + + supports_get = False + supports_groups = True + supports_get_identity_groups = True + + group_class = Group + + def __init__( + self, + multipass: Multipass, + name: str, + settings: Dict, + group_provider_class: Type[GroupProvider] = SQLGroupProvider, + ): + """Initialize the identity provider. + + Args: + multipass: The Flask-Multipass instance + name: The name of this identity provider instance + settings: The settings dictionary for this identity + provider instance + group_provider_class: The class to use for the group provider. + """ + super().__init__(multipass=multipass, name=name, settings=settings) + self.id_field = self.settings.setdefault("identifier_field", DEFAULT_IDENTIFIER_FIELD) + self._group_provider = group_provider_class(identity_provider=self) + self.group_class = self._group_provider.group_class + + def get_identity_from_auth(self, auth_info: AuthInfo) -> IdentityInfo: + """Retrieve identity information after authentication. + + Args: + auth_info: An AuthInfo instance from an auth provider. + + Raise: + IdentityRetrievalFailed: If the identifier is missing or there do exist multiple + in the saml response. + + Returns: + IdentityInfo: An IdentityInfo instance containing identity information + or None if no identity was found. + + """ + identifier = auth_info.data.get(self.id_field) + if isinstance(identifier, list): + if len(identifier) != 1: + raise IdentityRetrievalFailed("Identifier has multiple elements", provider=self) + identifier = identifier[0] + if not identifier: + raise IdentityRetrievalFailed("Identifier missing in saml response", provider=self) + + identity_info = IdentityInfo(self, identifier=identifier, **auth_info.data) + + grp_names = auth_info.data.get(SAML_GRP_ATTR_NAME) + + if grp_names: + if isinstance(grp_names, str): + # If only one group is returned, it is returned as a string by saml auth provider + grp_names = [grp_names] + + user_groups = self._group_provider.get_user_groups(identifier=identifier) + for group in user_groups: + if group.name not in grp_names: + self._group_provider.remove_group_member( + group_name=group.name, identifier=identifier + ) + + for grp_name in grp_names: + self._group_provider.add_group_member(group_name=grp_name, identifier=identifier) + + return identity_info + + def get_group(self, name: str) -> Optional[Group]: + """Return a specific group. + + Args: + name: The name of the group. + + Returns: + group: An instance of group_class or None if the group does not exist. + """ + return self._group_provider.get_group(name) + + def search_groups(self, name: str, exact: bool = False) -> Iterable[Group]: + """Search groups by name. + + Args: + name: The name to search for. + exact (bool, optional): If True, the name needs to match exactly, + i.e., no substring matches are performed. + + Yields: + a matching group_class object. + + """ + compare = operator.eq if exact else operator.contains + for group in self._group_provider.get_groups(): + if compare(group.name, name): + yield group + + def get_identity_groups(self, identifier: str) -> Iterable[Group]: + """Retrieve the groups a user identity belongs to. + + Args: + identifier: The unique user identifier used by the + provider. + + Returns: + iterable: An iterable of groups + """ + return self._group_provider.get_user_groups(identifier=identifier) diff --git a/install-libs.sh b/install-libs.sh new file mode 100644 index 0000000..c8b6b3b --- /dev/null +++ b/install-libs.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +# Script that installs the required native libraries for installing the plugin's dependencies +# on Ubuntu 22.04. + +if [ "$EUID" -ne 0 ] + then echo "Please run as root" + exit +fi + +apt-get update +apt-get install -y pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl libpython3-dev gcc libpq-dev diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..25accef --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,61 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +[tool.bandit] +exclude_dirs = ["/venv/"] +[tool.bandit.assert_used] +skips = ["*/*test.py", "*/test_*.py", "*tests/*.py"] + +# Testing tools configuration +[tool.coverage.run] +branch = true + +# Formatting tools configuration +[tool.black] +line-length = 99 +target-version = ["py38"] + +[tool.coverage.report] +fail_under = 95 +show_missing = true + +# Linting tools configuration +[tool.flake8] +max-line-length = 99 +max-doc-length = 99 +max-complexity = 10 +exclude = [".git", "__pycache__", ".tox", "build", "dist", "*.egg_info", "venv"] +select = ["E", "W", "F", "C", "N", "R", "D", "H"] +# Ignore W503, E501 because using black creates errors with this +# Ignore D107 Missing docstring in __init__ +ignore = ["W503", "E501", "D107"] +# D100, D101, D102, D103: Ignore missing docstrings in tests +per-file-ignores = ["tests/*:D100,D101,D102,D103,D104,D205,D212,D415"] +docstring-convention = "google" +# Check for properly formatted copyright header in each file +copyright-check = "True" +copyright-author = "Canonical Ltd." +copyright-regexp = "Copyright\\s\\d{4}([-,]\\d{4})*\\s+%(author)s" + +[tool.isort] +line_length = 99 +profile = "black" + +[tool.mypy] +check_untyped_defs = true +disallow_untyped_defs = true +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "tests.*" +disallow_untyped_defs = false + +[tool.pytest.ini_options] +minversion = "6.0" +log_cli_level = "INFO" +markers = [ + "requires_secrets: mark tests that require external secrets" +] + +[tool.pylint] +disable = "wrong-import-order" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..b987764 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = -p no:indico diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0012b68 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +flask-multipass[saml]>=0.4.* +indico>=3.2 +flask_sqlalchemy \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..7e9ef6f --- /dev/null +++ b/setup.cfg @@ -0,0 +1,24 @@ +[metadata] +name = Flask-Multipass-SAML-Groups +version = 0.0.1 +license = Apache-2.0 +author = launchpad.net/~canonical-is-devops +author_email = is-devops-team@canonical.com + +[options] +packages = find: +include_package_data = true +python_requires = ~=3.8 +install_requires = + flask-multipass[saml]>=0.4.3 + +[options.packages.find] +include = + flask_multipass_saml_groups + flask_multipass_saml_groups.* + +[options.entry_points] +flask_multipass.identity_providers = + saml_groups = flask_multipass_saml_groups.provider:SAMLGroupsIdentityProvider +indico.plugins = + saml_groups = flask_multipass_saml_groups.plugin:SAMLGroupsPlugin diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b13e707 --- /dev/null +++ b/setup.py @@ -0,0 +1,6 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +from setuptools import setup + +setup() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..cb88756 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 0000000..cbafe91 --- /dev/null +++ b/tests/common.py @@ -0,0 +1,33 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. +"""Add common functions for testing.""" + +from flask import Flask +from indico.core.db import db + + +def setup_sqlite(app: Flask): + """Add sqlite to app config and setup the database. + + Args: + app: The flask app. + """ + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" + with app.app_context(): + db.init_app(app) + # pylint does not recognize the methods of db.session, which is a proxy object + # pylint: disable=no-member + db.session.execute("attach ':memory:' as plugin_saml_groups;") + db.session.execute( + "CREATE TABLE plugin_saml_groups.saml_users " + "(id INTEGER PRIMARY KEY, identifier TEXT UNIQUE);" + ) + db.session.execute( + "CREATE TABLE plugin_saml_groups.saml_groups " + "(id INTEGER PRIMARY KEY, name TEXT UNIQUE);" + ) + db.session.execute( + "CREATE TABLE plugin_saml_groups.saml_group_members " + "(group_id INTEGER, user_id INTEGER);" + ) + db.session.commit() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..cb88756 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. diff --git a/tests/integration/common.py b/tests/integration/common.py new file mode 100644 index 0000000..c5828e3 --- /dev/null +++ b/tests/integration/common.py @@ -0,0 +1,119 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. +"""Common functions for integration tests.""" + +import base64 +from collections import namedtuple +from typing import List + +from flask.testing import FlaskClient + +from flask_multipass_saml_groups.provider import SAML_GRP_ATTR_NAME + +SP_ENTITY_ID = "http://localhost" + +User = namedtuple("User", ["email", "identifier"]) + + +def mk_saml_response(groups: List[str], user_email: str): + """Create a SAML response with the given groups and user email. + + Args: + groups: The groups to include in the response + user_email: The user email to include in the response + + Returns: + The SAML response + """ + if groups: + group_attr = f""" + + {"".join(f'{grp}' for grp in groups)} + """ + else: + group_attr = "" + return f""" + + + https://login.saml.com + + + + + https://login.saml.com + + {user_email} + + + + + + + {SP_ENTITY_ID} + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:Password + + + + + http://openid + + + {user_email} + + + foo + + + {user_email} + + + foo bar + + {group_attr} + + +""" + + +def login(client: FlaskClient, user_email: str, groups: List[str]): + """Login a user with the given email and groups. + + Args: + client: The test client + user_email: The user email + groups: The groups to include in the SAML response + """ + resp = client.get("/login/ubuntu") + assert resp.status_code == 302 + saml_response = mk_saml_response(groups=groups, user_email=user_email) + resp = client.post( + "/multipass/saml/ubuntu/acs", + data={ + "SAMLResponse": base64.b64encode(saml_response.encode("utf-8")), + "RelayState": "/login/ubuntu", + }, + ) + assert resp.status_code == 302 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..925baca --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,114 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Common fixtures for integration tests.""" +from secrets import token_hex + +import onelogin +import pytest +from flask import Flask +from flask_multipass import Multipass + +from flask_multipass_saml_groups.provider import SAMLGroupsIdentityProvider +from tests.common import setup_sqlite +from tests.integration.common import SP_ENTITY_ID, User + + +@pytest.fixture(name="config") +def config_fixture(): + """Return a config dict for the flask multipass plugin.""" + saml_config = { + "sp": { + "entityId": SP_ENTITY_ID, + "x509cert": "", + "privateKey": "", + }, + "idp": { + "entityId": "https://login.saml.com", + "x509cert": "dGVzdAo=", + "singleSignOnService": { + "url": "https://login.saml.com/saml/", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", + }, + "singleLogoutService": { + "url": "https://login.saml.com/+logout", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", + }, + }, + "security": { + "nameIdEncrypted": False, + "authnRequestsSigned": False, + "logoutRequestSigned": False, + "logoutResponseSigned": False, + "signMetadata": False, + "wantMessagesSigned": False, + "wantAssertionsSigned": False, + "wantNameId": False, + "wantNameIdEncrypted": False, + "wantAssertionsEncrypted": False, + "allowSingleLabelDomains": False, + "signatureAlgorithm": "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", + "digestAlgorithm": "http://www.w3.org/2001/04/xmlenc#sha256", + }, + } + multipass_auth_providers = { + "ubuntu": { + "type": "saml", + "title": "SAML SSO", + "saml_config": saml_config, + }, + } + multipass_identity_providers = { + "ubuntu": { + "type": "saml_groups", + "title": "SAML", + "mapping": { + "name": "DisplayName", + "email": "EmailAddress", + "affiliation": "HomeInstitute", + }, + } + } + multipass_provider_map = { + "ubuntu": "ubuntu", + } + return { + "MULTIPASS_AUTH_PROVIDERS": multipass_auth_providers, + "MULTIPASS_IDENTITY_PROVIDERS": multipass_identity_providers, + "MULTIPASS_PROVIDER_MAP": multipass_provider_map, + } + + +@pytest.fixture(name="app") +def app_fixture(config): + """Return a properly setup flask app.""" + app = Flask("test") + setup_sqlite(app) + app.config.update(config) + app.debug = True + app.secret_key = "fma-example" # nosec + app.add_url_rule("/", "index", lambda: "") + + return app + + +@pytest.fixture(name="multipass") +def multipass_fixture(app, monkeypatch): + """Return a properly set up flask multipass instance.""" + multipass = Multipass() + multipass.register_provider(SAMLGroupsIdentityProvider, "saml_groups") + multipass.init_app(app) + multipass.identity_handler(lambda identity: None) + monkeypatch.setattr( + onelogin.saml2.response.OneLogin_Saml2_Response, "is_valid", lambda *args, **kwargs: True + ) # disable signature validation of SAML response + + return multipass + + +@pytest.fixture(name="user") +def user_fixture(): + """Return a user email and identifier.""" + user_email = token_hex(16) + user_identifier = f"{user_email}@{SP_ENTITY_ID}" + return User(email=user_email, identifier=user_identifier) diff --git a/tests/integration/test_login.py b/tests/integration/test_login.py new file mode 100644 index 0000000..9e937fa --- /dev/null +++ b/tests/integration/test_login.py @@ -0,0 +1,98 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Integration tests which check if the groups are properly handled when a user logins.""" +from secrets import token_hex +from typing import List + +from flask import Flask +from flask_multipass import Multipass + +from tests.integration.common import login + + +def test_login_extracts_groups_from_saml_attributes(app, multipass, user): + """ + arrange: given an app + act: call login with a user that has groups in the SAML attributes + assert: the user is logged in and the groups are assigned to the user + """ + client = app.test_client() + + grp_names = [token_hex(16), token_hex(6)] + login(client, groups=grp_names, user_email=user.email) + + _assert_user_only_in_groups(grp_names, app, multipass, user.identifier) + + +def test_relogin_removes_previous_groups(app, multipass, user): + """ + arrange: given an app and a user with groups + act: call login with a user and recall login with the same user but differing groups + assert: the user is logged in and the invalid group memberships from first login are removed + """ + client = app.test_client() + + grp_names = [token_hex(16), token_hex(6)] + other_grp_names = [token_hex(16), token_hex(6)] + + login(client, groups=grp_names, user_email=user.email) + login(client, groups=other_grp_names, user_email=user.email) + + _assert_user_only_in_groups(other_grp_names, app, multipass, user.identifier) + + +def test_login_with_no_groups(app, multipass, user): + """ + arrange: given an app + act: call login with a user that has no groups in the SAML attributes + assert: the user is logged in and no groups are assigned to the user + """ + client = app.test_client() + + login(client, groups=[], user_email=user.email) + + _assert_user_only_in_groups([], app, multipass, user.identifier) + + +def test_login_with_multiple_identical_groups(app, multipass, user): + """ + arrange: given an app + act: call login with a user that has duplicate group names in the SAML attributes + assert: the user is logged in and the duplicate group names are not counted + """ + client = app.test_client() + + grp_name = token_hex(16) + + login(client, groups=[grp_name, grp_name], user_email=user.email) + + with app.app_context(): + idp = multipass.identity_providers["ubuntu"] + + grps = list(idp.get_identity_groups(user.identifier)) + assert len(grps) == 1 + + +def _assert_user_only_in_groups( + groups: List[str], app: Flask, multipass: Multipass, user_identifier: str +): + """Assert that the user is only in the given groups. + + Args: + groups: The groups the user should be in + app: The app + multipass: The multipass instance + user_identifier: The identifier of the user which is expected to belong to the groups. + + """ + with app.app_context(): + idp = multipass.identity_providers["ubuntu"] + + grps = list(idp.get_identity_groups(user_identifier)) + assert len(grps) == len(groups) + + for grp in grps: + assert grp.name in groups + assert isinstance(grp, idp.group_class) + assert grp.has_member(user_identifier) diff --git a/tests/integration/test_provider.py b/tests/integration/test_provider.py new file mode 100644 index 0000000..68d8df3 --- /dev/null +++ b/tests/integration/test_provider.py @@ -0,0 +1,91 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Integration tests which check if the provider works as expected.""" +from secrets import token_hex +from typing import List + +from flask_multipass import Group, IdentityInfo + +from tests.integration.common import login + + +def test_get_group(app, multipass, user): + """ + arrange: given a logged-in user which is a member of a group + act: call get_group on identity provider + assert: the group is returned and group methods work as expected + """ + client = app.test_client() + + grp_name = token_hex(16) + login(client, groups=[grp_name], user_email=user.email) + + with app.app_context(): + idp = multipass.identity_providers["ubuntu"] + _assert_group_methods_work([idp.get_group(grp_name)], user.identifier) + + +def test_get_identity_groups(app, multipass, user): + """ + arrange: given a logged-in user which is a member of multiple group + act: call get_identity_groups on identity provider + assert: the groups of the user are returned and group methods work as expected + """ + client = app.test_client() + + grp_names = [token_hex(16), token_hex(6)] + login(client, groups=grp_names, user_email=user.email) + + with app.app_context(): + idp = multipass.identity_providers["ubuntu"] + groups = list(idp.get_identity_groups(user.identifier)) + + _assert_group_methods_work(groups, user.identifier) + _assert_group_names(groups, grp_names) + + +def test_search_groups(app, multipass, user): + """ + arrange: given a logged-in user which is a member of a group + act: call search_groups on identity provider + assert: only the matched groups of the user are returned and group methods work as expected + """ + client = app.test_client() + login(client, groups=["x", "xy", "z"], user_email=user.email) + + with app.app_context(): + idp = multipass.identity_providers["ubuntu"] + groups = list(idp.search_groups("x")) + + _assert_group_methods_work(groups, user.identifier) + _assert_group_names(groups, ["x", "xy"]) + + +def _assert_group_methods_work(groups: List[Group], user_identifier: str): + """Assert that all the group methods work as expected. + + Args: + groups: The groups to check. + user_identifier: The identifier of the user which is expected to belong to the groups. + """ + for grp in groups: + members = list(grp.get_members()) + assert len(members) == 1 + member = members[0] + assert isinstance(member, IdentityInfo) + assert member.identifier == user_identifier + + assert grp.has_member(user_identifier) + + +def _assert_group_names(groups: List[Group], expected_names: List[str]): + """Assert that the group names are as expected. + + Args: + groups: The groups to check. + expected_names: The expected group names. + """ + assert len(groups) == len(expected_names) + for grp in groups: + assert grp.name in expected_names diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..cb88756 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. diff --git a/tests/unit/group_provider/__init__.py b/tests/unit/group_provider/__init__.py new file mode 100644 index 0000000..cb88756 --- /dev/null +++ b/tests/unit/group_provider/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. diff --git a/tests/unit/group_provider/sql/__init__.py b/tests/unit/group_provider/sql/__init__.py new file mode 100644 index 0000000..cb88756 --- /dev/null +++ b/tests/unit/group_provider/sql/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. diff --git a/tests/unit/group_provider/sql/conftest.py b/tests/unit/group_provider/sql/conftest.py new file mode 100644 index 0000000..6f65961 --- /dev/null +++ b/tests/unit/group_provider/sql/conftest.py @@ -0,0 +1,16 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. +"""Common functions for testing the sql group provider.""" + +import pytest +from flask import Flask + +from tests.common import setup_sqlite + + +@pytest.fixture(name="app") +def app_fixture(): + """Create a flask app with a properly setup sqlite db.""" + app = Flask("test") + setup_sqlite(app) + return app diff --git a/tests/unit/group_provider/sql/test_group.py b/tests/unit/group_provider/sql/test_group.py new file mode 100644 index 0000000..6f35255 --- /dev/null +++ b/tests/unit/group_provider/sql/test_group.py @@ -0,0 +1,117 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Unit tests for the sql group.""" +from secrets import token_hex + +import pytest +from flask_multipass import IdentityInfo, Multipass + +from flask_multipass_saml_groups.group_provider.sql import SQLGroup, SQLGroupProvider +from flask_multipass_saml_groups.provider import SAMLGroupsIdentityProvider + + +@pytest.fixture(name="group_name") +def group_name_fixture(): + """Return a group name""" + return token_hex(16) + + +@pytest.fixture(name="provider") +def provider_fixture(app): + """Create an identity provider""" + multipass = Multipass(app) + with app.app_context(): + yield SAMLGroupsIdentityProvider(multipass=multipass, name="saml_groups", settings={}) + + +@pytest.fixture(name="group_provider") +def group_provider_fixture(provider): + """Create a group provider""" + return SQLGroupProvider(identity_provider=provider) + + +@pytest.fixture(name="group") +def group_fixture(provider, group_name): + """Create a group object. + + The group is not created in the database, only the object is created. + """ + return SQLGroup(provider=provider, name=group_name) + + +def test_get_members(group, group_provider, group_name): + """ + arrange: given group with users + act: call get_members + assert: the users are contained in the returned result + """ + users = [token_hex(16), token_hex(6)] + group_provider.add_group_member(group_name=group_name, identifier=users[0]) + group_provider.add_group_member(group_name=group_name, identifier=users[1]) + + members = list(group.get_members()) + + assert members + assert len(members) == 2 + assert isinstance(members[0], IdentityInfo) + assert isinstance(members[1], IdentityInfo) + + assert {member.identifier for member in members} == set(users) + + +def test_get_members_returns_empty_list(group, group_provider, group_name): + """ + arrange: given no users + act: call get_members + assert: get_members returns an empty iterator + """ + group_provider.add_group(group_name) + + members = list(group.get_members()) + + assert not members + + +def test_get_members_returns_empty_list_for_non_existing_group(group): + """ + arrange: given no underlying db group + act: call get_members + assert: get_members returns an empty iterator + """ + members = list(group.get_members()) + + assert not members + + +def test_has_member(group, group_provider, group_name): + """ + arrange: given a user which belongs to a group + act: call has_member + assert: has_user returns True + """ + user_identifier = token_hex(16) + group_provider.add_group_member(identifier=user_identifier, group_name=group_name) + + assert group.has_member(user_identifier) + + +def test_has_member_returns_false(group, group_provider, group_name): + """ + arrange: given a user identifier which does not belong to a group + act: call has_member + assert: has_member returns False + """ + user_identifiers = [token_hex(16), token_hex(6)] + group_provider.add_group_member(identifier=user_identifiers[0], group_name=group_name) + assert not group.has_member(user_identifiers[1]) + + +def test_has_member_returns_false_for_non_existing_group(group): + """ + arrange: given no underlying db group + act: call has_member + assert: has_member returns False + """ + user_identifier = token_hex(16) + assert not group.has_member(user_identifier) diff --git a/tests/unit/group_provider/sql/test_provider.py b/tests/unit/group_provider/sql/test_provider.py new file mode 100644 index 0000000..3986815 --- /dev/null +++ b/tests/unit/group_provider/sql/test_provider.py @@ -0,0 +1,282 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Unit tests for the sql group provider.""" + + +from secrets import token_hex + +import pytest +from flask_multipass import IdentityProvider, Multipass +from indico.core.db import db + +from flask_multipass_saml_groups.group_provider.sql import SQLGroup, SQLGroupProvider +from flask_multipass_saml_groups.models.saml_groups import SAMLGroup as DBGroup +from flask_multipass_saml_groups.models.saml_groups import SAMLUser + +NOT_EXISTING_USER_IDENTIFIER = "user-3" +NOT_EXISTING_GRP_NAME = "not_existing" + + +@pytest.fixture(name="group_names") +def group_names_fixture(): + """Return group names""" + return [token_hex(16), token_hex(16)] + + +@pytest.fixture(name="user_identifiers") +def user_identifiers_fixture(): + """Return user identifiers""" + return [token_hex(16), token_hex(16)] + + +@pytest.fixture(name="group_provider") +def group_provider_fixture(app, group_names, user_identifiers): + """Setup a group provider and place groups and users in the database. + + The first user is placed in the first group. + The second user belongs to no group. + The second group has no members. + """ + with app.app_context(): + multipass = Multipass(app=app) + group_provider = SQLGroupProvider( + identity_provider=IdentityProvider( + multipass=multipass, name="saml_groups", settings={} + ), + ) + user1 = SAMLUser(identifier=user_identifiers[0]) + user2 = SAMLUser(identifier=user_identifiers[1]) + + # pylint does not recognize the methods of db.session, which is a proxy object + # pylint: disable=no-member + db.session.add(user1) + db.session.add(user2) + grp1 = DBGroup(name=group_names[0]) + grp1.members.append(user1) + db.session.add(grp1) + db.session.add(DBGroup(name=group_names[1])) + db.session.commit() + + yield group_provider + + +def test_get_group(group_provider, group_names): + """ + arrange: given a GroupProvider instance + act: call get_group with a specific group name + assert: returns a SQLGroup instance with the same name + """ + grp = group_provider.get_group(group_names[0]) + + assert isinstance(grp, SQLGroup) + assert grp.name == group_names[0] + + +def test_get_group_not_found(group_provider): + """ + arrange: given a GroupProvider instance + act: call get_group with a non existing group name + assert: returns None + """ + grp = group_provider.get_group("non-existing") + + assert grp is None + + +def test_get_groups(group_provider, group_names): + """ + arrange: given a GroupProvider instance + act: call get_groups + assert: returns an iterable of all groups + """ + grps = list(group_provider.get_groups()) + + assert grps + assert len(grps) == 2 + assert {grp.name for grp in grps} == set(group_names) + + +def test_get_user_groups(group_provider, user_identifiers, group_names): + """ + arrange: given a user identifier + act: call get_user_groups + assert: returns an iterable of groups the user belongs to + """ + grps = list(group_provider.get_user_groups(user_identifiers[0])) + + assert grps + assert len(grps) == 1 + assert grps[0].name == group_names[0] + + +def test_get_user_groups_without_groups(group_provider, user_identifiers): + """ + arrange: given a user identifier for a user who belongs to no groups + act: call get_user_groups + assert: returns empty list + """ + grps = list(group_provider.get_user_groups(user_identifiers[1])) + + assert not grps + + +def test_get_user_groups_for_non_existing_user(group_provider): + """ + arrange: given a user identifier for a non existing user + act: call get_user_groups + assert: returns empty list + """ + grps = list(group_provider.get_user_groups("non-existing")) + + assert not grps + + +def test_add_group(group_provider): + """ + arrange: given a GroupProvider instance + act: call add_group with a group name + assert: the group can be retrieved calling get_group + """ + group_provider.add_group(NOT_EXISTING_GRP_NAME) + grp = group_provider.get_group(NOT_EXISTING_GRP_NAME) + + assert isinstance(grp, SQLGroup) + assert grp.name == NOT_EXISTING_GRP_NAME + + +def test_add_group_group_already_existing(group_provider, group_names): + """ + arrange: given a group that already exists + act: call add_group with that group name + assert: the group can be retrieved calling get_group + """ + group_name = group_names[0] + group_provider.add_group(group_name) + + grp = group_provider.get_group(group_name) + assert isinstance(grp, SQLGroup) + assert grp.name == group_name + + +def test_add_group_member(group_provider, user_identifiers, group_names): + """ + arrange: given a user who does not belong to a group + act: call add_group_member with that user and group + assert: the user belongs to the group + """ + user_identifier = user_identifiers[0] + group_name = group_names[1] + group_provider.add_group_member(user_identifier, group_name) + + grp = group_provider.get_group(group_name) + assert isinstance(grp, SQLGroup) + assert grp.name == group_name + members = list(grp.get_members()) + assert len(members) == 1 + assert members[0].identifier == user_identifier + + +def test_add_group_member_user_non_existing(group_provider, group_names): + """ + arrange: given a non existing user identifier + act: call add_group_member with that user and a group name + assert: the user gets created and belongs to the group + """ + group_name = group_names[1] + group_provider.add_group_member(NOT_EXISTING_USER_IDENTIFIER, group_name) + + grp = group_provider.get_group(group_name) + assert isinstance(grp, SQLGroup) + assert grp.name == group_name + members = list(grp.get_members()) + assert len(members) == 1 + assert members[0].identifier == NOT_EXISTING_USER_IDENTIFIER + + +def test_add_group_member_group_non_existing(group_provider, user_identifiers): + """ + arrange: given a non existing group + act: call add_group_member with a user identifier that group + assert: the group gets created and the user belongs to the group + """ + user_identifier = user_identifiers[0] + group_provider.add_group_member(user_identifier, NOT_EXISTING_GRP_NAME) + + grp = group_provider.get_group(NOT_EXISTING_GRP_NAME) + assert isinstance(grp, SQLGroup) + assert grp.name == NOT_EXISTING_GRP_NAME + members = list(grp.get_members()) + assert len(members) == 1 + assert members[0].identifier == user_identifier + + +def test_add_group_member_pair_already_existing(group_provider, user_identifiers, group_names): + """ + arrange: given a user and a group to which the user does not belong + act: call add_group_member twice + assert: the user belongs to the group, but is returned only once + """ + user_identifier = user_identifiers[0] + group_name = group_names[1] + + group_provider.add_group_member(user_identifier, group_name) + group_provider.add_group_member(user_identifier, group_name) + + grp = group_provider.get_group(group_name) + assert isinstance(grp, SQLGroup) + assert grp.name == group_name + members = list(grp.get_members()) + assert len(members) == 1 + assert members[0].identifier == user_identifier + + +def test_remove_group_member(group_provider, user_identifiers, group_names): + """ + arrange: given a user which belongs to a group + act: call remove_group_member + assert: the user does not belong to the group anymore + """ + user_identifier = user_identifiers[0] + group_name = group_names[0] + + group_provider.remove_group_member(user_identifier, group_name) + + grp = group_provider.get_group(group_name) + assert isinstance(grp, SQLGroup) + assert grp.name == group_name + members = list(grp.get_members()) + assert not members + + +def test_remove_group_member_user_non_existing(group_provider, user_identifiers, group_names): + """ + arrange: given a group with a user + act: call remove_group_member with a non existing user identifier and that group + assert: the member list without the user is returned + """ + user_identifier = user_identifiers[0] + group_name = group_names[0] + + group_provider.remove_group_member(NOT_EXISTING_USER_IDENTIFIER, group_name) + + grp = group_provider.get_group(group_name) + assert isinstance(grp, SQLGroup) + assert grp.name == group_name + members = list(grp.get_members()) + assert len(members) == 1 + assert members[0].identifier == user_identifier + + +def test_remove_group_member_group_non_existing(group_provider, user_identifiers): + """ + arrange: given a user and a non existing group + act: call remove_group_member with that user and the non existing group + assert: the group does not exist + """ + user_identifier = user_identifiers[0] + + group_provider.remove_group_member(user_identifier, NOT_EXISTING_GRP_NAME) + + grp = group_provider.get_group(NOT_EXISTING_GRP_NAME) + assert not grp diff --git a/tests/unit/test_provider.py b/tests/unit/test_provider.py new file mode 100644 index 0000000..f4310ec --- /dev/null +++ b/tests/unit/test_provider.py @@ -0,0 +1,305 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Unit tests for the identity provider.""" +from copy import copy +from secrets import token_hex +from unittest.mock import Mock + +import pytest +from flask import Flask +from flask_multipass import AuthInfo, IdentityRetrievalFailed, Multipass +from werkzeug.datastructures import MultiDict + +from flask_multipass_saml_groups.provider import ( + DEFAULT_IDENTIFIER_FIELD, + SAML_GRP_ATTR_NAME, + SAMLGroupsIdentityProvider, +) +from tests.common import setup_sqlite + +USER_EMAIL = "user@example.com" +OTHER_USER_EMAIL = "other@example.com" + + +@pytest.fixture(name="group_names") +def group_names_fixture(): + """A list of group names.""" + return [token_hex(16), token_hex(16)] + + +@pytest.fixture(name="saml_attrs") +def saml_attrs_fixture(group_names): + """SAML attributes for a user. + + The user belongs to all groups. + """ + return { + "_saml_nameid": USER_EMAIL, + DEFAULT_IDENTIFIER_FIELD: f"{USER_EMAIL}@https://site", + "email": USER_EMAIL, + "fullname": "Foo bar", + "openid": "https://openid", + "userid": USER_EMAIL, + "username": "user", + SAML_GRP_ATTR_NAME: group_names, + } + + +@pytest.fixture(name="saml_attrs_other_user") +def saml_attrs_other_user_fixture(group_names): + """SAML attributes for another user. + + This user belongs only to the second group. + """ + return { + "_saml_nameid": OTHER_USER_EMAIL, + DEFAULT_IDENTIFIER_FIELD: f"{OTHER_USER_EMAIL}@https://site", + "email": OTHER_USER_EMAIL, + "fullname": "Foo bar", + "openid": "https://openid", + "userid": OTHER_USER_EMAIL, + "username": "other", + SAML_GRP_ATTR_NAME: group_names[1], # single elements are expected as str + } + + +@pytest.fixture(name="auth_info") +def auth_info_fixture(saml_attrs): + """The AuthInfo object for a user.""" + return AuthInfo(provider=Mock(), **saml_attrs) + + +@pytest.fixture(name="auth_info_other_user") +def auth_info_other_user_fixture(saml_attrs_other_user): + """The AuthInfo object for another user.""" + return AuthInfo(provider=Mock(), **saml_attrs_other_user) + + +@pytest.fixture(name="provider") +def provider_fixture(): + """Setup a SAMLGroupsIdentityProvider.""" + app = Flask("test") + multipass = Multipass(app) + + setup_sqlite(app) + with app.app_context(): + yield SAMLGroupsIdentityProvider(multipass=multipass, name="saml_groups", settings={}) + + +@pytest.fixture(name="provider_custom_field") +def provider_custom_field_fixture(): + """Setup a SAMLGroupsIdentityProvider with a custom identifier_field.""" + app = Flask("test") + multipass = Multipass(app) + + setup_sqlite(app) + + with app.app_context(): + yield SAMLGroupsIdentityProvider( + multipass=multipass, + name="saml_groups", + settings={"identifier_field": "fullname"}, + ) + + +def test_get_identity_from_auth_returns_identity_info(provider, auth_info, saml_attrs): + """ + arrange: given AuthInfo by AuthProvider + act: call get_identity_from_auth from identity provider + assert: the returned IdentityInfo object contains the expected data from AuthInfo + """ + identity_info = provider.get_identity_from_auth(auth_info) + + assert identity_info is not None + assert identity_info.provider == provider + assert identity_info.identifier == auth_info.data[DEFAULT_IDENTIFIER_FIELD] + assert identity_info.data == MultiDict(saml_attrs) + + +def test_get_identity_from_auth_returns_identity_from_custom_field( + auth_info, provider_custom_field +): + """ + arrange: given AuthInfo and provider using custom identifier_field + act: call get_identity_from_auth from SAMLGroupsIdentityProvider + assert: the returned IdentityInfo object uses the custom field as identifier + """ + identity_info = provider_custom_field.get_identity_from_auth(auth_info) + + assert identity_info is not None + assert identity_info.identifier == auth_info.data["fullname"] + + +def test_get_identity_from_auth_returns_identity_from_list(auth_info, provider_custom_field): + """ + arrange: given AuthInfo using a one element list for identifier_field + act: call get_identity_from_auth from SAMLGroupsIdentityProvider + assert: the returned IdentityInfo object uses value from the list as identifier + """ + fullname = token_hex(10) + auth_info.data["fullname"] = [fullname] + + identity_info = provider_custom_field.get_identity_from_auth(auth_info) + + assert identity_info is not None + assert identity_info.identifier == fullname + + +def test_get_identity_from_auth_raises_exc_for_multi_val_identifier( + auth_info, provider_custom_field +): + """ + arrange: given AuthInfo using identifier_field with multiple values + act: call get_identity_from_auth from SAMLGroupsIdentityProvider + assert: an exception is raised + """ + fullnames = [token_hex(10), token_hex(15)] + + auth_info.data["fullname"] = fullnames + + with pytest.raises(IdentityRetrievalFailed): + provider_custom_field.get_identity_from_auth(auth_info) + + +def test_get_identity_from_auth_raises_exc_for_no_identifier(auth_info, provider): + """ + arrange: given AuthInfo which does not provide value for identifier field + act: call get_identity_from_auth from SAMLGroupsIdentityProvider + assert: an exception is raised + """ + del auth_info.data[DEFAULT_IDENTIFIER_FIELD] + + with pytest.raises(IdentityRetrievalFailed): + provider.get_identity_from_auth(auth_info) + + +def test_get_identity_from_auth_adds_user_to_group(auth_info, provider, group_names): + """ + arrange: given AuthInfo by AuthProvider + act: call get_identity_from_auth from SAMLGroupsIdentityProvider + assert: the user is added to the groups + """ + provider.get_identity_from_auth(auth_info) + + for grp_name in group_names: + group = provider.get_group(grp_name) + members = list(group.get_members()) + assert members + assert members[0].identifier == auth_info.data[DEFAULT_IDENTIFIER_FIELD] + + +def test_get_identity_from_auth_adds_user_to_existing_group( + auth_info, auth_info_other_user, provider, group_names +): + """ + arrange: given AuthInfo of two users by AuthProvider + act: call get_identity_from_auth twice and second time with another user + assert: the user is added to the existing group + """ + provider.get_identity_from_auth(auth_info) + provider.get_identity_from_auth(auth_info_other_user) + + group = provider.get_group(group_names[0]) + members = list(group.get_members()) + assert members + assert members[0].identifier == auth_info.data[DEFAULT_IDENTIFIER_FIELD] + + group = provider.get_group(group_names[1]) + members = list(group.get_members()) + assert len(members) == 2 + expected_identifiers = { + auth_info.data[DEFAULT_IDENTIFIER_FIELD], + auth_info_other_user.data[DEFAULT_IDENTIFIER_FIELD], + } + assert members[0].identifier in expected_identifiers + assert members[1].identifier in expected_identifiers + + +def test_get_identity_from_auth_removes_user_from_group(auth_info, provider, group_names): + """ + arrange: given AuthInfo by AuthProvider + act: call get_identity_from_auth and afterwards again with a group removed + assert: the user is removed from the group + """ + provider.get_identity_from_auth(auth_info) + + group = provider.get_group(group_names[0]) + members = list(group.get_members()) + assert members + assert members[0].identifier == auth_info.data[DEFAULT_IDENTIFIER_FIELD] + + auth_info_grp_removed = copy(auth_info) + auth_info_grp_removed.data[SAML_GRP_ATTR_NAME] = group_names[1] + provider.get_identity_from_auth(auth_info_grp_removed) + + group = provider.get_group(group_names[0]) + members = list(group.get_members()) + assert not members + + group = provider.get_group(group_names[1]) + members = list(group.get_members()) + assert members + assert members[0].identifier == auth_info.data[DEFAULT_IDENTIFIER_FIELD] + + +def test_get_group_returns_specific_group(auth_info, provider, group_names): + """ + arrange: given AuthInfo by AuthProvider + act: call get_identity_from_auth and afterwards get_group with a specific group name + assert: the returned group name is the one requested + """ + provider.get_identity_from_auth(auth_info) + group = provider.get_group(group_names[0]) + + assert group.name == group_names[0] + + +def test_get_group_returns_none_if_no_auth_handled(provider, group_names): + """ + arrange: given only an SAMLGroupsIdentityProvider whose methods have never been called + act: call get_group from SAMLGroupsIdentityProvider with a specific group name + assert: the result is None + """ + group = provider.get_group(group_names[0]) + + assert group is None + + +def test_get_identity_groups(auth_info, provider, group_names): + """ + arrange: given AuthInfo by AuthProvider + act: call get_identity_from_auth and afterwards get_identity_groups + assert: the returned groups of the user are the ones expected + """ + provider.get_identity_from_auth(auth_info) + groups = list(provider.get_identity_groups(auth_info.data[DEFAULT_IDENTIFIER_FIELD])) + + assert len(groups) == 2 + assert set(g.name for g in groups) == set(group_names) + + +def test_search_groups_returns_all_matched_groups(auth_info, provider, group_names): + """ + arrange: given AuthInfo by AuthProvider + act: call get_identity_from_auth and afterwards search_groups + assert: the returned list of groups contains all the groups the user belongs to + """ + provider.get_identity_from_auth(auth_info) + groups = list(provider.search_groups(group_names[0], exact=True)) + + assert len(groups) == 1 + assert groups[0].name == group_names[0] + + +def test_search_groups_non_exact_returns_all_matched_groups(auth_info, provider, group_names): + """ + arrange: given AuthInfo by AuthProvider + act: call get_identity_from_auth and afterwards search_groups using exact=False + assert: the returned list of groups contains all the groups the user belongs to + """ + provider.get_identity_from_auth(auth_info) + groups = list(provider.search_groups(group_names[0][:-1], exact=False)) + + assert len(groups) == 1 + assert groups[0].name == group_names[0] diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..f51d8cd --- /dev/null +++ b/tox.ini @@ -0,0 +1,100 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +[tox] +skipsdist=True +skip_missing_interpreters = True +envlist = lint, unit, static, coverage-report + +[vars] +src_path = {toxinidir}/flask_multipass_saml_groups/ +tst_path = {toxinidir}/tests/ +all_path = {[vars]src_path} {[vars]tst_path} + +[testenv] +setenv = + PYTHONPATH = {toxinidir}:{[vars]src_path} + PYTHONBREAKPOINT=ipdb.set_trace + PY_COLORS=1 +passenv = + PYTHONPATH + +[testenv:fmt] +description = Apply coding style standards to code +deps = + black + isort +commands = + isort {[vars]all_path} + black {[vars]all_path} + +[testenv:lint] +description = Check code against coding style standards +deps = + black + codespell + flake8<6.0.0 + flake8-builtins + flake8-copyright<6.0.0 + flake8-docstrings>=1.6.0 + flake8-docstrings-complete>=1.0.3 + flake8-test-docs>=1.0 + isort + mypy + pep8-naming + pydocstyle>=2.10 + pylint + pyproject-flake8<6.0.0 + pytest + types-PyYAML + types-requests + -r{toxinidir}/requirements.txt +commands = + pydocstyle {[vars]src_path} + codespell {toxinidir} --skip {toxinidir}/.git --skip {toxinidir}/.tox \ + --skip {toxinidir}/build --skip {toxinidir}/venv \ + --skip {toxinidir}/.mypy_cache + # pflake8 wrapper supports config from pyproject.toml + pflake8 {[vars]all_path} --ignore=W503 + isort --check-only --diff {[vars]all_path} + black --check --diff {[vars]all_path} + mypy {[vars]all_path} + pylint {[vars]all_path} + + +[testenv:unit] +description = Run unit tests +deps = + coverage[toml] + pytest + -r{toxinidir}/requirements.txt +commands = + coverage run --source={[vars]src_path} --omit={[vars]src_path}/plugin.py \ + -m pytest --ignore={[vars]tst_path}integration -v --tb native -s {posargs} + coverage report + +[testenv:coverage-report] +description = Create test coverage report +deps = + coverage[toml] + pytest + -r{toxinidir}/requirements.txt +commands = + coverage report + +[testenv:static] +description = Run static analysis tests +deps = + bandit[toml] + -r{toxinidir}/requirements.txt +commands = + bandit -c {toxinidir}/pyproject.toml -r {[vars]src_path} {[vars]tst_path} + + +[testenv:integration] +description = Run integration tests +deps = + pytest + -r{toxinidir}/requirements.txt +commands = + pytest -v --tb native --ignore={[vars]tst_path}unit --log-cli-level=INFO -s {posargs}