diff --git a/asyncqlio/orm/operators.py b/asyncqlio/orm/operators.py index 27f70c4..4485602 100644 --- a/asyncqlio/orm/operators.py +++ b/asyncqlio/orm/operators.py @@ -246,6 +246,41 @@ def generate_sql(self, emitter: typing.Callable[[str], str]): return OperatorResponse(sql, params) +class Between(BaseOperator): + def __init__(self, column: 'md_column.Column', min: typing.Any, max: typing.Any): + self.column = column + self.min = min + self.max = max + + def generate_sql(self, emitter: typing.Callable[[str], str]): + # generate a dict of params + params = {} + l = [] + + for p in ['min', 'max']: + emitted, name = emitter() + params[name] = getattr(self, p) + l.append(emitted) + + sql = "{} BETWEEN {}".format(self.column.quoted_fullname, ' AND '.join(l)) + return OperatorResponse(sql, params) + + +class NotBetween(Between): + def generate_sql(self, emitter: typing.Callable[[str], str]): + # generate a dict of params + params = {} + l = [] + + for p in ['min', 'max']: + emitted, name = emitter() + params[name] = getattr(self, p) + l.append(emitted) + + sql = "{} NOT BETWEEN {}".format(self.column.quoted_fullname, ' AND '.join(l)) + return OperatorResponse(sql, params) + + class ComparisonOp(ColumnValueMixin, BaseOperator): """ A helper class that implements easy generation of comparison-based operators. diff --git a/asyncqlio/orm/schema/column.py b/asyncqlio/orm/schema/column.py index 0209ce5..8d2f8b5 100644 --- a/asyncqlio/orm/schema/column.py +++ b/asyncqlio/orm/schema/column.py @@ -311,6 +311,24 @@ def eq(self, other) -> 'md_operators.Eq': """ return md_operators.Eq(self, other) + def isin(self, other) -> 'md_operators.In': + """ + Checks if this column is in supplied list + """ + return md_operators.In(self, other) + + def between(self, min, max) -> 'md_operators.Between': + """ + Checks if this column is between the min and max value + """ + return md_operators.Between(self, min, max) + + def nbetween(self, min, max) -> 'md_operators.NotBetween': + """ + Checks if this column is not between the min and max value + """ + return md_operators.NotBetween(self, min, max) + def ne(self, other) -> 'md_operators.NEq': """ Checks if this column is not equal to something else. diff --git a/tests/test_3session.py b/tests/test_3session.py index bc250b8..64b1db3 100644 --- a/tests/test_3session.py +++ b/tests/test_3session.py @@ -124,3 +124,23 @@ async def test_numeric_decimal(db: DatabaseInterface, table: Table): assert str(res.lat) == "12.010" assert str(res.lon) == "12.01" + + +async def test_between(db: DatabaseInterface, table: Table): + test_ids = [210, 211, 212] + + async with db.get_session() as sess: + await sess.insert.rows(table(id=210, name="between0", email="0", lat=55.010, lon=55.010)).run() + await sess.insert.rows(table(id=211, name="between1", email="1", lat=55.110, lon=55.110)).run() + await sess.insert.rows(table(id=212, name="between2", email="2", lat=55.110, lon=55.110)).run() + + async with db.get_session() as sess: + res_first = await sess.select(table).where(table.id.between(200, 220)).first() + res_all = await sess.select(table).where(table.id.between(200, 220)).all() + res_not = await sess.select(table).where(table.id.nbetween(200, 220)).first() + + all = [r async for r in res_all] + + assert res_first.id == 210 + assert len(all) == 3 + assert res_not not in test_ids