import pytest from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker from fastapi.testclient import TestClient import typing as t from app.core import config, security from app.db.session import Base, get_db from app.db import models from app.main import app def get_test_db_url() -> str: return f"{config.SQLALCHEMY_DATABASE_URI}" @pytest.fixture def test_db(): """ Modify the db session to automatically roll back after each test. This is to avoid tests affecting the database state of other tests. """ # Connect to the test database engine = create_engine( get_test_db_url(), ) connection = engine.connect() trans = connection.begin() # Run a parent transaction that can roll back all changes test_session_maker = sessionmaker( autocommit=False, autoflush=False, bind=engine ) test_session = test_session_maker() #test_session.begin_nested() # @event.listens_for(test_session, "after_transaction_end") # def restart_savepoint(s, transaction): # if transaction.nested and not transaction._parent.nested: # s.expire_all() # s.begin_nested() yield test_session # Roll back the parent transaction after the test is complete test_session.close() trans.rollback() connection.close() @pytest.fixture(scope="function") def test_client(test_db): """ Get a TestClient instance that reads/write to the test database. """ def get_test_db(): yield test_db app.dependency_overrides[get_db] = get_test_db with TestClient(app) as test_client: yield test_client # @pytest.fixture # def test_password() -> str: # return "securepassword" # def get_password_hash() -> str: # """ # Password hashing can be expensive so a mock will be much faster # """ # return "supersecrethash" # @pytest.fixture # def test_user(test_db) -> models.User: # """ # Make a test user in the database # """ # user = models.User( # email="fake@email.com", # hashed_password=get_password_hash(), # is_active=True, # ) # test_db.add(user) # test_db.commit() # return user # @pytest.fixture # def test_superuser(test_db) -> models.User: # """ # Superuser for testing # """ # user = models.User( # email="fakeadmin@email.com", # hashed_password=get_password_hash(), # is_superuser=True, # ) # test_db.add(user) # test_db.commit() # return user # def verify_password_mock(first: str, second: str) -> bool: # return True # @pytest.fixture # def user_token_headers( # client: TestClient, test_user, test_password, monkeypatch # ) -> t.Dict[str, str]: # monkeypatch.setattr(security, "verify_password", verify_password_mock) # login_data = { # "username": test_user.email, # "password": test_password, # } # r = client.post("/api/token", data=login_data) # tokens = r.json() # a_token = tokens["access_token"] # headers = {"Authorization": f"Bearer {a_token}"} # return headers # @pytest.fixture # def superuser_token_headers( # client: TestClient, test_superuser, test_password, monkeypatch # ) -> t.Dict[str, str]: # monkeypatch.setattr(security, "verify_password", verify_password_mock) # login_data = { # "username": test_superuser.email, # "password": test_password, # } # r = client.post("/api/token", data=login_data) # tokens = r.json() # a_token = tokens["access_token"] # headers = {"Authorization": f"Bearer {a_token}"} # return headers