fix pytest

This commit is contained in:
2024-12-04 14:44:32 +09:00
parent c5c4f79e4f
commit 4f17a6952d
6 changed files with 150 additions and 181 deletions

View File

@@ -1,6 +1,5 @@
from sqlalchemy import Boolean, Column, Integer, String, DateTime,ForeignKey,Table from sqlalchemy import Boolean, Column, Integer, String, DateTime,ForeignKey,Table
from sqlalchemy.ext.declarative import as_declarative from sqlalchemy.orm import relationship,as_declarative
from sqlalchemy.orm import relationship
from datetime import datetime from datetime import datetime
from app.db.session import Base from app.db.session import Base
from app.core.security import chacha20Decrypt from app.core.security import chacha20Decrypt

View File

@@ -1,6 +1,5 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker,declarative_base
from sqlalchemy.orm import sessionmaker
from app.core import config from app.core import config

View File

@@ -0,0 +1,145 @@
import pytest
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from fastapi.testclient import TestClient
import typing as t
from app.core import config, security
from app.db.session import Base, get_db
from app.db import models
from app.main import app
def get_test_db_url() -> str:
return f"{config.SQLALCHEMY_DATABASE_URI}"
@pytest.fixture
def test_db():
"""
Modify the db session to automatically roll back after each test.
This is to avoid tests affecting the database state of other tests.
"""
# Connect to the test database
engine = create_engine(
get_test_db_url(),
)
connection = engine.connect()
trans = connection.begin()
# Run a parent transaction that can roll back all changes
test_session_maker = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
test_session = test_session_maker()
#test_session.begin_nested()
# @event.listens_for(test_session, "after_transaction_end")
# def restart_savepoint(s, transaction):
# if transaction.nested and not transaction._parent.nested:
# s.expire_all()
# s.begin_nested()
yield test_session
# Roll back the parent transaction after the test is complete
test_session.close()
trans.rollback()
connection.close()
@pytest.fixture(scope="function")
def test_client(test_db):
"""
Get a TestClient instance that reads/write to the test database.
"""
def get_test_db():
yield test_db
app.dependency_overrides[get_db] = get_test_db
with TestClient(app) as test_client:
yield test_client
# @pytest.fixture
# def test_password() -> str:
# return "securepassword"
# def get_password_hash() -> str:
# """
# Password hashing can be expensive so a mock will be much faster
# """
# return "supersecrethash"
# @pytest.fixture
# def test_user(test_db) -> models.User:
# """
# Make a test user in the database
# """
# user = models.User(
# email="fake@email.com",
# hashed_password=get_password_hash(),
# is_active=True,
# )
# test_db.add(user)
# test_db.commit()
# return user
# @pytest.fixture
# def test_superuser(test_db) -> models.User:
# """
# Superuser for testing
# """
# user = models.User(
# email="fakeadmin@email.com",
# hashed_password=get_password_hash(),
# is_superuser=True,
# )
# test_db.add(user)
# test_db.commit()
# return user
# def verify_password_mock(first: str, second: str) -> bool:
# return True
# @pytest.fixture
# def user_token_headers(
# client: TestClient, test_user, test_password, monkeypatch
# ) -> t.Dict[str, str]:
# monkeypatch.setattr(security, "verify_password", verify_password_mock)
# login_data = {
# "username": test_user.email,
# "password": test_password,
# }
# r = client.post("/api/token", data=login_data)
# tokens = r.json()
# a_token = tokens["access_token"]
# headers = {"Authorization": f"Bearer {a_token}"}
# return headers
# @pytest.fixture
# def superuser_token_headers(
# client: TestClient, test_superuser, test_password, monkeypatch
# ) -> t.Dict[str, str]:
# monkeypatch.setattr(security, "verify_password", verify_password_mock)
# login_data = {
# "username": test_superuser.email,
# "password": test_password,
# }
# r = client.post("/api/token", data=login_data)
# tokens = r.json()
# a_token = tokens["access_token"]
# headers = {"Authorization": f"Bearer {a_token}"}
# return headers

View File

@@ -1,11 +1,6 @@
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app) def test_read_main(test_client):
response = test_client.get("/api/v1")
def test_read_main():
response = client.get("/api/v1")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"message": "Hello World"} assert response.json() == {"message": "success"}

View File

@@ -1,169 +0,0 @@
import pytest
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import database_exists, create_database, drop_database
from fastapi.testclient import TestClient
import typing as t
from app.core import config, security
from app.db.session import Base, get_db
from app.db import models
from app.main import app
def get_test_db_url() -> str:
return f"{config.SQLALCHEMY_DATABASE_URI}_test"
@pytest.fixture
def test_db():
"""
Modify the db session to automatically roll back after each test.
This is to avoid tests affecting the database state of other tests.
"""
# Connect to the test database
engine = create_engine(
get_test_db_url(),
)
connection = engine.connect()
trans = connection.begin()
# Run a parent transaction that can roll back all changes
test_session_maker = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
test_session = test_session_maker()
test_session.begin_nested()
@event.listens_for(test_session, "after_transaction_end")
def restart_savepoint(s, transaction):
if transaction.nested and not transaction._parent.nested:
s.expire_all()
s.begin_nested()
yield test_session
# Roll back the parent transaction after the test is complete
test_session.close()
trans.rollback()
connection.close()
@pytest.fixture(scope="session", autouse=True)
def create_test_db():
"""
Create a test database and use it for the whole test session.
"""
test_db_url = get_test_db_url()
# Create the test database
assert not database_exists(
test_db_url
), "Test database already exists. Aborting tests."
create_database(test_db_url)
test_engine = create_engine(test_db_url)
Base.metadata.create_all(test_engine)
# Run the tests
yield
# Drop the test database
drop_database(test_db_url)
@pytest.fixture
def client(test_db):
"""
Get a TestClient instance that reads/write to the test database.
"""
def get_test_db():
yield test_db
app.dependency_overrides[get_db] = get_test_db
yield TestClient(app)
@pytest.fixture
def test_password() -> str:
return "securepassword"
def get_password_hash() -> str:
"""
Password hashing can be expensive so a mock will be much faster
"""
return "supersecrethash"
@pytest.fixture
def test_user(test_db) -> models.User:
"""
Make a test user in the database
"""
user = models.User(
email="fake@email.com",
hashed_password=get_password_hash(),
is_active=True,
)
test_db.add(user)
test_db.commit()
return user
@pytest.fixture
def test_superuser(test_db) -> models.User:
"""
Superuser for testing
"""
user = models.User(
email="fakeadmin@email.com",
hashed_password=get_password_hash(),
is_superuser=True,
)
test_db.add(user)
test_db.commit()
return user
def verify_password_mock(first: str, second: str) -> bool:
return True
@pytest.fixture
def user_token_headers(
client: TestClient, test_user, test_password, monkeypatch
) -> t.Dict[str, str]:
monkeypatch.setattr(security, "verify_password", verify_password_mock)
login_data = {
"username": test_user.email,
"password": test_password,
}
r = client.post("/api/token", data=login_data)
tokens = r.json()
a_token = tokens["access_token"]
headers = {"Authorization": f"Bearer {a_token}"}
return headers
@pytest.fixture
def superuser_token_headers(
client: TestClient, test_superuser, test_password, monkeypatch
) -> t.Dict[str, str]:
monkeypatch.setattr(security, "verify_password", verify_password_mock)
login_data = {
"username": test_superuser.email,
"password": test_password,
}
r = client.post("/api/token", data=login_data)
tokens = r.json()
a_token = tokens["access_token"]
headers = {"Authorization": f"Bearer {a_token}"}
return headers

Binary file not shown.