Skip to content

Commit

Permalink
AIP-65: Update dag source endpoint to support versioning (apache#43492)
Browse files Browse the repository at this point in the history
* AIP-65: Update dag source endpoint to support versioning

Enhanced the DAG source endpoint to support version-based retrieval

Refactored the get_dag_source function to allow fetching specific versions of DAG source code using dag_id, version_name, and version_number parameters.

Replaced file_token with dag_id in endpoint paths and removed unnecessary token-based access.

Updated OpenAPI specifications and requested serializers to include new versioning parameters.

Modified API response schema to include dag_id, version_name, and version_number for improved version tracking.

Added/updated tests

* Remove version_name

* fixup! Remove version_name

* fixup! fixup! Remove version_name

* Fix test in fab provider

* fix conflicts

* Remove async def

* fix conflicts
  • Loading branch information
ephraimbuddy authored Nov 19, 2024
1 parent 21fc21d commit cc8aa7b
Show file tree
Hide file tree
Showing 31 changed files with 480 additions and 419 deletions.
58 changes: 34 additions & 24 deletions airflow/api_connexion/endpoints/dag_source_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Sequence

from flask import Response, current_app, request
from itsdangerous import BadSignature, URLSafeSerializer
from flask import Response, request
from sqlalchemy import select

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound, PermissionDenied
from airflow.api_connexion.schemas.dag_source_schema import dag_source_schema
from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.models.dag import DagModel
from airflow.models.dagcode import DagCode
from airflow.models.dag_version import DagVersion
from airflow.utils.api_migration import mark_fastapi_migration_done
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www.extensions.init_auth_manager import get_auth_manager
Expand All @@ -41,32 +41,42 @@
@mark_fastapi_migration_done
@security.requires_access_dag("GET", DagAccessEntity.CODE)
@provide_session
def get_dag_source(*, file_token: str, session: Session = NEW_SESSION) -> Response:
"""Get source code using file token."""
secret_key = current_app.config["SECRET_KEY"]
auth_s = URLSafeSerializer(secret_key)
try:
path = auth_s.loads(file_token)
dag_ids = session.query(DagModel.dag_id).filter(DagModel.fileloc == path).all()
requests: Sequence[IsAuthorizedDagRequest] = [
{
"method": "GET",
"details": DagDetails(id=dag_id[0]),
}
for dag_id in dag_ids
]
def get_dag_source(
*,
dag_id: str,
version_number: int | None = None,
session: Session = NEW_SESSION,
) -> Response:
"""Get source code from DagCode."""
dag_version = DagVersion.get_version(dag_id, version_number, session=session)
if not dag_version:
raise NotFound(f"The source code of the DAG {dag_id}, version_number {version_number} was not found")
path = dag_version.dag_code.fileloc
dag_ids = session.scalars(select(DagModel.dag_id).where(DagModel.fileloc == path)).all()
requests: Sequence[IsAuthorizedDagRequest] = [
{
"method": "GET",
"details": DagDetails(id=dag_id[0]),
}
for dag_id in dag_ids
]

# Check if user has read access to all the DAGs defined in the file
if not get_auth_manager().batch_is_authorized_dag(requests):
raise PermissionDenied()
dag_source = DagCode.code(path, session=session)
except (BadSignature, FileNotFoundError):
raise NotFound("Dag source not found")
# Check if user has read access to all the DAGs defined in the file
if not get_auth_manager().batch_is_authorized_dag(requests):
raise PermissionDenied()
dag_source = dag_version.dag_code.source_code
version_number = dag_version.version_number

return_type = request.accept_mimetypes.best_match(["text/plain", "application/json"])
if return_type == "text/plain":
return Response(dag_source, headers={"Content-Type": return_type})
if return_type == "application/json":
content = dag_source_schema.dumps({"content": dag_source})
content = dag_source_schema.dumps(
{
"content": dag_source,
"dag_id": dag_id,
"version_number": version_number,
}
)
return Response(content, headers={"Content-Type": return_type})
return Response("Not Allowed Accept Header", status=HTTPStatus.NOT_ACCEPTABLE)
16 changes: 13 additions & 3 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2263,17 +2263,18 @@ paths:
"403":
$ref: "#/components/responses/PermissionDenied"

/dagSources/{file_token}:
/dagSources/{dag_id}:
parameters:
- $ref: "#/components/parameters/FileToken"

- $ref: "#/components/parameters/DAGID"
get:
summary: Get a source code
description: >
Get a source code using file token.
x-openapi-router-controller: airflow.api_connexion.endpoints.dag_source_endpoint
operationId: get_dag_source
tags: [DAG]
parameters:
- $ref: "#/components/parameters/VersionNumber"
responses:
"200":
description: Success.
Expand Down Expand Up @@ -5860,6 +5861,15 @@ components:
description: |
List of field for return.
VersionNumber:
in: query
name: version_number
schema:
type: integer
description: |
The version number.
# Reusable request bodies
requestBodies: {}

Expand Down
2 changes: 2 additions & 0 deletions airflow/api_connexion/schemas/dag_source_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class DagSourceSchema(Schema):
"""Dag Source schema."""

content = fields.String(dump_only=True)
dag_id = fields.String(dump_only=True)
version_number = fields.Integer(dump_only=True)


dag_source_schema = DagSourceSchema()
2 changes: 2 additions & 0 deletions airflow/api_fastapi/core_api/datamodels/dag_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ class DAGSourceResponse(BaseModel):
"""DAG Source serializer for responses."""

content: str | None
dag_id: str
version_number: int | None
24 changes: 21 additions & 3 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1602,20 +1602,28 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dagSources/{file_token}:
/public/dagSources/{dag_id}:
get:
tags:
- DagSource
summary: Get Dag Source
description: Get source code using file token.
operationId: get_dag_source
parameters:
- name: file_token
- name: dag_id
in: path
required: true
schema:
type: string
title: File Token
title: Dag Id
- name: version_number
in: query
required: false
schema:
anyOf:
- type: integer
- type: 'null'
title: Version Number
- name: accept
in: header
required: false
Expand Down Expand Up @@ -5132,9 +5140,19 @@ components:
- type: string
- type: 'null'
title: Content
dag_id:
type: string
title: Dag Id
version_number:
anyOf:
- type: integer
- type: 'null'
title: Version Number
type: object
required:
- content
- dag_id
- version_number
title: DAGSourceResponse
description: DAG Source serializer for responses.
DAGTagCollectionResponse:
Expand Down
25 changes: 13 additions & 12 deletions airflow/api_fastapi/core_api/routes/public/dag_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
from typing import Annotated

from fastapi import Depends, Header, HTTPException, Request, Response, status
from itsdangerous import BadSignature, URLSafeSerializer
from sqlalchemy.orm import Session

from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.dag_sources import DAGSourceResponse
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.models.dagcode import DagCode
from airflow.models.dag_version import DagVersion

dag_sources_router = AirflowRouter(tags=["DagSource"], prefix="/dagSources")

Expand All @@ -36,7 +35,7 @@


@dag_sources_router.get(
"/{file_token}",
"/{dag_id}",
responses={
**create_openapi_http_exception_doc(
[
Expand All @@ -55,21 +54,23 @@
response_model=DAGSourceResponse,
)
def get_dag_source(
file_token: str,
dag_id: str,
session: Annotated[Session, Depends(get_session)],
request: Request,
accept: Annotated[str, Header()] = mime_type_any,
version_number: int | None = None,
):
"""Get source code using file token."""
auth_s = URLSafeSerializer(request.app.state.secret_key)

try:
path = auth_s.loads(file_token)
dag_source_model = DAGSourceResponse(
content=DagCode.code(path, session=session),
dag_version = DagVersion.get_version(dag_id, version_number, session=session)
if not dag_version:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
f"The source code of the DAG {dag_id}, version_number {version_number} was not found",
)
except (BadSignature, FileNotFoundError):
raise HTTPException(status.HTTP_404_NOT_FOUND, "DAG source not found")

dag_source = dag_version.dag_code.source_code
version_number = dag_version.version_number
dag_source_model = DAGSourceResponse(dag_id=dag_id, content=dag_source, version_number=version_number)

if accept.startswith(mime_type_text):
return Response(dag_source_model.content, media_type=mime_type_text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def upgrade():
"dag_version",
sa.Column("id", UUIDType(binary=False), nullable=False),
sa.Column("version_number", sa.Integer(), nullable=False),
sa.Column("version_name", StringID()),
sa.Column("dag_id", StringID(), nullable=False),
sa.Column("created_at", UtcDateTime(), nullable=False, default=timezone.utcnow),
sa.ForeignKeyConstraint(
Expand Down
1 change: 0 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ class DAG(TaskSDKDag, LoggingMixin):
**Warning**: A fail stop dag can only have tasks with the default trigger rule ("all_success").
An exception will be thrown if any task in a fail stop dag has a non default trigger rule.
:param dag_display_name: The display name of the DAG which appears on the UI.
:param version_name: The version name to use in storing the dag to the DB.
"""

partial: bool = False
Expand Down
23 changes: 7 additions & 16 deletions airflow/models/dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class DagVersion(Base):
__tablename__ = "dag_version"
id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7)
version_number = Column(Integer, nullable=False, default=1)
version_name = Column(StringID())
dag_id = Column(StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"), nullable=False)
dag_model = relationship("DagModel", back_populates="dag_versions")
dag_code = relationship(
Expand Down Expand Up @@ -78,7 +77,6 @@ def write_dag(
cls,
*,
dag_id: str,
version_name: str | None = None,
version_number: int = 1,
session: Session = NEW_SESSION,
) -> DagVersion:
Expand All @@ -88,7 +86,6 @@ def write_dag(
Checks if a version of the DAG exists and increments the version number if it does.
:param dag_id: The DAG ID.
:param version_name: The version name.
:param version_number: The version number.
:param session: The database session.
:return: The DagVersion object.
Expand All @@ -102,7 +99,6 @@ def write_dag(
dag_version = DagVersion(
dag_id=dag_id,
version_number=version_number,
version_name=version_name,
)
log.debug("Writing DagVersion %s to the DB", dag_version)
session.add(dag_version)
Expand Down Expand Up @@ -136,7 +132,7 @@ def get_latest_version(cls, dag_id: str, *, session: Session = NEW_SESSION) -> D
def get_version(
cls,
dag_id: str,
version_number: int = 1,
version_number: int | None = None,
*,
session: Session = NEW_SESSION,
) -> DagVersion | None:
Expand All @@ -148,18 +144,13 @@ def get_version(
:param session: The database session.
:return: The version of the DAG or None if not found.
"""
version_select_obj = (
select(cls)
.where(cls.dag_id == dag_id, cls.version_number == version_number)
.order_by(cls.version_number.desc())
.limit(1)
)
return session.scalar(version_select_obj)
version_select_obj = select(cls).where(cls.dag_id == dag_id)
if version_number:
version_select_obj = version_select_obj.where(cls.version_number == version_number)

return session.scalar(version_select_obj.order_by(cls.id.desc()).limit(1))

@property
def version(self) -> str:
"""A human-friendly representation of the version."""
name = f"{self.version_number}"
if self.version_name:
name = f"{self.version_name}-{self.version_number}"
return name
return f"{self.dag_id}-{self.version_number}"
1 change: 0 additions & 1 deletion airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def write_dag(
log.debug("Serialized DAG (%s) is unchanged. Skipping writing to DB", dag.dag_id)
return False
dagv = DagVersion.write_dag(
version_name=dag.version_name,
dag_id=dag.dag_id,
session=session,
)
Expand Down
1 change: 0 additions & 1 deletion airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@
},
"dag_display_name": { "type" : "string"},
"description": { "type" : "string"},
"version_name": {"type": "string"},
"_concurrency": { "type" : "number"},
"max_active_tasks": { "type" : "number"},
"max_active_runs": { "type" : "number"},
Expand Down
8 changes: 5 additions & 3 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -385,15 +385,17 @@ export const useDagSourceServiceGetDagSourceKey =
export const UseDagSourceServiceGetDagSourceKeyFn = (
{
accept,
fileToken,
dagId,
versionNumber,
}: {
accept?: string;
fileToken: string;
dagId: string;
versionNumber?: number;
},
queryKey?: Array<unknown>,
) => [
useDagSourceServiceGetDagSourceKey,
...(queryKey ?? [{ accept, fileToken }]),
...(queryKey ?? [{ accept, dagId, versionNumber }]),
];
export type DagStatsServiceGetDagStatsDefaultResponse = Awaited<
ReturnType<typeof DagStatsService.getDagStats>
Expand Down
Loading

0 comments on commit cc8aa7b

Please sign in to comment.