From 4f17a6952d4ac86aa64b622f38b6e971b60c04f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=20=E6=9F=8F?= Date: Wed, 4 Dec 2024 14:44:32 +0900 Subject: [PATCH] fix pytest --- backend/app/db/models.py | 3 +- backend/app/db/session.py | 3 +- backend/app/tests/conftest.py | 145 ++++++++++++++++++++++++++++ backend/app/tests/test_main.py | 11 +-- backend/conftest.py | 169 --------------------------------- backend/requirements.txt | Bin 1878 -> 1946 bytes 6 files changed, 150 insertions(+), 181 deletions(-) create mode 100644 backend/app/tests/conftest.py delete mode 100644 backend/conftest.py diff --git a/backend/app/db/models.py b/backend/app/db/models.py index 9ef599c..04ba1e3 100644 --- a/backend/app/db/models.py +++ b/backend/app/db/models.py @@ -1,6 +1,5 @@ from sqlalchemy import Boolean, Column, Integer, String, DateTime,ForeignKey,Table -from sqlalchemy.ext.declarative import as_declarative -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship,as_declarative from datetime import datetime from app.db.session import Base from app.core.security import chacha20Decrypt diff --git a/backend/app/db/session.py b/backend/app/db/session.py index d7e2f6c..148c9a4 100644 --- a/backend/app/db/session.py +++ b/backend/app/db/session.py @@ -1,6 +1,5 @@ from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker,declarative_base from app.core import config diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py new file mode 100644 index 0000000..ab33ff6 --- /dev/null +++ b/backend/app/tests/conftest.py @@ -0,0 +1,145 @@ +import pytest +from sqlalchemy import create_engine, event +from sqlalchemy.orm import sessionmaker +from fastapi.testclient import TestClient +import typing as t + +from app.core import config, security +from app.db.session import Base, get_db +from app.db import models +from app.main import app + + +def get_test_db_url() -> str: + return f"{config.SQLALCHEMY_DATABASE_URI}" + + +@pytest.fixture +def test_db(): + """ + Modify the db session to automatically roll back after each test. + This is to avoid tests affecting the database state of other tests. + """ + # Connect to the test database + engine = create_engine( + get_test_db_url(), + ) + + connection = engine.connect() + trans = connection.begin() + + # Run a parent transaction that can roll back all changes + test_session_maker = sessionmaker( + autocommit=False, autoflush=False, bind=engine + ) + test_session = test_session_maker() + #test_session.begin_nested() + + # @event.listens_for(test_session, "after_transaction_end") + # def restart_savepoint(s, transaction): + # if transaction.nested and not transaction._parent.nested: + # s.expire_all() + # s.begin_nested() + + yield test_session + + # Roll back the parent transaction after the test is complete + test_session.close() + trans.rollback() + connection.close() + + +@pytest.fixture(scope="function") +def test_client(test_db): + """ + Get a TestClient instance that reads/write to the test database. + """ + + def get_test_db(): + yield test_db + + app.dependency_overrides[get_db] = get_test_db + with TestClient(app) as test_client: + yield test_client + + +# @pytest.fixture +# def test_password() -> str: +# return "securepassword" + + +# def get_password_hash() -> str: +# """ +# Password hashing can be expensive so a mock will be much faster +# """ +# return "supersecrethash" + + +# @pytest.fixture +# def test_user(test_db) -> models.User: +# """ +# Make a test user in the database +# """ + +# user = models.User( +# email="fake@email.com", +# hashed_password=get_password_hash(), +# is_active=True, +# ) +# test_db.add(user) +# test_db.commit() +# return user + + +# @pytest.fixture +# def test_superuser(test_db) -> models.User: +# """ +# Superuser for testing +# """ + +# user = models.User( +# email="fakeadmin@email.com", +# hashed_password=get_password_hash(), +# is_superuser=True, +# ) +# test_db.add(user) +# test_db.commit() +# return user + + +# def verify_password_mock(first: str, second: str) -> bool: +# return True + + +# @pytest.fixture +# def user_token_headers( +# client: TestClient, test_user, test_password, monkeypatch +# ) -> t.Dict[str, str]: +# monkeypatch.setattr(security, "verify_password", verify_password_mock) + +# login_data = { +# "username": test_user.email, +# "password": test_password, +# } +# r = client.post("/api/token", data=login_data) +# tokens = r.json() +# a_token = tokens["access_token"] +# headers = {"Authorization": f"Bearer {a_token}"} +# return headers + + +# @pytest.fixture +# def superuser_token_headers( +# client: TestClient, test_superuser, test_password, monkeypatch +# ) -> t.Dict[str, str]: +# monkeypatch.setattr(security, "verify_password", verify_password_mock) + +# login_data = { +# "username": test_superuser.email, +# "password": test_password, +# } +# r = client.post("/api/token", data=login_data) +# tokens = r.json() +# a_token = tokens["access_token"] +# headers = {"Authorization": f"Bearer {a_token}"} +# return headers diff --git a/backend/app/tests/test_main.py b/backend/app/tests/test_main.py index a1048ae..8d844d7 100644 --- a/backend/app/tests/test_main.py +++ b/backend/app/tests/test_main.py @@ -1,11 +1,6 @@ -import pytest -from fastapi.testclient import TestClient -from app.main import app -client = TestClient(app) - -def test_read_main(): - response = client.get("/api/v1") +def test_read_main(test_client): + response = test_client.get("/api/v1") assert response.status_code == 200 - assert response.json() == {"message": "Hello World"} + assert response.json() == {"message": "success"} diff --git a/backend/conftest.py b/backend/conftest.py deleted file mode 100644 index ecd831d..0000000 --- a/backend/conftest.py +++ /dev/null @@ -1,169 +0,0 @@ -import pytest -from sqlalchemy import create_engine, event -from sqlalchemy.orm import sessionmaker -from sqlalchemy_utils import database_exists, create_database, drop_database -from fastapi.testclient import TestClient -import typing as t - -from app.core import config, security -from app.db.session import Base, get_db -from app.db import models -from app.main import app - - -def get_test_db_url() -> str: - return f"{config.SQLALCHEMY_DATABASE_URI}_test" - - -@pytest.fixture -def test_db(): - """ - Modify the db session to automatically roll back after each test. - This is to avoid tests affecting the database state of other tests. - """ - # Connect to the test database - engine = create_engine( - get_test_db_url(), - ) - - connection = engine.connect() - trans = connection.begin() - - # Run a parent transaction that can roll back all changes - test_session_maker = sessionmaker( - autocommit=False, autoflush=False, bind=engine - ) - test_session = test_session_maker() - test_session.begin_nested() - - @event.listens_for(test_session, "after_transaction_end") - def restart_savepoint(s, transaction): - if transaction.nested and not transaction._parent.nested: - s.expire_all() - s.begin_nested() - - yield test_session - - # Roll back the parent transaction after the test is complete - test_session.close() - trans.rollback() - connection.close() - - -@pytest.fixture(scope="session", autouse=True) -def create_test_db(): - """ - Create a test database and use it for the whole test session. - """ - - test_db_url = get_test_db_url() - - # Create the test database - assert not database_exists( - test_db_url - ), "Test database already exists. Aborting tests." - create_database(test_db_url) - test_engine = create_engine(test_db_url) - Base.metadata.create_all(test_engine) - - # Run the tests - yield - - # Drop the test database - drop_database(test_db_url) - - -@pytest.fixture -def client(test_db): - """ - Get a TestClient instance that reads/write to the test database. - """ - - def get_test_db(): - yield test_db - - app.dependency_overrides[get_db] = get_test_db - - yield TestClient(app) - - -@pytest.fixture -def test_password() -> str: - return "securepassword" - - -def get_password_hash() -> str: - """ - Password hashing can be expensive so a mock will be much faster - """ - return "supersecrethash" - - -@pytest.fixture -def test_user(test_db) -> models.User: - """ - Make a test user in the database - """ - - user = models.User( - email="fake@email.com", - hashed_password=get_password_hash(), - is_active=True, - ) - test_db.add(user) - test_db.commit() - return user - - -@pytest.fixture -def test_superuser(test_db) -> models.User: - """ - Superuser for testing - """ - - user = models.User( - email="fakeadmin@email.com", - hashed_password=get_password_hash(), - is_superuser=True, - ) - test_db.add(user) - test_db.commit() - return user - - -def verify_password_mock(first: str, second: str) -> bool: - return True - - -@pytest.fixture -def user_token_headers( - client: TestClient, test_user, test_password, monkeypatch -) -> t.Dict[str, str]: - monkeypatch.setattr(security, "verify_password", verify_password_mock) - - login_data = { - "username": test_user.email, - "password": test_password, - } - r = client.post("/api/token", data=login_data) - tokens = r.json() - a_token = tokens["access_token"] - headers = {"Authorization": f"Bearer {a_token}"} - return headers - - -@pytest.fixture -def superuser_token_headers( - client: TestClient, test_superuser, test_password, monkeypatch -) -> t.Dict[str, str]: - monkeypatch.setattr(security, "verify_password", verify_password_mock) - - login_data = { - "username": test_superuser.email, - "password": test_password, - } - r = client.post("/api/token", data=login_data) - tokens = r.json() - a_token = tokens["access_token"] - headers = {"Authorization": f"Bearer {a_token}"} - return headers diff --git a/backend/requirements.txt b/backend/requirements.txt index 8f0139ace584b8963c9758e93b89d35f9840b6da..e9ec467ee99b3189a3d793fdb43c18430a3b0219 100644 GIT binary patch delta 83 zcmcb{H;aEm0<&x;LmopWLo!1?kWK@#(iv=l(1<~g!GJ+;awPL?(PV~fhD4w^NF_|x XWbgeSg`