SQLAlchemy 1.0->SQLAlchemy 2.x
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import asc, desc
|
||||
from sqlalchemy import asc, desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.query import Query
|
||||
from typing import Type, List, Optional
|
||||
@@ -46,11 +46,12 @@ class crudbase:
|
||||
and_conditions.append(column == value)
|
||||
|
||||
if and_conditions:
|
||||
query = query.filter(*and_conditions)
|
||||
query = query.where(and_(*and_conditions))
|
||||
if or_conditions:
|
||||
query = query.filter(or_(*or_conditions))
|
||||
return query
|
||||
query = query.where(or_(*or_conditions))
|
||||
|
||||
return query
|
||||
|
||||
def _apply_sorting(self, query: Query, sort_by: Optional[str], sort_order: Optional[str]) -> Query:
|
||||
if sort_by:
|
||||
column = getattr(self.model, sort_by, None)
|
||||
@@ -61,12 +62,11 @@ class crudbase:
|
||||
query = query.order_by(asc(column))
|
||||
return query
|
||||
|
||||
def get_all(self, db: Session) -> Query:
|
||||
return db.query(self.model)
|
||||
def get_all(self) -> Query:
|
||||
return select(self.model)
|
||||
|
||||
|
||||
def get(self, db: Session, item_id: int) -> Optional[models.Base]:
|
||||
return db.query(self.model).get(item_id)
|
||||
return db.execute(select(self.model).filter(self.model.id == item_id)).scalar_one_or_none()
|
||||
|
||||
def create(self, db: Session, obj_in: BaseModel) -> models.Base:
|
||||
db_obj = self.model(**obj_in.model_dump())
|
||||
@@ -76,7 +76,7 @@ class crudbase:
|
||||
return db_obj
|
||||
|
||||
def update(self, db: Session, item_id: int, obj_in: BaseModel) -> Optional[models.Base]:
|
||||
db_obj = db.query(self.model).filter(self.model.id == item_id).first()
|
||||
db_obj = self.get(db,item_id)
|
||||
if db_obj:
|
||||
for key, value in obj_in.model_dump(exclude_unset=True).items():
|
||||
setattr(db_obj, key, value)
|
||||
@@ -86,16 +86,16 @@ class crudbase:
|
||||
return None
|
||||
|
||||
def delete(self, db: Session, item_id: int) -> Optional[models.Base]:
|
||||
db_obj = db.query(self.model).get(item_id)
|
||||
db_obj = self.get(db,item_id)
|
||||
if db_obj:
|
||||
db.delete(db_obj)
|
||||
db.commit()
|
||||
return db_obj
|
||||
return None
|
||||
|
||||
def get_by_conditions(self, db: Session, filters: Optional[dict] = None, sort_by: Optional[str] = None,
|
||||
def get_by_conditions(self, filters: Optional[dict] = None, sort_by: Optional[str] = None,
|
||||
sort_order: Optional[str] = "asc") -> Query:
|
||||
query = db.query(self.model)
|
||||
query = select(self.model)
|
||||
if filters:
|
||||
query = self._apply_filters(query, filters)
|
||||
if sort_by:
|
||||
|
||||
@@ -16,17 +16,17 @@ class dbuserdomain(crudbase):
|
||||
super().__init__(model=models.UserDomain)
|
||||
|
||||
def get_userdomain(self,db: Session,userid:int,domainid:int):
|
||||
return super().get_by_conditions(db,{"userid":userid,"domainid":domainid}).first()
|
||||
return db.execute(super().get_by_conditions({"userid":userid,"domainid":domainid})).scalars().first()
|
||||
|
||||
def get_userdomain_by_domainid(self,db: Session,ownerid:int,domainid:int):
|
||||
return super().get_by_conditions(db,{"domainid":domainid})
|
||||
return super().get_by_conditions({"domainid":domainid})
|
||||
|
||||
|
||||
def get_default_domains(self,db: Session,domainid:int):
|
||||
return super().get_by_conditions(db,{"domainid":domainid,"is_default":True}).all()
|
||||
return db.execute(super().get_by_conditions({"domainid":domainid,"is_default":True})).scalars().all()
|
||||
|
||||
def get_user_default_domain(self,db: Session,userid:int):
|
||||
return super().get_by_conditions(db,{"userid":userid,"is_default":True}).first()
|
||||
return db.execute(super().get_by_conditions({"userid":userid,"is_default":True})).scalars().first()
|
||||
|
||||
|
||||
dbuserdomain = dbuserdomain()
|
||||
@@ -36,10 +36,10 @@ class dbdomain(crudbase):
|
||||
super().__init__(model=models.Domain)
|
||||
|
||||
def get_domains(self,db: Session)-> ApiReturnPage[models.Base]:
|
||||
return paginate(super().get_all(db))
|
||||
return paginate(db,super().get_all())
|
||||
|
||||
def get_domains_by_owner(self,db: Session,ownerid:int)-> ApiReturnPage[models.Base]:
|
||||
return paginate( super().get_by_conditions(db,{"ownerid":ownerid}))
|
||||
return paginate(db,super().get_by_conditions({"ownerid":ownerid}))
|
||||
|
||||
def create_domain(self,db: Session, domain: schemas.DomainIn,userid:int):
|
||||
#db_domain = super().get_by_conditions(db,{"url":domain.url,"kintoneuser":domain.kintoneuser,"onwerid":userid}).first()
|
||||
@@ -79,8 +79,8 @@ class dbdomain(crudbase):
|
||||
return None
|
||||
|
||||
def add_userdomain(self,db: Session,ownerid:int,userid:int,domainid:int) -> schemas.DomainOut:
|
||||
db_domain = super().get_by_conditions(db,{"id":domainid,"is_active":True}).first()
|
||||
if db_domain:
|
||||
db_domain = super().get(db,domainid)
|
||||
if db_domain and db_domain.is_active:
|
||||
db_userdomain = dbuserdomain.get_userdomain(db,userid,domainid)
|
||||
if not db_userdomain:
|
||||
user_domain = models.UserDomain(userid = userid, domainid = domainid ,createuserid = ownerid,updateuserid = ownerid)
|
||||
|
||||
@@ -32,13 +32,13 @@ class dbuser(crudbase):
|
||||
return super().get(db,user_id)
|
||||
|
||||
def get_user_by_email(self,db: Session, email: str) -> schemas.User:
|
||||
return super().get_by_conditions(db,{"email":email}).first()
|
||||
return db.execute(super().get_by_conditions({"email":email})).scalars().first()
|
||||
|
||||
def get_users(self,db: Session) -> ApiReturnPage[models.Base]:
|
||||
return paginate(super().get_all(db))
|
||||
return paginate(db,super().get_all())
|
||||
|
||||
def get_users_not_admin(self,db: Session) -> ApiReturnPage[models.Base]:
|
||||
return paginate(super().get_by_conditions(db,{"is_superuser":False}))
|
||||
return paginate(db,super().get_by_conditions({"is_superuser":False}))
|
||||
|
||||
def create_user(self,db: Session, user: schemas.UserCreate,userid:int):
|
||||
hashed_password = get_password_hash(user.password)
|
||||
@@ -63,7 +63,7 @@ class dbuser(crudbase):
|
||||
return dbrole.get_all(db).all()
|
||||
|
||||
def get_roles_by_level(self,db: Session,level:int) -> t.List[schemas.RoleBase]:
|
||||
return dbrole.get_by_conditions(db,{"level":{"operator":">=","value":level}}).all()
|
||||
return db.execute(dbrole.get_by_conditions({"level":{"operator":">=","value":level}})).scalars().all()
|
||||
|
||||
def assign_userrole(self,db: Session, user_id: int, roles: t.List[int]):
|
||||
db_user = super().get(db,user_id)
|
||||
|
||||
Reference in New Issue
Block a user