diff --git a/setup.cfg b/setup.cfg index 68e9e6a..084b7b7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,4 +36,19 @@ dev = pytest pytest-mock pytest-asyncio - asynctest; python_version<'3.8' \ No newline at end of file + pytest-cov + # asynctest; python_version<'3.8' + +[coverage:run] +source = announce_server +branch = True + +[coverage:report] +# Fail the test if coverage is below a certain percentage +fail_under = 90 +show_missing = True +exclude_lines = + if __name__ == .__main__.: + +[tool:pytest] +addopts = --cov --cov-report term-missing \ No newline at end of file diff --git a/src/announce_server/decorator.py b/src/announce_server/decorator.py index 114ce5f..e4e0a72 100644 --- a/src/announce_server/decorator.py +++ b/src/announce_server/decorator.py @@ -44,6 +44,31 @@ async def _announce_server(**kwargs): await main() +# def announce_server(task=None, loop=None, **outer_kwargs): +# if task is None: +# return lambda f: announce_server(f, loop=loop, **outer_kwargs) + +# @wraps(task) +# def wrapper(*args, **kwargs): +# async def main(*args, **kwargs): +# if loop is not None: +# host_block_thread = loop.run_in_executor(None, task) +# else: +# host_block_thread = asyncio.to_thread(task) # python 3.9+ + +# # Announce the server to the host +# await _announce_server(**outer_kwargs) + +# # Wait for host_block to finish +# await host_block_thread + +# if loop is not None: +# task = loop.create_task(main(*args, **kwargs)) +# else: +# task = asyncio.run(main(*args, **kwargs)) +# return task +# return wrapper + # def announce_server(task=None, loop=None, **outer_kwargs): # if task is None: # return lambda f: announce_server(f, loop=loop, **outer_kwargs) @@ -65,3 +90,52 @@ async def _announce_server(**kwargs): # await host_block_thread # return wrapper + + +# def announce_server(task=None, loop=None, **outer_kwargs): +# if task is None: +# return lambda f: announce_server(f, loop=loop, **outer_kwargs) + +# if loop is None: +# loop = asyncio.get_event_loop() + +# @wraps(task) +# def wrapper(*args, **kwargs): +# async def main(*args, **kwargs): +# if asyncio.iscoroutinefunction(task): +# # If the task is async, just await it +# host_block_thread = task(*args, **kwargs) +# else: +# host_block_thread = loop.run_in_executor(None, task, *args, **kwargs) + +# # Announce the server to the host +# await _announce_server(**outer_kwargs) + +# # Wait for host_block to finish +# await host_block_thread + +# task = loop.create_task(main(*args, **kwargs)) +# return task + +# return wrapper + + +def announce_server(task=None, **outer_kwargs): + if task is None: + return lambda f: announce_server(f, **outer_kwargs) + + @wraps(task) + def wrapper(*args, **kwargs): + async def main(*args, **kwargs): + loop = asyncio.get_event_loop() + host_block_thread = loop.run_in_executor(None, task) + + # Announce the server to the host + await _announce_server(**outer_kwargs) + + # Wait for host_block to finish + await host_block_thread + + return asyncio.run(main()) + + return wrapper diff --git a/tests/test_announce.py b/tests/test_announce.py index 63fd3dc..b76f454 100644 --- a/tests/test_announce.py +++ b/tests/test_announce.py @@ -1,54 +1,39 @@ import asyncio -import sys -from unittest.mock import patch, MagicMock - -if sys.version_info >= (3, 8): - from unittest.mock import AsyncMock -else: - from asynctest import CoroutineMock as AsyncMock +import subprocess +from unittest.mock import MagicMock, patch import pytest from announce_server.decorator import _announce_server, announce_server -@pytest.mark.asyncio -# @patch('announce_server.decorator.announce_server', new=MagicMock()) -async def test_announce_server_decorator(mocked_announce_server, event_loop): - # Sample function to be decorated - async def sample_async_function(): - await asyncio.sleep(1) - return "Hello, world!" - +@patch("announce_server.decorator._announce_server") +def test_announce_server_decorator(mock_announce_server): # Mock the _announce_server function to prevent actual connections - mocked_announce_server.return_value = lambda x: x + mock_announce_server.return_value = MagicMock() # Decorate the sample function with announce_server - decorated_function = announce_server( + @announce_server( name="test_server", ip="127.0.0.1", port=8000, host_ip="127.0.0.1", host_port=5000, - loop=event_loop, # Pass the current event loop - )(sample_async_function) + ) + def http_server(): + server = subprocess.Popen(["python3", "-m", "http.server", "13373"]) + yield + server.terminate() + server.wait() # Run the decorated function - coro = asyncio.to_thread(decorated_function) - task = await asyncio.gather(coro) - await asyncio.sleep(1.1) # Sleep slightly longer than sample_async_function - task.cancel() # Cancel the task + http_server() # Check if the _announce_server function was called with the correct arguments - mocked_announce_server.assert_called_once_with( + mock_announce_server.assert_called_once_with( name="test_server", ip="127.0.0.1", port=8000, host_ip="127.0.0.1", host_port=5000, - loop=event_loop, ) - - # Check if the decorated function returns the expected result - result = await sample_async_function() - assert result == "Hello, world!" \ No newline at end of file diff --git a/tests/test_get_ip.py b/tests/test_get_ip.py index 998f1db..b584c08 100644 --- a/tests/test_get_ip.py +++ b/tests/test_get_ip.py @@ -1,4 +1,6 @@ +import asyncio import socket +from concurrent.futures import ThreadPoolExecutor from unittest.mock import MagicMock, patch import pytest @@ -6,7 +8,8 @@ import pytest from announce_server import get_ip_address -def test_get_ip_address(): +@pytest.mark.asyncio +async def test_get_ip_address(): with patch("socket.socket") as mock_socket: # Create a MagicMock object for the socket object mock_socket_instance = MagicMock() @@ -19,6 +22,11 @@ def test_get_ip_address(): mock_socket_instance.getsockname.return_value = (expected_ip, 0) # Test the get_ip_address function + # result_ip = await asyncio.to_thread(get_ip_address) + # loop = asyncio.get_event_loop() + # with ThreadPoolExecutor() as pool: + # result_ip = await loop.run_in_executor(pool, get_ip_address) + result_ip = get_ip_address() # Check if the result matches the expected IP address