Skip to content

S3: add heuristic for S3 requests from IaC/SDK with no prefix #12678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions localstack-core/localstack/aws/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def get_client(

endpoint_url = endpoint_url or get_service_endpoint()
if service_name == "s3" and endpoint_url:
if re.match(r"https?://localhost(:[0-9]+)?", endpoint_url):
if re.match(r"https?://localhost(:[0-9]+)$", endpoint_url):
endpoint_url = endpoint_url.replace("://localhost", f"://{get_s3_hostname()}")

return self._get_client(
Expand Down Expand Up @@ -579,7 +579,7 @@ def get_client(

endpoint_url = endpoint_url or get_service_endpoint()
if service_name == "s3":
if re.match(r"https?://localhost(:[0-9]+)?", endpoint_url):
if re.match(r"https?://localhost(:[0-9]+)$", endpoint_url):
endpoint_url = endpoint_url.replace("://localhost", f"://{get_s3_hostname()}")

# Prevent `PartialCredentialsError` when only access key ID is provided
Expand Down
19 changes: 16 additions & 3 deletions localstack-core/localstack/aws/protocol/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import re
from abc import ABC
from email.utils import parsedate_to_datetime
from functools import lru_cache
from typing import IO, Any, Dict, List, Mapping, Optional, Tuple, Union
from xml.etree import ElementTree as ETree

Expand All @@ -87,6 +88,7 @@
from cbor2._decoder import loads as cbor2_loads
from werkzeug.exceptions import BadRequest, NotFound

from localstack import config
from localstack.aws.protocol.op_router import RestServiceOperationRouter
from localstack.http import Request

Expand Down Expand Up @@ -1006,7 +1008,7 @@ def __init__(self, request: Request):

def __enter__(self):
# only modify the request if it uses the virtual host addressing
if bucket_name := self._is_vhost_address_get_bucket(self.request):
if bucket_name := self._is_vhost_address_get_bucket(self.request.host):
# save the original path and host for restoring on context exit
self.old_path = self.request.path
self.old_host = self.request.host
Expand Down Expand Up @@ -1066,10 +1068,21 @@ def _set_request_props(
pass

@staticmethod
def _is_vhost_address_get_bucket(request: Request) -> str | None:
@lru_cache
def _is_vhost_address_get_bucket(host: str) -> str | None:
from localstack.services.s3.utils import uses_host_addressing

return uses_host_addressing(request.headers)
if bucket_name := uses_host_addressing({"host": host}):
return bucket_name

# FIXME: this is a hack to allow recognizing some virtual-hosted S3 requests targeting the regular
# LocalStack endpoint, and not the `s3.`-prefixed one.
# this is the case for CDK, that doesn't allow easy configuration of service specific endpoints. However,
# while this allows us to understand the S3 request, this is "limited support" as it doesn't work for
# pre-signed URL and S3-specific CORS.
pattern = r"(?P<bucket>.*)\." + config.LOCALSTACK_HOST.host + r"(?::\d.+)"
if (match := re.match(pattern, host)) and match.group("bucket") != "s3":
return match.group("bucket")

@_handle_exceptions
def parse(self, request: Request) -> Tuple[OperationModel, Any]:
Expand Down
5 changes: 3 additions & 2 deletions localstack-core/localstack/services/s3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
import zlib
from enum import StrEnum
from secrets import token_bytes
from typing import Any, Dict, Literal, NamedTuple, Optional, Protocol, Tuple, Union
from typing import Any, Literal, NamedTuple, Optional, Protocol, Tuple, Union
from urllib import parse as urlparser
from zoneinfo import ZoneInfo

import xmltodict
from botocore.exceptions import ClientError
from botocore.utils import InvalidArnException
from werkzeug.datastructures import Headers

from localstack import config, constants
from localstack.aws.api import CommonServiceException, RequestContext
Expand Down Expand Up @@ -482,7 +483,7 @@ def is_valid_canonical_id(canonical_id: str) -> bool:
return False


def uses_host_addressing(headers: Dict[str, str]) -> str | None:
def uses_host_addressing(headers: dict[str, str] | Headers) -> str | None:
"""
Determines if the request is targeting S3 with virtual host addressing
:param headers: the request headers
Expand Down
31 changes: 31 additions & 0 deletions tests/aws/services/s3/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3980,6 +3980,37 @@ def test_s3_hostname_with_subdomain(self, aws_http_client_factory, aws_client):
assert resp.ok
assert b"<Bucket" in resp.content

@markers.aws.only_localstack
def test_virtual_host_parsing_with_non_prefixed_endpoint(
self, s3_bucket, aws_client_factory, cleanups, aws_client, s3_empty_bucket
):
non_prefixed_endpoint = config.internal_service_url(host=LOCALHOST_HOSTNAME)
assert "s3." not in non_prefixed_endpoint
s3_client = aws_client_factory(
config=Config(s3={"addressing_style": "virtual"}),
endpoint_url=non_prefixed_endpoint,
).s3
bucket_name = f"bucket-{short_uid()}"
response = s3_client.create_bucket(Bucket=bucket_name)
cleanups.append(lambda: aws_client.s3.delete_bucket(Bucket=bucket_name))
cleanups.append(lambda: s3_empty_bucket(bucket_name))
assert bucket_name in response["Location"]

response = s3_client.put_object(Bucket=bucket_name, Key="test/key", Body="test")
assert response["ETag"]

response = s3_client.get_object(Bucket=bucket_name, Key="test/key")
assert response["Body"].read() == b"test"

list_buckets = s3_client.list_buckets()
assert len(list_buckets["Buckets"]) >= 2
bucket_names = [bucket["Name"] for bucket in list_buckets["Buckets"]]
assert s3_bucket in bucket_names
assert bucket_name in bucket_names

list_objects = s3_client.list_objects_v2(Bucket=bucket_name)
assert len(list_objects["Contents"]) >= 1

@pytest.mark.skipif(condition=TEST_S3_IMAGE, reason="Lambda not enabled in S3 image")
@markers.skip_offline
@markers.aws.validated
Expand Down
Loading