170 lines
4.0 KiB
Python
170 lines
4.0 KiB
Python
import pytest
|
|
from sqlalchemy import create_engine, event
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy_utils import database_exists, create_database, drop_database
|
|
from fastapi.testclient import TestClient
|
|
import typing as t
|
|
|
|
from app.core import config, security
|
|
from app.db.session import Base, get_db
|
|
from app.db import models
|
|
from app.main import app
|
|
|
|
|
|
def get_test_db_url() -> str:
|
|
return f"{config.SQLALCHEMY_DATABASE_URI}_test"
|
|
|
|
|
|
@pytest.fixture
|
|
def test_db():
|
|
"""
|
|
Modify the db session to automatically roll back after each test.
|
|
This is to avoid tests affecting the database state of other tests.
|
|
"""
|
|
# Connect to the test database
|
|
engine = create_engine(
|
|
get_test_db_url(),
|
|
)
|
|
|
|
connection = engine.connect()
|
|
trans = connection.begin()
|
|
|
|
# Run a parent transaction that can roll back all changes
|
|
test_session_maker = sessionmaker(
|
|
autocommit=False, autoflush=False, bind=engine
|
|
)
|
|
test_session = test_session_maker()
|
|
test_session.begin_nested()
|
|
|
|
@event.listens_for(test_session, "after_transaction_end")
|
|
def restart_savepoint(s, transaction):
|
|
if transaction.nested and not transaction._parent.nested:
|
|
s.expire_all()
|
|
s.begin_nested()
|
|
|
|
yield test_session
|
|
|
|
# Roll back the parent transaction after the test is complete
|
|
test_session.close()
|
|
trans.rollback()
|
|
connection.close()
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def create_test_db():
|
|
"""
|
|
Create a test database and use it for the whole test session.
|
|
"""
|
|
|
|
test_db_url = get_test_db_url()
|
|
|
|
# Create the test database
|
|
assert not database_exists(
|
|
test_db_url
|
|
), "Test database already exists. Aborting tests."
|
|
create_database(test_db_url)
|
|
test_engine = create_engine(test_db_url)
|
|
Base.metadata.create_all(test_engine)
|
|
|
|
# Run the tests
|
|
yield
|
|
|
|
# Drop the test database
|
|
drop_database(test_db_url)
|
|
|
|
|
|
@pytest.fixture
|
|
def client(test_db):
|
|
"""
|
|
Get a TestClient instance that reads/write to the test database.
|
|
"""
|
|
|
|
def get_test_db():
|
|
yield test_db
|
|
|
|
app.dependency_overrides[get_db] = get_test_db
|
|
|
|
yield TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_password() -> str:
|
|
return "securepassword"
|
|
|
|
|
|
def get_password_hash() -> str:
|
|
"""
|
|
Password hashing can be expensive so a mock will be much faster
|
|
"""
|
|
return "supersecrethash"
|
|
|
|
|
|
@pytest.fixture
|
|
def test_user(test_db) -> models.User:
|
|
"""
|
|
Make a test user in the database
|
|
"""
|
|
|
|
user = models.User(
|
|
email="fake@email.com",
|
|
hashed_password=get_password_hash(),
|
|
is_active=True,
|
|
)
|
|
test_db.add(user)
|
|
test_db.commit()
|
|
return user
|
|
|
|
|
|
@pytest.fixture
|
|
def test_superuser(test_db) -> models.User:
|
|
"""
|
|
Superuser for testing
|
|
"""
|
|
|
|
user = models.User(
|
|
email="fakeadmin@email.com",
|
|
hashed_password=get_password_hash(),
|
|
is_superuser=True,
|
|
)
|
|
test_db.add(user)
|
|
test_db.commit()
|
|
return user
|
|
|
|
|
|
def verify_password_mock(first: str, second: str) -> bool:
|
|
return True
|
|
|
|
|
|
@pytest.fixture
|
|
def user_token_headers(
|
|
client: TestClient, test_user, test_password, monkeypatch
|
|
) -> t.Dict[str, str]:
|
|
monkeypatch.setattr(security, "verify_password", verify_password_mock)
|
|
|
|
login_data = {
|
|
"username": test_user.email,
|
|
"password": test_password,
|
|
}
|
|
r = client.post("/api/token", data=login_data)
|
|
tokens = r.json()
|
|
a_token = tokens["access_token"]
|
|
headers = {"Authorization": f"Bearer {a_token}"}
|
|
return headers
|
|
|
|
|
|
@pytest.fixture
|
|
def superuser_token_headers(
|
|
client: TestClient, test_superuser, test_password, monkeypatch
|
|
) -> t.Dict[str, str]:
|
|
monkeypatch.setattr(security, "verify_password", verify_password_mock)
|
|
|
|
login_data = {
|
|
"username": test_superuser.email,
|
|
"password": test_password,
|
|
}
|
|
r = client.post("/api/token", data=login_data)
|
|
tokens = r.json()
|
|
a_token = tokens["access_token"]
|
|
headers = {"Authorization": f"Bearer {a_token}"}
|
|
return headers
|