Skip to content

feat: added context managers #778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion playwright/_impl/_async_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

import asyncio
import traceback
from typing import Any, Awaitable, Callable, Generic, TypeVar
from types import TracebackType
from typing import Any, Awaitable, Callable, Generic, Type, TypeVar

from playwright._impl._impl_to_api_mapping import ImplToApiMapping, ImplWrapper

mapping = ImplToApiMapping()


T = TypeVar("T")
Self = TypeVar("Self", bound="AsyncBase")


class AsyncEventInfo(Generic[T]):
Expand Down Expand Up @@ -79,3 +81,16 @@ def once(self, event: str, f: Any) -> None:
def remove_listener(self, event: str, f: Any) -> None:
"""Removes the function ``f`` from ``event``."""
self._impl_obj.remove_listener(event, self._wrap_handler(f))


class AsyncContextManager(AsyncBase):
async def __aenter__(self: Self) -> Self:
return self

async def __aexit__(
self: Self,
exc_type: Type[BaseException],
exc_val: BaseException,
traceback: TracebackType,
) -> None:
await self.close() # type: ignore
16 changes: 16 additions & 0 deletions playwright/_impl/_sync_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import traceback
from types import TracebackType
from typing import (
Any,
Awaitable,
Expand All @@ -22,6 +23,7 @@
Generic,
List,
Optional,
Type,
TypeVar,
cast,
)
Expand All @@ -34,6 +36,7 @@


T = TypeVar("T")
Self = TypeVar("Self")


class EventInfo(Generic[T]):
Expand Down Expand Up @@ -152,3 +155,16 @@ async def task() -> None:
raise exceptions[0]

return list(map(lambda action: results[action], actions))


class SyncContextManager(SyncBase):
def __enter__(self: Self) -> Self:
return self

def __exit__(
self: Self,
exc_type: Type[BaseException],
exc_val: BaseException,
traceback: TracebackType,
) -> None:
self.close() # type: ignore
13 changes: 9 additions & 4 deletions playwright/async_api/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
StorageState,
ViewportSize,
)
from playwright._impl._async_base import AsyncBase, AsyncEventContextManager, mapping
from playwright._impl._async_base import (
AsyncBase,
AsyncContextManager,
AsyncEventContextManager,
mapping,
)
from playwright._impl._browser import Browser as BrowserImpl
from playwright._impl._browser_context import BrowserContext as BrowserContextImpl
from playwright._impl._browser_type import BrowserType as BrowserTypeImpl
Expand Down Expand Up @@ -4900,7 +4905,7 @@ async def delete(self) -> NoneType:
mapping.register(VideoImpl, Video)


class Page(AsyncBase):
class Page(AsyncContextManager):
def __init__(self, obj: PageImpl):
super().__init__(obj)

