Skip to content

fix(transport): handle connection error correctly #660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions playwright/_impl/_browser_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from pathlib import Path
from typing import Dict, List, Optional, Union, cast

Expand All @@ -21,6 +22,7 @@
ProxySettings,
ViewportSize,
)
from playwright._impl._api_types import Error
from playwright._impl._browser import Browser, normalize_context_params
from playwright._impl._browser_context import BrowserContext
from playwright._impl._connection import (
Expand All @@ -37,6 +39,7 @@
not_installed_error,
)
from playwright._impl._transport import WebSocketTransport
from playwright._impl._wait_helper import throw_on_timeout


class BrowserType(ChannelOwner):
Expand Down Expand Up @@ -172,7 +175,9 @@ async def connect(
slow_mo: float = None,
headers: Dict[str, str] = None,
) -> Browser:
transport = WebSocketTransport(ws_endpoint, timeout, headers)
if timeout is None:
timeout = 30000
transport = WebSocketTransport(self._connection._loop, ws_endpoint, headers)

connection = Connection(
self._connection._dispatcher_fiber,
Expand All @@ -182,8 +187,21 @@ async def connect(
connection._is_sync = self._connection._is_sync
connection._loop = self._connection._loop
connection._loop.create_task(connection.run())
await connection.initialize()
playwright_future = asyncio.create_task(
connection.wait_for_object_with_known_name("Playwright")
)
timeout_future = throw_on_timeout(timeout, Error("Connection timed out"))
done, pending = await asyncio.wait(
{transport.on_error_future, playwright_future, timeout_future},
return_when=asyncio.FIRST_COMPLETED,
)
if not playwright_future.done():
playwright_future.cancel()
if not timeout_future.done():
timeout_future.cancel()
playwright = next(iter(done)).result()
self._connection._child_ws_connections.append(connection)
playwright = await connection.wait_for_object_with_known_name("Playwright")
pre_launched_browser = playwright._initializer.get("preLaunchedBrowser")
assert pre_launched_browser
browser = cast(Browser, from_channel(pre_launched_browser))
Expand Down
7 changes: 7 additions & 0 deletions playwright/_impl/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,22 @@ def __init__(
self._is_sync = False
self._api_name = ""
self._child_ws_connections: List["Connection"] = []
self._initialized = False

async def run_as_sync(self) -> None:
self._is_sync = True
await self.run()
await self.initialize()

async def run(self) -> None:
self._loop = asyncio.get_running_loop()
self._root_object = RootChannelOwner(self)
await self._transport.run()

async def initialize(self) -> None:
await self._transport.wait_until_initialized
self._initialized = True

def stop_sync(self) -> None:
self._transport.request_stop()
self._dispatcher_fiber.switch()
Expand All @@ -190,6 +196,7 @@ def cleanup(self) -> None:
ws_connection._transport.dispose()

async def wait_for_object_with_known_name(self, guid: str) -> Any:
assert self._initialized
if guid in self._objects:
return self._objects[guid]
callback = self._loop.create_future()
Expand Down
73 changes: 49 additions & 24 deletions playwright/_impl/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,19 @@ def _get_stderr_fileno() -> Optional[int]:


class Transport(ABC):
def __init__(self) -> None:
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self.on_error_future: asyncio.Future
self.on_message = lambda _: None

self.wait_until_initialized: asyncio.Future = loop.create_future()
self._wait_until_initialized_set_success = (
lambda: self.wait_until_initialized.set_result(True)
)
self._wait_until_initialized_set_exception = (
lambda exc: self.wait_until_initialized.set_exception(exc)
)

@abstractmethod
def request_stop(self) -> None:
pass
Expand All @@ -57,7 +67,7 @@ async def wait_until_stopped(self) -> None:

async def run(self) -> None:
self._loop = asyncio.get_running_loop()
self.on_error_future: asyncio.Future = asyncio.Future()
self.on_error_future = asyncio.Future()

@abstractmethod
def send(self, message: Dict) -> None:
Expand All @@ -78,8 +88,10 @@ def deserialize_message(self, data: bytes) -> Any:


class PipeTransport(Transport):
def __init__(self, driver_executable: Path) -> None:
super().__init__()
def __init__(
self, loop: asyncio.AbstractEventLoop, driver_executable: Path
) -> None:
super().__init__(loop)
self._stopped = False
self._driver_executable = driver_executable
self._loop: asyncio.AbstractEventLoop
Expand All @@ -96,14 +108,21 @@ async def run(self) -> None:
await super().run()
self._stopped_future: asyncio.Future = asyncio.Future()

self._proc = proc = await asyncio.create_subprocess_exec(
str(self._driver_executable),
"run-driver",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=_get_stderr_fileno(),
limit=32768,
)
try:
self._proc = proc = await asyncio.create_subprocess_exec(
str(self._driver_executable),
"run-driver",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=_get_stderr_fileno(),
limit=32768,
)
except Exception as exc:
self._wait_until_initialized_set_exception(exc)
return

self._wait_until_initialized_set_success()

assert proc.stdout
assert proc.stdin
self._output = proc.stdin
Expand Down Expand Up @@ -138,16 +157,17 @@ def send(self, message: Dict) -> None:

class WebSocketTransport(AsyncIOEventEmitter, Transport):
def __init__(
self, ws_endpoint: str, timeout: float = None, headers: Dict[str, str] = None
self,
loop: asyncio.AbstractEventLoop,
ws_endpoint: str,
headers: Dict[str, str] = None,
) -> None:
super().__init__()
Transport.__init__(self)
super().__init__(loop)
Transport.__init__(self, loop)

self._stopped = False
self.ws_endpoint = ws_endpoint
self.timeout = timeout
self.headers = headers
self._loop: asyncio.AbstractEventLoop

def request_stop(self) -> None:
self._stopped = True
Expand All @@ -162,13 +182,17 @@ async def wait_until_stopped(self) -> None:
async def run(self) -> None:
await super().run()

options: Dict[str, Any] = {}
if self.timeout is not None:
options["close_timeout"] = self.timeout / 1000
options["ping_timeout"] = self.timeout / 1000
if self.headers is not None:
options["extra_headers"] = self.headers
self._connection = await websockets.connect(self.ws_endpoint, **options)
try:
self._connection = await websockets.connect(
self.ws_endpoint, extra_headers=self.headers
)
except Exception as err:
self._wait_until_initialized_set_exception(
Error(f"websockets.connect: {err}")
)
return

self._wait_until_initialized_set_success()

while not self._stopped:
try:
Expand All @@ -190,6 +214,7 @@ async def run(self) -> None:
except Exception as exc:
print(f"Received unhandled exception: {exc}")
self.on_error_future.set_exception(exc)
break

def send(self, message: Dict) -> None:
if self._stopped or self._connection.closed:
Expand Down
8 changes: 8 additions & 0 deletions playwright/_impl/_wait_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,11 @@ def listener(event_data: Any = None) -> None:

def result(self) -> asyncio.Future:
return self._result


def throw_on_timeout(timeout: float, exception: Exception) -> asyncio.Task:
async def throw() -> None:
await asyncio.sleep(timeout / 1000)
raise exception

return asyncio.create_task(throw())
5 changes: 4 additions & 1 deletion playwright/async_api/_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ def __init__(self) -> None:

async def __aenter__(self) -> AsyncPlaywright:
self._connection = Connection(
None, create_remote_object, PipeTransport(compute_driver_executable())
None,
create_remote_object,
PipeTransport(asyncio.get_event_loop(), compute_driver_executable()),
)
loop = asyncio.get_running_loop()
self._connection._loop = loop
loop.create_task(self._connection.run())
await self._connection.initialize()
playwright = AsyncPlaywright(
await self._connection.wait_for_object_with_known_name("Playwright")
)
Expand Down
2 changes: 1 addition & 1 deletion playwright/sync_api/_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def greenlet_main() -> None:
self._connection = Connection(
dispatcher_fiber,
create_remote_object,
PipeTransport(compute_driver_executable()),
PipeTransport(asyncio.new_event_loop(), compute_driver_executable()),
)

g_self = greenlet.getcurrent()
Expand Down
10 changes: 10 additions & 0 deletions tests/async/test_browsertype_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,13 @@ async def test_prevent_getting_video_path(
== "Path is not available when using browserType.connect(). Use save_as() to save a local copy."
)
remote_server.kill()


async def test_connect_to_closed_server_without_hangs(
browser_type: BrowserType, launch_server
):
remote_server = launch_server()
remote_server.kill()
with pytest.raises(Error) as exc:
await browser_type.connect(remote_server.ws_endpoint)
assert "websockets.connect: " in exc.value.message
10 changes: 4 additions & 6 deletions tests/async/test_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,12 +527,10 @@ async def test_timeout_waiting_for_stable_position(page, server):
}"""
)

error = None
try:
await button.click(timeout=5000)
except Error as e:
error = e
assert "Timeout 5000ms exceeded." in error.message
with pytest.raises(Error) as exc_info:
await button.click(timeout=3000)
error = exc_info.value
assert "Timeout 3000ms exceeded." in error.message
assert "waiting for element to be visible, enabled and stable" in error.message
assert "element is not stable - waiting" in error.message

Expand Down
10 changes: 10 additions & 0 deletions tests/sync/test_browsertype_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,13 @@ def test_browser_type_connect_should_forward_close_events_to_pages(
assert events == ["page::close", "context::close", "browser::disconnected"]
remote.kill()
assert events == ["page::close", "context::close", "browser::disconnected"]


def test_connect_to_closed_server_without_hangs(
browser_type: BrowserType, launch_server
):
remote_server = launch_server()
remote_server.kill()
with pytest.raises(Error) as exc:
browser_type.connect(remote_server.ws_endpoint)
assert "websockets.connect: " in exc.value.message