Skip to content

fix typing when passing loaded config into Client.connect #998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions temporalio/envconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Mapping, Optional, Union, cast
from typing import Any, Dict, Literal, Mapping, Optional, Union, cast

from typing_extensions import TypeAlias, TypedDict

Expand Down Expand Up @@ -172,11 +172,11 @@ class ClientConnectConfig(TypedDict, total=False):
Experimental API.
"""

target_host: Optional[str]
namespace: Optional[str]
api_key: Optional[str]
tls: Optional[Union[bool, temporalio.service.TLSConfig]]
rpc_metadata: Optional[Mapping[str, str]]
target_host: str
namespace: str
api_key: str
tls: Union[bool, temporalio.service.TLSConfig]
rpc_metadata: Mapping[str, str]


@dataclass(frozen=True)
Expand Down Expand Up @@ -230,18 +230,26 @@ def to_dict(self) -> ClientConfigProfileDict:

def to_client_connect_config(self) -> ClientConnectConfig:
"""Create a `ClientConnectConfig` from this profile."""
config: ClientConnectConfig = {}
if self.address:
config["target_host"] = self.address
if self.namespace:
if not self.address:
raise ValueError(
"Configuration profile must contain an 'address' to be used for "
"client connection"
)

# Only include non-None values
config: Dict[str, Any] = {}
config["target_host"] = self.address
if self.namespace is not None:
config["namespace"] = self.namespace
if self.api_key:
if self.api_key is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this change is necessary, but I don't mind it.

config["api_key"] = self.api_key
if self.tls:
if self.tls is not None:
config["tls"] = self.tls.to_connect_tls_config()
if self.grpc_meta:
config["rpc_metadata"] = self.grpc_meta
return config

# Cast to ClientConnectConfig - this is safe because we've only included non-None values
return cast(ClientConnectConfig, config)

@staticmethod
def load(
Expand Down
17 changes: 12 additions & 5 deletions tests/test_envconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path):
config = ClientConfig.load_client_connect_config(config_file=str(config_file))
assert config.get("target_host") == target_host
assert config.get("namespace") == namespace
new_client = await Client.connect(**config) # type: ignore
new_client = await Client.connect(**config)
assert new_client.service_client.config.target_host == target_host
assert new_client.namespace == namespace

Expand All @@ -462,7 +462,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path):
rpc_metadata = config.get("rpc_metadata")
assert rpc_metadata
assert "custom-header" in rpc_metadata
new_client = await Client.connect(**config) # type: ignore
new_client = await Client.connect(**config)
assert new_client.service_client.config.target_host == target_host
assert new_client.namespace == "custom-namespace"
assert (
Expand All @@ -476,7 +476,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path):
)
assert config.get("target_host") == target_host
assert config.get("namespace") == "env-namespace-override"
new_client = await Client.connect(**config) # type: ignore
new_client = await Client.connect(**config)
assert new_client.namespace == "env-namespace-override"

# Test with env overrides disabled
Expand All @@ -487,7 +487,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path):
)
assert config.get("target_host") == target_host
assert config.get("namespace") == namespace
new_client = await Client.connect(**config) # type: ignore
new_client = await Client.connect(**config)
assert new_client.namespace == namespace

# Test with file loading disabled (so only env is used)
Expand All @@ -500,11 +500,18 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path):
)
assert config.get("target_host") == target_host
assert config.get("namespace") == "env-only-namespace"
new_client = await Client.connect(**config) # type: ignore
new_client = await Client.connect(**config)
assert new_client.service_client.config.target_host == target_host
assert new_client.namespace == "env-only-namespace"


def test_to_client_connect_config_missing_address_fails():
"""Test that to_client_connect_config raises a ValueError if address is missing."""
profile = ClientConfigProfile()
with pytest.raises(ValueError, match="must contain an 'address'"):
profile.to_client_connect_config()


def test_disables_raise_error():
"""Test that providing both disable_file and disable_env raises an error."""
with pytest.raises(RuntimeError, match="Cannot disable both"):
Expand Down
Loading