Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Row level security #134

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions balsam/schemas/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ class UserCreate(BaseModel):
class UserOut(BaseModel):
id: int
username: str
token: str
1 change: 0 additions & 1 deletion balsam/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ class ValidationError(HTTPException):
def __init__(self, detail: str) -> None:
super().__init__(status_code=status.HTTP_400_BAD_REQUEST, detail=detail)


settings = Settings()
__all__ = ["settings", "Settings", "ValidationError"]
6 changes: 6 additions & 0 deletions balsam/server/auth/password_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session, exc
from sqlalchemy.exc import SQLAlchemyError

from balsam.schemas import UserCreate, UserOut
from balsam.server.models.crud import users
Expand Down Expand Up @@ -43,6 +44,9 @@ def login(

user = authenticate_user_password(db, username, password)
token, expiry = create_access_token(user)
sql = "ALTER USER {} PASSWORD '{}' VALID UNTIL '{}'".format(user.username, token, expiry.strftime("%b %d %Y"))
db.execute(sql)

return {"access_token": token, "token_type": "bearer", "expiration": expiry}


Expand All @@ -57,5 +61,7 @@ def register(user: UserCreate, db: Session = Depends(get_admin_session)) -> User
raise HTTPException(status_code=400, detail="Username already taken")

new_user = users.create_user(db, user.username, user.password)
print("REGISTERING USER ",user.username)

db.commit()
return new_user
3 changes: 1 addition & 2 deletions balsam/server/auth/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,4 @@ def user_from_token(token: str = Depends(oauth2_scheme)) -> schemas.UserOut:
except PyJWTError:
raise credentials_exception

print("user_from_token has identified the user:", username)
return schemas.UserOut(id=user_id, username=username)
return schemas.UserOut(id=user_id, token=token, username=username)
4 changes: 4 additions & 0 deletions balsam/server/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import uvicorn

from fastapi import FastAPI, HTTPException, Request, WebSocket, status
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -117,3 +118,6 @@ async def subscribe_user(websocket: WebSocket) -> None:
app.add_middleware(TimingMiddleware, router=app.router)
logger.info("Loaded balsam.server.main")
logger.info(settings.serialize_without_secrets())

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
143 changes: 142 additions & 1 deletion balsam/server/models/alembic/versions/f8fbad8262e3_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from alembic import op
from sqlalchemy import text
from sqlalchemy.dialects import postgresql
from sqlalchemy import Enum, Table, Column, Integer, LargeBinary, Text, String, ForeignKey, DateTime, Boolean, Float, Sequence, INTEGER, literal_column, select, column
from datetime import datetime

# revision identifiers, used by Alembic.
revision = "f8fbad8262e3"
Expand All @@ -19,14 +21,29 @@

def upgrade():
# ### commands auto generated by Alembic - please adjust! ###

op.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"')
print("Added uuid-ossp extension")
op.create_table(
"users",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("username", sa.String(length=100), nullable=False),
sa.Column("hashed_password", sa.String(length=128), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("username"),

sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"users\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY users_policy ON \"users\" USING (users.owner::text = current_user)")

op.create_table(
"device_code_attempts",
sa.Column("client_id", postgresql.UUID(as_uuid=True)),
Expand All @@ -40,14 +57,36 @@ def upgrade():
sa.UniqueConstraint("user_code"),
sa.UniqueConstraint("device_code"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"device_code_attempts\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY device_code_attempts_policy ON \"device_code_attempts\" USING (device_code_attempts.owner::text = current_user)")

op.create_table(
"auth_states",
sa.Column("id", sa.Integer()),
sa.Column("state", sa.String(length=512), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("state"),
sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"auth_states\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY auth_states_policy ON \"auth_states\" USING (auth_states.owner::text = current_user)")

op.create_table(
"sites",
sa.Column("id", sa.Integer(), nullable=False),
Expand All @@ -66,7 +105,18 @@ def upgrade():
sa.ForeignKeyConstraint(["owner_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("hostname", "path"),
sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"sites\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY sites_policy ON \"sites\" USING (sites.owner::text = current_user)")

op.create_index(op.f("ix_sites_owner_id"), "sites", ["owner_id"], unique=False)
op.create_table(
"apps",
Expand All @@ -80,7 +130,18 @@ def upgrade():
sa.ForeignKeyConstraint(["site_id"], ["sites.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("site_id", "class_path"),
sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"apps\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY apps_policy ON \"apps\" USING (apps.owner::text = current_user)")

op.create_table(
"batch_jobs",
sa.Column("id", sa.Integer(), nullable=False),
Expand All @@ -100,7 +161,18 @@ def upgrade():
sa.Column("end_time", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(["site_id"], ["sites.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"batch_jobs\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY batch_jobs_policy ON \"batch_jobs\" USING (batch_jobs.owner::text = current_user)")

op.create_index(op.f("ix_batch_jobs_state"), "batch_jobs", ["state"], unique=False)
op.create_table(
"sessions",
Expand All @@ -111,7 +183,18 @@ def upgrade():
sa.ForeignKeyConstraint(["batch_job_id"], ["batch_jobs.id"], ondelete="SET NULL"),
sa.ForeignKeyConstraint(["site_id"], ["sites.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"sessions\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY sessions_policy ON \"sessions\" USING (sessions.owner::text = current_user)")

op.create_table(
"jobs",
sa.Column("id", sa.Integer(), nullable=False),
Expand All @@ -138,7 +221,18 @@ def upgrade():
sa.ForeignKeyConstraint(["batch_job_id"], ["batch_jobs.id"], ondelete="SET NULL"),
sa.ForeignKeyConstraint(["session_id"], ["sessions.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"jobs\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY jobs_policy ON \"jobs\" USING (jobs.owner::text = current_user)")

op.create_index(op.f("ix_jobs_state"), "jobs", ["state"], unique=False)

# Correct way of creating index on tags supporting fast @> (contains) lookups:
Expand All @@ -156,7 +250,18 @@ def upgrade():
sa.ForeignKeyConstraint(["child_id"], ["jobs.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["parent_id"], ["jobs.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("parent_id", "child_id"),
sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"job_deps\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY job_deps_policy ON \"job_deps\" USING (job_deps.owner::text = current_user)")

op.create_index(op.f("ix_job_deps_child_id"), "job_deps", ["child_id"], unique=False)
op.create_index(op.f("ix_job_deps_parent_id"), "job_deps", ["parent_id"], unique=False)
op.create_table(
Expand All @@ -169,7 +274,18 @@ def upgrade():
sa.Column("data", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"log_events\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY log_events_policy ON \"log_events\" USING (log_events.owner::text = current_user)")

op.create_table(
"transfer_items",
sa.Column("id", sa.Integer(), nullable=False),
Expand Down Expand Up @@ -200,7 +316,17 @@ def upgrade():
sa.Column("transfer_info", sa.JSON(), nullable=True),
sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.Column("uid", String(40), autoincrement=False, default=literal_column('uuid_generate_v4()'), unique=True),

sa.Column("owner", String(40), default=literal_column('current_user')),

sa.Column("created", DateTime, default=datetime.now, nullable=False),
sa.Column("lastupdated", DateTime, default=datetime.now,
onupdate=datetime.now, nullable=False)
)

op.execute("ALTER TABLE \"transfer_items\" ENABLE ROW LEVEL SECURITY")
op.execute("CREATE POLICY transfer_items_policy ON \"transfer_items\" USING (transfer_items.owner::text = current_user)")
# ### end Alembic commands ###


Expand All @@ -220,4 +346,19 @@ def downgrade():
op.drop_index(op.f("ix_sites_owner_id"), table_name="sites")
op.drop_table("sites")
op.drop_table("users")
# ### end Alembic commands ###

# list users and remove users not 'postgres'
connection = op.get_bind()
users = connection.execute("SELECT * FROM pg_user")

for user in users:
try:
username = user['usename']

if username == 'postgres':
continue

op.execute(f"DROP USER {username}")
except Exception as ex :
print(f"Encountered error trying to drop {username}.")
# ### end Alembic commands ###
9 changes: 7 additions & 2 deletions balsam/server/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
from sqlalchemy import create_engine, orm
from sqlalchemy.engine import Engine
from sqlalchemy.ext.declarative import declarative_base
from fastapi import Depends
from balsam.server import settings

auth = settings.auth.get_auth_method()

import balsam.server
from balsam import schemas
from balsam.schemas.user import UserOut

logger = logging.getLogger(__name__)
Expand All @@ -15,7 +20,7 @@
_Session = None


def get_engine() -> Engine:
def get_engine(user: schemas.UserOut) -> Engine:
global _engine
if _engine is None:
logger.info(f"Creating DB engine: {balsam.server.settings.database_url}")
Expand All @@ -31,7 +36,7 @@ def get_engine() -> Engine:
def get_session(user: Optional[UserOut] = None) -> orm.Session:
global _Session
if _Session is None:
_Session = orm.sessionmaker(bind=get_engine())
_Session = orm.sessionmaker(bind=get_engine(user))

session: orm.Session = _Session()
return session
Expand Down
3 changes: 3 additions & 0 deletions balsam/server/models/crud/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def create_user(db: Session, username: str, password: Optional[str]) -> UserOut:
new_user = User(username=username, hashed_password=hashed)
else:
new_user = User(username=username)

sql = f"CREATE USER {username} WITH PASSWORD '{hashed}'"
db.execute(sql)
db.add(new_user)
db.flush()
return UserOut(id=new_user.id, username=new_user.username)
Expand Down
Loading