add tenant logic

This commit is contained in:
2024-12-15 12:28:22 +09:00
parent 2823364148
commit 39775a5179
16 changed files with 138 additions and 55 deletions

View File

@@ -1,26 +1,35 @@
from fastapi.security import OAuth2PasswordRequestForm
from fastapi import Form
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from datetime import timedelta from datetime import timedelta
from app.db.session import get_db
from app.core import security
from app.core.auth import authenticate_user, sign_up_new_user from app.core.auth import authenticate_user, sign_up_new_user
from app.core import security,tenantCacheService
from app.core.dbmanager import get_db
from sqlalchemy.orm import Session
auth_router = r = APIRouter() auth_router = r = APIRouter()
@r.post("/token") @r.post("/token")
async def login( async def login(db:Session= Depends(get_db) ,form_data: OAuth2PasswordRequestForm = Depends()):
db=Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends() if not db :
):
user = authenticate_user(db, form_data.username, form_data.password)
if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password", detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
user = authenticate_user(db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="abcIncorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta( access_token_expires = timedelta(
minutes=security.ACCESS_TOKEN_EXPIRE_MINUTES minutes=security.ACCESS_TOKEN_EXPIRE_MINUTES
) )
@@ -33,7 +42,7 @@ async def login(
permissions =";".join(list(set(perlst))) permissions =";".join(list(set(perlst)))
access_token = security.create_access_token( access_token = security.create_access_token(
data={"sub": user.id, "roles":roles,"permissions": permissions ,}, data={"sub": user.id,"roles":roles,"permissions": permissions,"tenant":user.tenantid,},
expires_delta=access_token_expires, expires_delta=access_token_expires,
) )

View File

@@ -8,7 +8,7 @@ import deepdiff
import app.core.config as config import app.core.config as config
import os import os
from pathlib import Path from pathlib import Path
from app.db.session import SessionLocal from app.core.dbmanager import get_db
from app.db.crud import get_flows_by_app,get_kintoneformat from app.db.crud import get_flows_by_app,get_kintoneformat
from app.core.auth import get_current_active_user,get_current_user from app.core.auth import get_current_active_user,get_current_user
from app.core.apiexception import APIException from app.core.apiexception import APIException
@@ -17,15 +17,15 @@ from app.db.cruddb import domainService
kinton_router = r = APIRouter() kinton_router = r = APIRouter()
def getkintoneenv(user = Depends(get_current_user)): def getkintoneenv(user = Depends(get_current_user)):
db = SessionLocal() db = get_db(user.tenantid) #SessionLocal()
domain = domainService.get_default_domain(db,user.id) #get_activedomain(db, user.id) domain = domainService.get_default_domain(db,user.id) #get_activedomain(db, user.id)
db.close() db.close()
kintoneevn = config.KINTONE_ENV(domain) kintoneevn = config.KINTONE_ENV(domain)
return kintoneevn return kintoneevn
def getkintoneformat(): def getkintoneformat(user = Depends(get_current_user)):
db = SessionLocal() db = get_db(user.tenantid)#SessionLocal()
formats = get_kintoneformat(db) formats = get_kintoneformat(db)
db.close() db.close()
return formats return formats

View File

@@ -2,8 +2,8 @@ from http import HTTPStatus
from fastapi import Query, Request,Depends, APIRouter, UploadFile,HTTPException,File from fastapi import Query, Request,Depends, APIRouter, UploadFile,HTTPException,File
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
# from app.core.operation import log_operation # from app.core.operation import log_operation
from app.db import Base,engine # from app.db import Base,engine
from app.db.session import get_db from app.core.dbmanager import get_db
from app.db.crud import * from app.db.crud import *
from app.db.schemas import * from app.db.schemas import *
from typing import List, Optional from typing import List, Optional
@@ -15,7 +15,7 @@ from app.db.cruddb import domainService,appService
import httpx import httpx
import app.core.config as config import app.core.config as config
from app.core import domainCacheService from app.core import domainCacheService,tenantCacheService
platform_router = r = APIRouter() platform_router = r = APIRouter()

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter, Request, Depends, Response, Security, encoders
import typing as t import typing as t
from app.core.common import ApiReturnModel,ApiReturnPage from app.core.common import ApiReturnModel,ApiReturnPage
from app.core.apiexception import APIException from app.core.apiexception import APIException
from app.db.session import get_db from app.core.dbmanager import get_db
from app.db.crud import ( from app.db.crud import (
get_allusers, get_allusers,
get_users, get_users,
@@ -16,6 +16,7 @@ from app.db.crud import (
from app.db.schemas import UserCreate, UserEdit, User, UserOut,RoleBase,Permission from app.db.schemas import UserCreate, UserEdit, User, UserOut,RoleBase,Permission
from app.core.auth import get_current_user,get_current_active_user, get_current_active_superuser from app.core.auth import get_current_user,get_current_active_user, get_current_active_superuser
from app.db.cruddb import userService from app.db.cruddb import userService
from app.core import tenantCacheService
users_router = r = APIRouter() users_router = r = APIRouter()

View File

@@ -1 +1,2 @@
from app.core.cache import domainCacheService from app.core.cache import domainCacheService
from app.core.cache import tenantCacheService

View File

@@ -1,7 +1,7 @@
from fastapi import HTTPException, status from fastapi import HTTPException, status,Depends
import httpx import httpx
from app.db.schemas import ErrorCreate from app.db.schemas import ErrorCreate
from app.db.session import SessionLocal from app.db.session import get_tenant_db
from app.db.crud import create_log from app.db.crud import create_log
class APIException(Exception): class APIException(Exception):
@@ -31,9 +31,9 @@ class APIException(Exception):
self.error = ErrorCreate(location=location, title=title, content=content) self.error = ErrorCreate(location=location, title=title, content=content)
super().__init__(self.error) super().__init__(self.error)
def writedblog(exc: APIException): def writedblog(exc: APIException,db = Depends(get_tenant_db())):
db = SessionLocal() #db = SessionLocal()
try: #try:
create_log(db,exc.error) create_log(db,exc.error)
finally: #finally:
db.close() #db.close()

View File

@@ -3,13 +3,14 @@ import jwt
from fastapi import Depends, HTTPException, Request, Security, status from fastapi import Depends, HTTPException, Request, Security, status
from jwt import PyJWTError from jwt import PyJWTError
from app.db import models, schemas, session from app.db import models, schemas
from app.db.crud import get_user_by_email, create_user,get_user from app.db.crud import get_user_by_email, create_user,get_user
from app.core import security from app.core import security
from app.db.cruddb import userService from app.db.cruddb import userService
from app.core.dbmanager import get_db
async def get_current_user(security_scopes: SecurityScopes, async def get_current_user(security_scopes: SecurityScopes,
db=Depends(session.get_db), token: str = Depends(security.oauth2_scheme) db=Depends(get_db), token: str = Depends(security.oauth2_scheme)
): ):
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -17,7 +18,6 @@ async def get_current_user(security_scopes: SecurityScopes,
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
try: try:
payload = jwt.decode( payload = jwt.decode(
token, security.SECRET_KEY, algorithms=[security.ALGORITHM] token, security.SECRET_KEY, algorithms=[security.ALGORITHM]
) )
@@ -25,6 +25,10 @@ async def get_current_user(security_scopes: SecurityScopes,
if id is None: if id is None:
raise credentials_exception raise credentials_exception
tenant:str = payload.get("tenant")
if tenant is None:
raise credentials_exception
permissions: str = payload.get("permissions") permissions: str = payload.get("permissions")
if not permissions =="ALL": if not permissions =="ALL":
for scope in security_scopes.scopes: for scope in security_scopes.scopes:
@@ -59,11 +63,11 @@ async def get_current_active_superuser(
def authenticate_user(db, email: str, password: str): def authenticate_user(db, email: str, password: str):
user = get_user_by_email(db, email) user = userService.get_user_by_email(db,email) #get_user_by_email(db, email)
if not user: if not user:
return False return None
if not security.verify_password(password, user.hashed_password): if not security.verify_password(password, user.hashed_password):
return False return None
return user return user

View File

@@ -2,6 +2,7 @@ import time
from typing import Any from typing import Any
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.cruddb import domainService from app.db.cruddb import domainService
from app.db.cruddb import tenantService
class MemoryCache: class MemoryCache:
def __init__(self, max_cache_size: int = 100, ttl: int = 60): def __init__(self, max_cache_size: int = 100, ttl: int = 60):
@@ -53,3 +54,18 @@ class domainCache:
self.memoryCache.clear() self.memoryCache.clear()
domainCacheService =domainCache() domainCacheService =domainCache()
class tenantCache:
def __init__(self):
self.memoryCache = MemoryCache(max_cache_size=50, ttl=120)
def get_tenant_db(self,db: Session, tenantid: str):
if not self.memoryCache.get(f"TENANT_{tenantid}"):
tenant = tenantService.get_tenant(db,tenantid)
if tenant:
self.memoryCache.set(f"TENANT_{tenantid}",tenant.db)
return self.memoryCache.get(f"TENANT_{tenantid}")
tenantCacheService =tenantCache()

View File

@@ -0,0 +1,12 @@
from fastapi import Depends
from app.db.session import get_tenant_db,get_user_db
from app.core import tenantCacheService
def get_db(tenant:str = "1",tenantdb = Depends(get_tenant_db)):
db_url = tenantCacheService.get_tenant_db(tenantdb,tenant)
db = get_user_db(db_url)
try:
yield db
finally:
db.close()

View File

@@ -1,3 +1,4 @@
from app.db.cruddb.dbuser import userService from app.db.cruddb.dbuser import userService
from app.db.cruddb.dbdomain import domainService from app.db.cruddb.dbdomain import domainService
from app.db.cruddb.dbapp import appService from app.db.cruddb.dbapp import appService
from app.db.cruddb.dbtenant import tenantService

View File

@@ -0,0 +1,13 @@
from app.db.cruddb.crudbase import crudbase
from app.db import models, schemas
from sqlalchemy.orm import Session
class dbtenant(crudbase):
def __init__(self):
super().__init__(model=models.Tenant)
def get_tenant(sefl,db:Session,tenantid: str):
tenant = db.execute(super().get_by_conditions({"tenantid":tenantid})).scalars().first()
return tenant
tenantService = dbtenant()

View File

@@ -1,7 +1,7 @@
from sqlalchemy import Boolean, Column, Integer, String, DateTime,ForeignKey,Table from sqlalchemy import Boolean, Column, Integer, String, DateTime,ForeignKey,Table
from sqlalchemy.orm import Mapped,relationship,as_declarative,mapped_column from sqlalchemy.orm import Mapped,relationship,as_declarative,mapped_column
from datetime import datetime from datetime import datetime
from app.db.session import Base from app.db import Base
from app.core.security import chacha20Decrypt from app.core.security import chacha20Decrypt
@as_declarative() @as_declarative()
@@ -34,6 +34,7 @@ class User(Base):
hashed_password = mapped_column(String(200), nullable=False) hashed_password = mapped_column(String(200), nullable=False)
is_active = mapped_column(Boolean, default=True) is_active = mapped_column(Boolean, default=True)
is_superuser = mapped_column(Boolean, default=False) is_superuser = mapped_column(Boolean, default=False)
tenantid = mapped_column(String(100))
createuserid = mapped_column(Integer,ForeignKey("user.id")) createuserid = mapped_column(Integer,ForeignKey("user.id"))
updateuserid = mapped_column(Integer,ForeignKey("user.id")) updateuserid = mapped_column(Integer,ForeignKey("user.id"))
createuser = relationship('User',foreign_keys=[createuserid]) createuser = relationship('User',foreign_keys=[createuserid])
@@ -148,6 +149,7 @@ class Tenant(Base):
licence = mapped_column(String(200)) licence = mapped_column(String(200))
startdate = mapped_column(DateTime) startdate = mapped_column(DateTime)
enddate = mapped_column(DateTime) enddate = mapped_column(DateTime)
db = mapped_column(String(200))
class Domain(Base): class Domain(Base):

View File

@@ -51,6 +51,7 @@ class UserCreate(UserBase):
last_name: str last_name: str
is_active:bool is_active:bool
is_superuser:bool is_superuser:bool
tenantid:t.Optional[str] = "1"
createuserid:t.Optional[int] = None createuserid:t.Optional[int] = None
updateuserid:t.Optional[int] = None updateuserid:t.Optional[int] = None

View File

@@ -1,20 +1,37 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker,declarative_base from sqlalchemy.orm import sessionmaker, declarative_base, Session
from app.core import config from app.core import config
engine = create_engine(
config.SQLALCHEMY_DATABASE_URI,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
# engine = create_engine(
# config.SQLALCHEMY_DATABASE_URI,
# )
# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Dependency class Database:
def get_db(): def __init__(self, database_url: str):
db = SessionLocal() self.database_url = database_url
self.engine = create_engine(self.database_url)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
self.Base = declarative_base()
def get_db(self):
db =self.SessionLocal()
return db
tenantdb = Database(config.SQLALCHEMY_DATABASE_URI)
def get_tenant_db():
db = tenantdb.get_db()
try: try:
yield db yield db
finally: finally:
db.close() db.close()
def get_user_db(database_url: str):
database = Database(database_url)
db = database.get_db()
return db

View File

@@ -8,7 +8,7 @@ from app.api.api_v1.routers.users import users_router
from app.api.api_v1.routers.auth import auth_router from app.api.api_v1.routers.auth import auth_router
from app.api.api_v1.routers.platform import platform_router from app.api.api_v1.routers.platform import platform_router
from app.core import config from app.core import config
from app.db import Base,engine #from app.db import Base,engine
from app.core.auth import get_current_active_user from app.core.auth import get_current_active_user
from app.core.celery_app import celery_app from app.core.celery_app import celery_app
from app import tasks from app import tasks
@@ -22,7 +22,7 @@ import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
Base.metadata.create_all(bind=engine) #Base.metadata.create_all(bind=engine)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):

View File

@@ -5,7 +5,7 @@ from fastapi.testclient import TestClient
import typing as t import typing as t
from app.core import config, security from app.core import config, security
from app.db.session import Base, get_db from app.core.dbmanager import get_db
from app.db import models,schemas from app.db import models,schemas
from app.main import app from app.main import app
@@ -42,9 +42,12 @@ def test_client(test_db):
with TestClient(app) as test_client: with TestClient(app) as test_client:
yield test_client yield test_client
@pytest.fixture(scope="session")
def test_tenant_id():
return "1"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def test_user(test_db): def test_user(test_db,test_tenant_id):
password ="test" password ="test"
user = models.User( user = models.User(
email = "test@test.com", email = "test@test.com",
@@ -52,7 +55,8 @@ def test_user(test_db):
last_name = "abc", last_name = "abc",
hashed_password = security.get_password_hash(password), hashed_password = security.get_password_hash(password),
is_active = True, is_active = True,
is_superuser = False is_superuser = False,
tenantid = test_tenant_id
) )
test_db.add(user) test_db.add(user)
test_db.commit() test_db.commit()
@@ -66,26 +70,28 @@ def password():
return "password" return "password"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def user(password): def user(password,test_tenant_id):
user = models.User( user = models.User(
email = "user@test.com", email = "user@test.com",
first_name = "user", first_name = "user",
last_name = "abc", last_name = "abc",
hashed_password = security.get_password_hash(password), hashed_password = security.get_password_hash(password),
is_active = True, is_active = True,
is_superuser = False is_superuser = False,
tenantid = test_tenant_id
) )
return user return user
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def admin(password): def admin(password,test_tenant_id):
user = models.User( user = models.User(
email = "admin@test.com", email = "admin@test.com",
first_name = "admin", first_name = "admin",
last_name = "abc", last_name = "abc",
hashed_password = security.get_password_hash(password), hashed_password = security.get_password_hash(password),
is_active = True, is_active = True,
is_superuser = True is_superuser = True,
tenantid =test_tenant_id
) )
return user return user