fix pytest
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from sqlalchemy import Boolean, Column, Integer, String, DateTime,ForeignKey,Table
|
||||
from sqlalchemy.ext.declarative import as_declarative
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship,as_declarative
|
||||
from datetime import datetime
|
||||
from app.db.session import Base
|
||||
from app.core.security import chacha20Decrypt
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker,declarative_base
|
||||
|
||||
from app.core import config
|
||||
|
||||
|
||||
145
backend/app/tests/conftest.py
Normal file
145
backend/app/tests/conftest.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from fastapi.testclient import TestClient
|
||||
import typing as t
|
||||
|
||||
from app.core import config, security
|
||||
from app.db.session import Base, get_db
|
||||
from app.db import models
|
||||
from app.main import app
|
||||
|
||||
|
||||
def get_test_db_url() -> str:
|
||||
return f"{config.SQLALCHEMY_DATABASE_URI}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
"""
|
||||
Modify the db session to automatically roll back after each test.
|
||||
This is to avoid tests affecting the database state of other tests.
|
||||
"""
|
||||
# Connect to the test database
|
||||
engine = create_engine(
|
||||
get_test_db_url(),
|
||||
)
|
||||
|
||||
connection = engine.connect()
|
||||
trans = connection.begin()
|
||||
|
||||
# Run a parent transaction that can roll back all changes
|
||||
test_session_maker = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine
|
||||
)
|
||||
test_session = test_session_maker()
|
||||
#test_session.begin_nested()
|
||||
|
||||
# @event.listens_for(test_session, "after_transaction_end")
|
||||
# def restart_savepoint(s, transaction):
|
||||
# if transaction.nested and not transaction._parent.nested:
|
||||
# s.expire_all()
|
||||
# s.begin_nested()
|
||||
|
||||
yield test_session
|
||||
|
||||
# Roll back the parent transaction after the test is complete
|
||||
test_session.close()
|
||||
trans.rollback()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_client(test_db):
|
||||
"""
|
||||
Get a TestClient instance that reads/write to the test database.
|
||||
"""
|
||||
|
||||
def get_test_db():
|
||||
yield test_db
|
||||
|
||||
app.dependency_overrides[get_db] = get_test_db
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_password() -> str:
|
||||
# return "securepassword"
|
||||
|
||||
|
||||
# def get_password_hash() -> str:
|
||||
# """
|
||||
# Password hashing can be expensive so a mock will be much faster
|
||||
# """
|
||||
# return "supersecrethash"
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_user(test_db) -> models.User:
|
||||
# """
|
||||
# Make a test user in the database
|
||||
# """
|
||||
|
||||
# user = models.User(
|
||||
# email="fake@email.com",
|
||||
# hashed_password=get_password_hash(),
|
||||
# is_active=True,
|
||||
# )
|
||||
# test_db.add(user)
|
||||
# test_db.commit()
|
||||
# return user
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_superuser(test_db) -> models.User:
|
||||
# """
|
||||
# Superuser for testing
|
||||
# """
|
||||
|
||||
# user = models.User(
|
||||
# email="fakeadmin@email.com",
|
||||
# hashed_password=get_password_hash(),
|
||||
# is_superuser=True,
|
||||
# )
|
||||
# test_db.add(user)
|
||||
# test_db.commit()
|
||||
# return user
|
||||
|
||||
|
||||
# def verify_password_mock(first: str, second: str) -> bool:
|
||||
# return True
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def user_token_headers(
|
||||
# client: TestClient, test_user, test_password, monkeypatch
|
||||
# ) -> t.Dict[str, str]:
|
||||
# monkeypatch.setattr(security, "verify_password", verify_password_mock)
|
||||
|
||||
# login_data = {
|
||||
# "username": test_user.email,
|
||||
# "password": test_password,
|
||||
# }
|
||||
# r = client.post("/api/token", data=login_data)
|
||||
# tokens = r.json()
|
||||
# a_token = tokens["access_token"]
|
||||
# headers = {"Authorization": f"Bearer {a_token}"}
|
||||
# return headers
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def superuser_token_headers(
|
||||
# client: TestClient, test_superuser, test_password, monkeypatch
|
||||
# ) -> t.Dict[str, str]:
|
||||
# monkeypatch.setattr(security, "verify_password", verify_password_mock)
|
||||
|
||||
# login_data = {
|
||||
# "username": test_superuser.email,
|
||||
# "password": test_password,
|
||||
# }
|
||||
# r = client.post("/api/token", data=login_data)
|
||||
# tokens = r.json()
|
||||
# a_token = tokens["access_token"]
|
||||
# headers = {"Authorization": f"Bearer {a_token}"}
|
||||
# return headers
|
||||
@@ -1,11 +1,6 @@
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
def test_read_main():
|
||||
response = client.get("/api/v1")
|
||||
def test_read_main(test_client):
|
||||
response = test_client.get("/api/v1")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello World"}
|
||||
assert response.json() == {"message": "success"}
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy_utils import database_exists, create_database, drop_database
|
||||
from fastapi.testclient import TestClient
|
||||
import typing as t
|
||||
|
||||
from app.core import config, security
|
||||
from app.db.session import Base, get_db
|
||||
from app.db import models
|
||||
from app.main import app
|
||||
|
||||
|
||||
def get_test_db_url() -> str:
|
||||
return f"{config.SQLALCHEMY_DATABASE_URI}_test"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
"""
|
||||
Modify the db session to automatically roll back after each test.
|
||||
This is to avoid tests affecting the database state of other tests.
|
||||
"""
|
||||
# Connect to the test database
|
||||
engine = create_engine(
|
||||
get_test_db_url(),
|
||||
)
|
||||
|
||||
connection = engine.connect()
|
||||
trans = connection.begin()
|
||||
|
||||
# Run a parent transaction that can roll back all changes
|
||||
test_session_maker = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine
|
||||
)
|
||||
test_session = test_session_maker()
|
||||
test_session.begin_nested()
|
||||
|
||||
@event.listens_for(test_session, "after_transaction_end")
|
||||
def restart_savepoint(s, transaction):
|
||||
if transaction.nested and not transaction._parent.nested:
|
||||
s.expire_all()
|
||||
s.begin_nested()
|
||||
|
||||
yield test_session
|
||||
|
||||
# Roll back the parent transaction after the test is complete
|
||||
test_session.close()
|
||||
trans.rollback()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def create_test_db():
|
||||
"""
|
||||
Create a test database and use it for the whole test session.
|
||||
"""
|
||||
|
||||
test_db_url = get_test_db_url()
|
||||
|
||||
# Create the test database
|
||||
assert not database_exists(
|
||||
test_db_url
|
||||
), "Test database already exists. Aborting tests."
|
||||
create_database(test_db_url)
|
||||
test_engine = create_engine(test_db_url)
|
||||
Base.metadata.create_all(test_engine)
|
||||
|
||||
# Run the tests
|
||||
yield
|
||||
|
||||
# Drop the test database
|
||||
drop_database(test_db_url)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(test_db):
|
||||
"""
|
||||
Get a TestClient instance that reads/write to the test database.
|
||||
"""
|
||||
|
||||
def get_test_db():
|
||||
yield test_db
|
||||
|
||||
app.dependency_overrides[get_db] = get_test_db
|
||||
|
||||
yield TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_password() -> str:
|
||||
return "securepassword"
|
||||
|
||||
|
||||
def get_password_hash() -> str:
|
||||
"""
|
||||
Password hashing can be expensive so a mock will be much faster
|
||||
"""
|
||||
return "supersecrethash"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(test_db) -> models.User:
|
||||
"""
|
||||
Make a test user in the database
|
||||
"""
|
||||
|
||||
user = models.User(
|
||||
email="fake@email.com",
|
||||
hashed_password=get_password_hash(),
|
||||
is_active=True,
|
||||
)
|
||||
test_db.add(user)
|
||||
test_db.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_superuser(test_db) -> models.User:
|
||||
"""
|
||||
Superuser for testing
|
||||
"""
|
||||
|
||||
user = models.User(
|
||||
email="fakeadmin@email.com",
|
||||
hashed_password=get_password_hash(),
|
||||
is_superuser=True,
|
||||
)
|
||||
test_db.add(user)
|
||||
test_db.commit()
|
||||
return user
|
||||
|
||||
|
||||
def verify_password_mock(first: str, second: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_token_headers(
|
||||
client: TestClient, test_user, test_password, monkeypatch
|
||||
) -> t.Dict[str, str]:
|
||||
monkeypatch.setattr(security, "verify_password", verify_password_mock)
|
||||
|
||||
login_data = {
|
||||
"username": test_user.email,
|
||||
"password": test_password,
|
||||
}
|
||||
r = client.post("/api/token", data=login_data)
|
||||
tokens = r.json()
|
||||
a_token = tokens["access_token"]
|
||||
headers = {"Authorization": f"Bearer {a_token}"}
|
||||
return headers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def superuser_token_headers(
|
||||
client: TestClient, test_superuser, test_password, monkeypatch
|
||||
) -> t.Dict[str, str]:
|
||||
monkeypatch.setattr(security, "verify_password", verify_password_mock)
|
||||
|
||||
login_data = {
|
||||
"username": test_superuser.email,
|
||||
"password": test_password,
|
||||
}
|
||||
r = client.post("/api/token", data=login_data)
|
||||
tokens = r.json()
|
||||
a_token = tokens["access_token"]
|
||||
headers = {"Authorization": f"Bearer {a_token}"}
|
||||
return headers
|
||||
Binary file not shown.
Reference in New Issue
Block a user