Expand Down Expand Up @@ -8101,7 +8106,7 @@ def expect_worker(
mapping.register(PageImpl, Page)


class BrowserContext(AsyncBase):
class BrowserContext(AsyncContextManager):
def __init__(self, obj: BrowserContextImpl):
super().__init__(obj)

Expand Down Expand Up @@ -8892,7 +8897,7 @@ async def detach(self) -> NoneType:
mapping.register(CDPSessionImpl, CDPSession)


class Browser(AsyncBase):
class Browser(AsyncContextManager):
def __init__(self, obj: BrowserImpl):
super().__init__(obj)

Expand Down
13 changes: 9 additions & 4 deletions playwright/sync_api/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@
from playwright._impl._page import Worker as WorkerImpl
from playwright._impl._playwright import Playwright as PlaywrightImpl
from playwright._impl._selectors import Selectors as SelectorsImpl
from playwright._impl._sync_base import EventContextManager, SyncBase, mapping
from playwright._impl._sync_base import (
EventContextManager,
SyncBase,
SyncContextManager,
mapping,
)
from playwright._impl._tracing import Tracing as TracingImpl
from playwright._impl._video import Video as VideoImpl

Expand Down Expand Up @@ -4873,7 +4878,7 @@ def delete(self) -> NoneType:
mapping.register(VideoImpl, Video)


class Page(SyncBase):
class Page(SyncContextManager):
def __init__(self, obj: PageImpl):
super().__init__(obj)

Expand Down Expand Up @@ -8055,7 +8060,7 @@ def expect_worker(
mapping.register(PageImpl, Page)


class BrowserContext(SyncBase):
class BrowserContext(SyncContextManager):
def __init__(self, obj: BrowserContextImpl):
super().__init__(obj)

Expand Down Expand Up @@ -8837,7 +8842,7 @@ def detach(self) -> NoneType:
mapping.register(CDPSessionImpl, CDPSession)


class Browser(SyncBase):
class Browser(SyncContextManager):
def __init__(self, obj: BrowserImpl):
super().__init__(obj)

Expand Down
13 changes: 7 additions & 6 deletions scripts/generate_async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ def generate(t: Any) -> None:
print("")
class_name = short_name(t)
base_class = t.__bases__[0].__name__
base_sync_class = (
"AsyncBase"
if base_class == "ChannelOwner" or base_class == "object"
else base_class
)
if class_name in ["Page", "BrowserContext", "Browser"]:
base_sync_class = "AsyncContextManager"
elif base_class in ["ChannelOwner", "object"]:
base_sync_class = "AsyncBase"
else:
base_sync_class = base_class
print(f"class {class_name}({base_sync_class}):")
print("")
print(f" def __init__(self, obj: {class_name}Impl):")
Expand Down Expand Up @@ -122,7 +123,7 @@ def generate(t: Any) -> None:
def main() -> None:
print(header)
print(
"from playwright._impl._async_base import AsyncEventContextManager, AsyncBase, mapping"
"from playwright._impl._async_base import AsyncEventContextManager, AsyncBase, AsyncContextManager, mapping"
)
print("NoneType = type(None)")

Expand Down
13 changes: 7 additions & 6 deletions scripts/generate_sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ def generate(t: Any) -> None:
print("")
class_name = short_name(t)
base_class = t.__bases__[0].__name__
base_sync_class = (
"SyncBase"
if base_class == "ChannelOwner" or base_class == "object"
else base_class
)
if class_name in ["Page", "BrowserContext", "Browser"]:
base_sync_class = "SyncContextManager"
elif base_class in ["ChannelOwner", "object"]:
base_sync_class = "SyncBase"
else:
base_sync_class = base_class
print(f"class {class_name}({base_sync_class}):")
print("")
print(f" def __init__(self, obj: {class_name}Impl):")
Expand Down Expand Up @@ -123,7 +124,7 @@ def main() -> None:

print(header)
print(
"from playwright._impl._sync_base import EventContextManager, SyncBase, mapping"
"from playwright._impl._sync_base import EventContextManager, SyncBase, SyncContextManager, mapping"
)
print("NoneType = type(None)")

Expand Down
26 changes: 26 additions & 0 deletions tests/async/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Microsoft Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from playwright.async_api import BrowserType


async def test_context_managers(browser_type: BrowserType, launch_arguments):
async with await browser_type.launch(**launch_arguments) as browser:
async with await browser.new_context() as context:
async with await context.new_page():
assert len(context.pages) == 1
assert len(context.pages) == 0
assert len(browser.contexts) == 1
assert len(browser.contexts) == 0
assert not browser.is_connected()
26 changes: 26 additions & 0 deletions tests/sync/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Microsoft Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from playwright.sync_api import BrowserType


def test_context_managers(browser_type: BrowserType, launch_arguments):
with browser_type.launch(**launch_arguments) as browser:
with browser.new_context() as context:
with context.new_page():
assert len(context.pages) == 1
assert len(context.pages) == 0
assert len(browser.contexts) == 1
assert len(browser.contexts) == 0
assert not browser.is_connected()