fix pytest
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
from sqlalchemy import Boolean, Column, Integer, String, DateTime,ForeignKey,Table
|
from sqlalchemy import Boolean, Column, Integer, String, DateTime,ForeignKey,Table
|
||||||
from sqlalchemy.ext.declarative import as_declarative
|
from sqlalchemy.orm import relationship,as_declarative
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from app.db.session import Base
|
from app.db.session import Base
|
||||||
from app.core.security import chacha20Decrypt
|
from app.core.security import chacha20Decrypt
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.orm import sessionmaker,declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from app.core import config
|
from app.core import config
|
||||||
|
|
||||||
|
|||||||
145
backend/app/tests/conftest.py
Normal file
145
backend/app/tests/conftest.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine, event
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from app.core import config, security
|
||||||
|
from app.db.session import Base, get_db
|
||||||
|
from app.db import models
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_db_url() -> str:
|
||||||
|
return f"{config.SQLALCHEMY_DATABASE_URI}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_db():
|
||||||
|
"""
|
||||||
|
Modify the db session to automatically roll back after each test.
|
||||||
|
This is to avoid tests affecting the database state of other tests.
|
||||||
|
"""
|
||||||
|
# Connect to the test database
|
||||||
|
engine = create_engine(
|
||||||
|
get_test_db_url(),
|
||||||
|
)
|
||||||
|
|
||||||
|
connection = engine.connect()
|
||||||
|
trans = connection.begin()
|
||||||
|
|
||||||
|
# Run a parent transaction that can roll back all changes
|
||||||
|
test_session_maker = sessionmaker(
|
||||||
|
autocommit=False, autoflush=False, bind=engine
|
||||||
|
)
|
||||||
|
test_session = test_session_maker()
|
||||||
|
#test_session.begin_nested()
|
||||||
|
|
||||||
|
# @event.listens_for(test_session, "after_transaction_end")
|
||||||
|
# def restart_savepoint(s, transaction):
|
||||||
|
# if transaction.nested and not transaction._parent.nested:
|
||||||
|
# s.expire_all()
|
||||||
|
# s.begin_nested()
|
||||||
|
|
||||||
|
yield test_session
|
||||||
|
|
||||||
|
# Roll back the parent transaction after the test is complete
|
||||||
|
test_session.close()
|
||||||
|
trans.rollback()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def test_client(test_db):
|
||||||
|
"""
|
||||||
|
Get a TestClient instance that reads/write to the test database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_test_db():
|
||||||
|
yield test_db
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = get_test_db
|
||||||
|
with TestClient(app) as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.fixture
|
||||||
|
# def test_password() -> str:
|
||||||
|
# return "securepassword"
|
||||||
|
|
||||||
|
|
||||||
|
# def get_password_hash() -> str:
|
||||||
|
# """
|
||||||
|
# Password hashing can be expensive so a mock will be much faster
|
||||||
|
# """
|
||||||
|
# return "supersecrethash"
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.fixture
|
||||||
|
# def test_user(test_db) -> models.User:
|
||||||
|
# """
|
||||||
|
# Make a test user in the database
|
||||||
|
# """
|
||||||
|
|
||||||
|
# user = models.User(
|
||||||
|
# email="fake@email.com",
|
||||||
|
# hashed_password=get_password_hash(),
|
||||||
|
# is_active=True,
|
||||||
|
# )
|
||||||
|
# test_db.add(user)
|
||||||
|
# test_db.commit()
|
||||||
|
# return user
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.fixture
|
||||||
|
# def test_superuser(test_db) -> models.User:
|
||||||
|
# """
|
||||||
|
# Superuser for testing
|
||||||
|
# """
|
||||||
|
|
||||||
|
# user = models.User(
|
||||||
|
# email="fakeadmin@email.com",
|
||||||
|
# hashed_password=get_password_hash(),
|
||||||
|
# is_superuser=True,
|
||||||
|
# )
|
||||||
|
# test_db.add(user)
|
||||||
|
# test_db.commit()
|
||||||
|
# return user
|
||||||
|
|
||||||
|
|
||||||
|
# def verify_password_mock(first: str, second: str) -> bool:
|
||||||
|
# return True
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.fixture
|
||||||
|
# def user_token_headers(
|
||||||
|
# client: TestClient, test_user, test_password, monkeypatch
|
||||||
|
# ) -> t.Dict[str, str]:
|
||||||
|
# monkeypatch.setattr(security, "verify_password", verify_password_mock)
|
||||||
|
|
||||||
|
# login_data = {
|
||||||
|
# "username": test_user.email,
|
||||||
|
# "password": test_password,
|
||||||
|
# }
|
||||||
|
# r = client.post("/api/token", data=login_data)
|
||||||
|
# tokens = r.json()
|
||||||
|
# a_token = tokens["access_token"]
|
||||||
|
# headers = {"Authorization": f"Bearer {a_token}"}
|
||||||
|
# return headers
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.fixture
|
||||||
|
# def superuser_token_headers(
|
||||||
|
# client: TestClient, test_superuser, test_password, monkeypatch
|
||||||
|
# ) -> t.Dict[str, str]:
|
||||||
|
# monkeypatch.setattr(security, "verify_password", verify_password_mock)
|
||||||
|
|
||||||
|
# login_data = {
|
||||||
|
# "username": test_superuser.email,
|
||||||
|
# "password": test_password,
|
||||||
|
# }
|
||||||
|
# r = client.post("/api/token", data=login_data)
|
||||||
|
# tokens = r.json()
|
||||||
|
# a_token = tokens["access_token"]
|
||||||
|
# headers = {"Authorization": f"Bearer {a_token}"}
|
||||||
|
# return headers
|
||||||
@@ -1,11 +1,6 @@
|
|||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from app.main import app
|
|
||||||
|
|
||||||
client = TestClient(app)
|
def test_read_main(test_client):
|
||||||
|
response = test_client.get("/api/v1")
|
||||||
def test_read_main():
|
|
||||||
response = client.get("/api/v1")
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"message": "Hello World"}
|
assert response.json() == {"message": "success"}
|
||||||
|
|||||||
@@ -1,169 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from sqlalchemy import create_engine, event
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlalchemy_utils import database_exists, create_database, drop_database
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
import typing as t
|
|
||||||
|
|
||||||
from app.core import config, security
|
|
||||||
from app.db.session import Base, get_db
|
|
||||||
from app.db import models
|
|
||||||
from app.main import app
|
|
||||||
|
|
||||||
|
|
||||||
def get_test_db_url() -> str:
|
|
||||||
return f"{config.SQLALCHEMY_DATABASE_URI}_test"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_db():
|
|
||||||
"""
|
|
||||||
Modify the db session to automatically roll back after each test.
|
|
||||||
This is to avoid tests affecting the database state of other tests.
|
|
||||||
"""
|
|
||||||
# Connect to the test database
|
|
||||||
engine = create_engine(
|
|
||||||
get_test_db_url(),
|
|
||||||
)
|
|
||||||
|
|
||||||
connection = engine.connect()
|
|
||||||
trans = connection.begin()
|
|
||||||
|
|
||||||
# Run a parent transaction that can roll back all changes
|
|
||||||
test_session_maker = sessionmaker(
|
|
||||||
autocommit=False, autoflush=False, bind=engine
|
|
||||||
)
|
|
||||||
test_session = test_session_maker()
|
|
||||||
test_session.begin_nested()
|
|
||||||
|
|
||||||
@event.listens_for(test_session, "after_transaction_end")
|
|
||||||
def restart_savepoint(s, transaction):
|
|
||||||
if transaction.nested and not transaction._parent.nested:
|
|
||||||
s.expire_all()
|
|
||||||
s.begin_nested()
|
|
||||||
|
|
||||||
yield test_session
|
|
||||||
|
|
||||||
# Roll back the parent transaction after the test is complete
|
|
||||||
test_session.close()
|
|
||||||
trans.rollback()
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def create_test_db():
|
|
||||||
"""
|
|
||||||
Create a test database and use it for the whole test session.
|
|
||||||
"""
|
|
||||||
|
|
||||||
test_db_url = get_test_db_url()
|
|
||||||
|
|
||||||
# Create the test database
|
|
||||||
assert not database_exists(
|
|
||||||
test_db_url
|
|
||||||
), "Test database already exists. Aborting tests."
|
|
||||||
create_database(test_db_url)
|
|
||||||
test_engine = create_engine(test_db_url)
|
|
||||||
Base.metadata.create_all(test_engine)
|
|
||||||
|
|
||||||
# Run the tests
|
|
||||||
yield
|
|
||||||
|
|
||||||
# Drop the test database
|
|
||||||
drop_database(test_db_url)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client(test_db):
|
|
||||||
"""
|
|
||||||
Get a TestClient instance that reads/write to the test database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_test_db():
|
|
||||||
yield test_db
|
|
||||||
|
|
||||||
app.dependency_overrides[get_db] = get_test_db
|
|
||||||
|
|
||||||
yield TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_password() -> str:
|
|
||||||
return "securepassword"
|
|
||||||
|
|
||||||
|
|
||||||
def get_password_hash() -> str:
|
|
||||||
"""
|
|
||||||
Password hashing can be expensive so a mock will be much faster
|
|
||||||
"""
|
|
||||||
return "supersecrethash"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_user(test_db) -> models.User:
|
|
||||||
"""
|
|
||||||
Make a test user in the database
|
|
||||||
"""
|
|
||||||
|
|
||||||
user = models.User(
|
|
||||||
email="fake@email.com",
|
|
||||||
hashed_password=get_password_hash(),
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
test_db.add(user)
|
|
||||||
test_db.commit()
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_superuser(test_db) -> models.User:
|
|
||||||
"""
|
|
||||||
Superuser for testing
|
|
||||||
"""
|
|
||||||
|
|
||||||
user = models.User(
|
|
||||||
email="fakeadmin@email.com",
|
|
||||||
hashed_password=get_password_hash(),
|
|
||||||
is_superuser=True,
|
|
||||||
)
|
|
||||||
test_db.add(user)
|
|
||||||
test_db.commit()
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
def verify_password_mock(first: str, second: str) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def user_token_headers(
|
|
||||||
client: TestClient, test_user, test_password, monkeypatch
|
|
||||||
) -> t.Dict[str, str]:
|
|
||||||
monkeypatch.setattr(security, "verify_password", verify_password_mock)
|
|
||||||
|
|
||||||
login_data = {
|
|
||||||
"username": test_user.email,
|
|
||||||
"password": test_password,
|
|
||||||
}
|
|
||||||
r = client.post("/api/token", data=login_data)
|
|
||||||
tokens = r.json()
|
|
||||||
a_token = tokens["access_token"]
|
|
||||||
headers = {"Authorization": f"Bearer {a_token}"}
|
|
||||||
return headers
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def superuser_token_headers(
|
|
||||||
client: TestClient, test_superuser, test_password, monkeypatch
|
|
||||||
) -> t.Dict[str, str]:
|
|
||||||
monkeypatch.setattr(security, "verify_password", verify_password_mock)
|
|
||||||
|
|
||||||
login_data = {
|
|
||||||
"username": test_superuser.email,
|
|
||||||
"password": test_password,
|
|
||||||
}
|
|
||||||
r = client.post("/api/token", data=login_data)
|
|
||||||
tokens = r.json()
|
|
||||||
a_token = tokens["access_token"]
|
|
||||||
headers = {"Authorization": f"Bearer {a_token}"}
|
|
||||||
return headers
|
|
||||||
Binary file not shown.
Reference in New Issue
Block a user