fix pytest

This commit is contained in:
2024-12-04 14:44:32 +09:00
parent c5c4f79e4f
commit 4f17a6952d
6 changed files with 150 additions and 181 deletions

View File

@@ -1,6 +1,5 @@
from sqlalchemy import Boolean, Column, Integer, String, DateTime,ForeignKey,Table
from sqlalchemy.ext.declarative import as_declarative
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship,as_declarative
from datetime import datetime
from app.db.session import Base
from app.core.security import chacha20Decrypt

View File

@@ -1,6 +1,5 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import sessionmaker,declarative_base
from app.core import config

View File

@@ -0,0 +1,145 @@
import pytest
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from fastapi.testclient import TestClient
import typing as t
from app.core import config, security
from app.db.session import Base, get_db
from app.db import models
from app.main import app
def get_test_db_url() -> str:
return f"{config.SQLALCHEMY_DATABASE_URI}"
@pytest.fixture
def test_db():
"""
Modify the db session to automatically roll back after each test.
This is to avoid tests affecting the database state of other tests.
"""
# Connect to the test database
engine = create_engine(
get_test_db_url(),
)
connection = engine.connect()
trans = connection.begin()
# Run a parent transaction that can roll back all changes
test_session_maker = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
test_session = test_session_maker()
#test_session.begin_nested()
# @event.listens_for(test_session, "after_transaction_end")
# def restart_savepoint(s, transaction):
# if transaction.nested and not transaction._parent.nested:
# s.expire_all()
# s.begin_nested()
yield test_session
# Roll back the parent transaction after the test is complete
test_session.close()
trans.rollback()
connection.close()
@pytest.fixture(scope="function")
def test_client(test_db):
"""
Get a TestClient instance that reads/write to the test database.
"""
def get_test_db():
yield test_db
app.dependency_overrides[get_db] = get_test_db
with TestClient(app) as test_client:
yield test_client
# @pytest.fixture
# def test_password() -> str:
# return "securepassword"
# def get_password_hash() -> str:
# """
# Password hashing can be expensive so a mock will be much faster
# """
# return "supersecrethash"
# @pytest.fixture
# def test_user(test_db) -> models.User:
# """
# Make a test user in the database
# """
# user = models.User(
# email="fake@email.com",
# hashed_password=get_password_hash(),
# is_active=True,
# )
# test_db.add(user)
# test_db.commit()
# return user
# @pytest.fixture
# def test_superuser(test_db) -> models.User:
# """
# Superuser for testing
# """
# user = models.User(
# email="fakeadmin@email.com",
# hashed_password=get_password_hash(),
# is_superuser=True,
# )
# test_db.add(user)
# test_db.commit()
# return user
# def verify_password_mock(first: str, second: str) -> bool:
# return True
# @pytest.fixture
# def user_token_headers(
# client: TestClient, test_user, test_password, monkeypatch
# ) -> t.Dict[str, str]:
# monkeypatch.setattr(security, "verify_password", verify_password_mock)
# login_data = {
# "username": test_user.email,
# "password": test_password,
# }
# r = client.post("/api/token", data=login_data)
# tokens = r.json()
# a_token = tokens["access_token"]
# headers = {"Authorization": f"Bearer {a_token}"}
# return headers
# @pytest.fixture
# def superuser_token_headers(
# client: TestClient, test_superuser, test_password, monkeypatch
# ) -> t.Dict[str, str]:
# monkeypatch.setattr(security, "verify_password", verify_password_mock)
# login_data = {
# "username": test_superuser.email,
# "password": test_password,
# }
# r = client.post("/api/token", data=login_data)
# tokens = r.json()
# a_token = tokens["access_token"]
# headers = {"Authorization": f"Bearer {a_token}"}
# return headers

View File

@@ -1,11 +1,6 @@
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_read_main():
response = client.get("/api/v1")
def test_read_main(test_client):
response = test_client.get("/api/v1")
assert response.status_code == 200
assert response.json() == {"message": "Hello World"}
assert response.json() == {"message": "success"}

View File

@@ -1,169 +0,0 @@
import pytest
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import database_exists, create_database, drop_database
from fastapi.testclient import TestClient
import typing as t
from app.core import config, security
from app.db.session import Base, get_db
from app.db import models
from app.main import app
def get_test_db_url() -> str:
return f"{config.SQLALCHEMY_DATABASE_URI}_test"
@pytest.fixture
def test_db():
"""
Modify the db session to automatically roll back after each test.
This is to avoid tests affecting the database state of other tests.
"""
# Connect to the test database
engine = create_engine(
get_test_db_url(),
)
connection = engine.connect()
trans = connection.begin()
# Run a parent transaction that can roll back all changes
test_session_maker = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
test_session = test_session_maker()
test_session.begin_nested()
@event.listens_for(test_session, "after_transaction_end")
def restart_savepoint(s, transaction):
if transaction.nested and not transaction._parent.nested:
s.expire_all()
s.begin_nested()
yield test_session
# Roll back the parent transaction after the test is complete
test_session.close()
trans.rollback()
connection.close()
@pytest.fixture(scope="session", autouse=True)
def create_test_db():
"""
Create a test database and use it for the whole test session.
"""
test_db_url = get_test_db_url()
# Create the test database
assert not database_exists(
test_db_url
), "Test database already exists. Aborting tests."
create_database(test_db_url)
test_engine = create_engine(test_db_url)
Base.metadata.create_all(test_engine)
# Run the tests
yield
# Drop the test database
drop_database(test_db_url)
@pytest.fixture
def client(test_db):
"""
Get a TestClient instance that reads/write to the test database.
"""
def get_test_db():
yield test_db
app.dependency_overrides[get_db] = get_test_db
yield TestClient(app)
@pytest.fixture
def test_password() -> str:
return "securepassword"
def get_password_hash() -> str:
"""
Password hashing can be expensive so a mock will be much faster
"""
return "supersecrethash"
@pytest.fixture
def test_user(test_db) -> models.User:
"""
Make a test user in the database
"""
user = models.User(
email="fake@email.com",
hashed_password=get_password_hash(),
is_active=True,
)
test_db.add(user)
test_db.commit()
return user
@pytest.fixture
def test_superuser(test_db) -> models.User:
"""
Superuser for testing
"""
user = models.User(
email="fakeadmin@email.com",
hashed_password=get_password_hash(),
is_superuser=True,
)
test_db.add(user)
test_db.commit()
return user
def verify_password_mock(first: str, second: str) -> bool:
return True
@pytest.fixture
def user_token_headers(
client: TestClient, test_user, test_password, monkeypatch
) -> t.Dict[str, str]:
monkeypatch.setattr(security, "verify_password", verify_password_mock)
login_data = {
"username": test_user.email,
"password": test_password,
}
r = client.post("/api/token", data=login_data)
tokens = r.json()
a_token = tokens["access_token"]
headers = {"Authorization": f"Bearer {a_token}"}
return headers
@pytest.fixture
def superuser_token_headers(
client: TestClient, test_superuser, test_password, monkeypatch
) -> t.Dict[str, str]:
monkeypatch.setattr(security, "verify_password", verify_password_mock)
login_data = {
"username": test_superuser.email,
"password": test_password,
}
r = client.post("/api/token", data=login_data)
tokens = r.json()
a_token = tokens["access_token"]
headers = {"Authorization": f"Bearer {a_token}"}
return headers

Binary file not shown.