diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index 56932947..981b1a19 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -3,7 +3,7 @@ from typing import Literal -from sqlalchemy import ColumnElement, Engine, Select, func, select +from sqlalchemy import ColumnElement, Engine, Float, Select, case, cast, func, select from sqlalchemy.ext.declarative import DeferredReflection from sqlalchemy.orm import aliased, declarative_base @@ -280,17 +280,36 @@ def filter_by_missing_publisher(self) -> ColumnElement: return Superhero.publisher_id == None -class SuperheroColourFilterMixin: +class SuperheroAggregationMixin: """ - Mixin for filtering the view by the superhero colour attributes. + Mixin for aggregating the view by the superhero attributes. """ - def __init__(self, *args, **kwargs) -> None: - self.eye_colour = aliased(Colour) - self.hair_colour = aliased(Colour) - self.skin_colour = aliased(Colour) + @view_aggregation() + def count_superheroes(self) -> Select: + """ + Counts the number of superheros. - super().__init__(*args, **kwargs) + Returns: + The superheros count. + """ + return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) + + @view_aggregation() + def average_height(self) -> Select: + """ + Averages the height of the superheros. + + Returns: + The superheros average height. + """ + return self.data.with_only_columns(func.avg(Superhero.height_cm).label("average_height")).group_by(Superhero.id) + + +class SuperheroColourFilterMixin: + """ + Mixin for filtering the view by the superhero colour attributes. + """ @view_filter() def filter_by_eye_colour(self, eye_colour: str) -> ColumnElement: @@ -352,6 +371,31 @@ def filter_by_same_hair_and_skin_colour(self) -> ColumnElement: return self.hair_colour.colour == self.skin_colour.colour +class SuperheroColourAggregationMixin: + """ + Mixin for aggregating the view by the superhero colour attributes. + """ + + @view_aggregation() + def percentage_of_eye_colour(self, eye_colour: str) -> Select: + """ + Calculates the percentage of objects with eye colour. + + Args: + eye_colour: The eye colour of the object. + + Returns: + The percentage of objects with eye colour. + """ + return self.data.with_only_columns( + ( + cast(func.count(case((self.eye_colour.colour == eye_colour, Superhero.id), else_=None)), Float) + * 100 + / func.count(Superhero.id) + ).label(f"percentage_of_{eye_colour.lower()}") + ) + + class PublisherFilterMixin: """ Mixin for filtering the view by the publisher attributes. @@ -371,6 +415,31 @@ def filter_by_publisher_name(self, publisher_name: str) -> ColumnElement: return Publisher.publisher_name == publisher_name +class PublisherAggregationMixin: + """ + Mixin for aggregating the view by the publisher attributes. + """ + + @view_aggregation() + def percentage_of_publisher(self, publisher_name: str) -> Select: + """ + Calculates the percentage of objects with publisher. + + Args: + publisher_name: The name of the publisher. + + Returns: + The percentage of objects with publisher. + """ + return self.data.with_only_columns( + ( + cast(func.count(case((Publisher.publisher_name == publisher_name, Superhero.id), else_=None)), Float) + * 100 + / func.count(Superhero.id) + ).label(f"percentage_of_{publisher_name.lower()}") + ) + + class AlignmentFilterMixin: """ Mixin for filtering the view by the alignment attributes. @@ -390,6 +459,31 @@ def filter_by_alignment(self, alignment: Literal["Good", "Bad", "Neutral", "N/A" return Alignment.alignment == alignment +class AlignmentAggregationMixin: + """ + Mixin for aggregating the view by the alignment. + """ + + @view_aggregation() + def percentage_of_alignment(self, alignment: Literal["Good", "Bad", "Neutral", "N/A"]) -> Select: + """ + Calculates the percentage of objects with alignment. + + Args: + alignment: The alignment of the object. + + Returns: + The percentage of objects with alignment. + """ + return self.data.with_only_columns( + ( + cast(func.count(case((Alignment.alignment == alignment, Superhero.id), else_=None)), Float) + * 100 + / func.count(Superhero.id) + ).label(f"percentage_of_{alignment.lower()}") + ) + + class GenderFilterMixin: """ Mixin for filtering the view by the gender. @@ -409,48 +503,61 @@ def filter_by_gender(self, gender: Literal["Male", "Female", "N/A"]) -> ColumnEl return Gender.gender == gender -class RaceFilterMixin: +class GenderAggregationMixin: """ - Mixin for filtering the view by the race. + Mixin for aggregating the view by the gender. """ - @view_filter() - def filter_by_race(self, race: str) -> ColumnElement: + @view_aggregation() + def percentage_of_gender(self, gender: Literal["Male", "Female", "N/A"]) -> Select: """ - Filters the view by the object race. + Calculates the percentage of objects with gender. Args: - race: The race of the object. + gender: The gender of the object. Returns: - The filter condition. - """ - return Race.race == race + The percentage of objects with gender. + """ + return self.data.with_only_columns( + ( + cast(func.count(case((Gender.gender == gender, Superhero.id), else_=None)), Float) + * 100 + / func.count(Superhero.id) + ).label(f"percentage_of_{gender.lower()}") + ) -class SuperheroAggregationMixin: +class RaceFilterMixin: """ - Mixin for aggregating the view by the superhero attributes. + Mixin for filtering the view by the race. """ - @view_aggregation() - def count_superheroes(self) -> Select: + @view_filter() + def filter_by_race(self, race: str) -> ColumnElement: """ - Counts the number of superheros. + Filters the view by the object race. + + Args: + race: The race of the object. Returns: - The superheros count. + The filter condition. """ - return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) + return Race.race == race class SuperheroView( DBInitMixin, - SuperheroFilterMixin, SuperheroAggregationMixin, + SuperheroFilterMixin, + SuperheroColourAggregationMixin, SuperheroColourFilterMixin, + AlignmentAggregationMixin, AlignmentFilterMixin, + GenderAggregationMixin, GenderFilterMixin, + PublisherAggregationMixin, PublisherFilterMixin, RaceFilterMixin, SqlAlchemyBaseView, @@ -460,6 +567,12 @@ class SuperheroView( publisher name, gender, race, alignment, eye colour, hair colour, skin colour. """ + def __init__(self, *args, **kwargs) -> None: + self.eye_colour = aliased(Colour) + self.hair_colour = aliased(Colour) + self.skin_colour = aliased(Colour) + super().__init__(*args, **kwargs) + def get_select(self) -> Select: """ Initializes the select object for the view.