104 lines
4.1 KiB
Python
104 lines
4.1 KiB
Python
from sqlalchemy import asc, desc
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.orm.query import Query
|
|
from typing import Type, List, Optional
|
|
from app.core.common import ApiReturnPage
|
|
from sqlalchemy import and_ ,or_
|
|
from pydantic import BaseModel
|
|
from app.db import models
|
|
|
|
class crudbase:
|
|
def __init__(self, model: Type[models.Base]):
|
|
self.model = model
|
|
|
|
def _apply_filters(self, query: Query, filters: dict) -> Query:
|
|
and_conditions = []
|
|
or_conditions = []
|
|
for column_name, value in filters.items():
|
|
column = getattr(self.model, column_name, None)
|
|
if column:
|
|
if isinstance(value, dict):
|
|
if 'operator' in value:
|
|
operator = value['operator']
|
|
filter_value = value['value']
|
|
if operator == '!=':
|
|
and_conditions.append(column != filter_value)
|
|
elif operator == 'like':
|
|
and_conditions.append(column.like(f"%{filter_value}%"))
|
|
elif operator == '=':
|
|
and_conditions.append(column == filter_value)
|
|
elif operator == '>':
|
|
and_conditions.append(column > filter_value)
|
|
elif operator == '>=':
|
|
and_conditions.append(column >= filter_value)
|
|
elif operator == '<':
|
|
and_conditions.append(column < filter_value)
|
|
elif operator == '<=':
|
|
and_conditions.append(column <= filter_value)
|
|
elif operator == 'in':
|
|
if isinstance(filter_value, list):
|
|
or_conditions.append(column.in_(filter_value))
|
|
else:
|
|
and_conditions.append(column == filter_value)
|
|
else:
|
|
and_conditions.append(column == value)
|
|
else:
|
|
and_conditions.append(column == value)
|
|
|
|
if and_conditions:
|
|
query = query.filter(*and_conditions)
|
|
if or_conditions:
|
|
query = query.filter(or_(*or_conditions))
|
|
return query
|
|
|
|
def _apply_sorting(self, query: Query, sort_by: Optional[str], sort_order: Optional[str]) -> Query:
|
|
if sort_by:
|
|
column = getattr(self.model, sort_by, None)
|
|
if column:
|
|
if sort_order == "desc":
|
|
query = query.order_by(desc(column))
|
|
else:
|
|
query = query.order_by(asc(column))
|
|
return query
|
|
|
|
def get_all(self, db: Session) -> Query:
|
|
return db.query(self.model)
|
|
|
|
|
|
def get(self, db: Session, item_id: int) -> Optional[models.Base]:
|
|
return db.query(self.model).get(item_id)
|
|
|
|
def create(self, db: Session, obj_in: BaseModel) -> models.Base:
|
|
db_obj = self.model(**obj_in.model_dump())
|
|
db.add(db_obj)
|
|
db.commit()
|
|
db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
def update(self, db: Session, item_id: int, obj_in: BaseModel) -> Optional[models.Base]:
|
|
db_obj = db.query(self.model).filter(self.model.id == item_id).first()
|
|
if db_obj:
|
|
for key, value in obj_in.model_dump(exclude_unset=True).items():
|
|
setattr(db_obj, key, value)
|
|
db.commit()
|
|
db.refresh(db_obj)
|
|
return db_obj
|
|
return None
|
|
|
|
def delete(self, db: Session, item_id: int) -> Optional[models.Base]:
|
|
db_obj = db.query(self.model).get(item_id)
|
|
if db_obj:
|
|
db.delete(db_obj)
|
|
db.commit()
|
|
return db_obj
|
|
return None
|
|
|
|
def get_by_conditions(self, db: Session, filters: Optional[dict] = None, sort_by: Optional[str] = None,
|
|
sort_order: Optional[str] = "asc") -> Query:
|
|
query = db.query(self.model)
|
|
if filters:
|
|
query = self._apply_filters(query, filters)
|
|
if sort_by:
|
|
query = self._apply_sorting(query, sort_by, sort_order)
|
|
print(str(query))
|
|
return query |