Skip to content
This repository was archived by the owner on Nov 23, 2017. It is now read-only.

Make loop TCP APIs accept only TCP streaming sockets #453

Merged
merged 1 commit into from
Nov 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,26 @@ def _set_reuseport(sock):
'SO_REUSEPORT defined but not implemented.')


# Linux's sock.type is a bitmask that can include extra info about socket.
_SOCKET_TYPE_MASK = 0
if hasattr(socket, 'SOCK_NONBLOCK'):
_SOCKET_TYPE_MASK |= socket.SOCK_NONBLOCK
if hasattr(socket, 'SOCK_CLOEXEC'):
_SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC
def _is_stream_socket(sock):
# Linux's socket.type is a bitmask that can include extra info
# about socket, therefore we can't do simple
# `sock_type == socket.SOCK_STREAM`.
return (sock.type & socket.SOCK_STREAM) == socket.SOCK_STREAM


def _is_dgram_socket(sock):
# Linux's socket.type is a bitmask that can include extra info
# about socket, therefore we can't do simple
# `sock_type == socket.SOCK_DGRAM`.
return (sock.type & socket.SOCK_DGRAM) == socket.SOCK_DGRAM


def _is_ip_socket(sock):
if sock.family == socket.AF_INET:
return True
if hasattr(socket, 'AF_INET6') and sock.family == socket.AF_INET6:
return True
return False


def _ipaddr_info(host, port, family, type, proto):
Expand All @@ -102,8 +116,12 @@ def _ipaddr_info(host, port, family, type, proto):
host is None:
return None

type &= ~_SOCKET_TYPE_MASK
if type == socket.SOCK_STREAM:
# Linux only:
# getaddrinfo() can raise when socket.type is a bit mask.
# So if socket.type is a bit mask of SOCK_STREAM, and say
# SOCK_NONBLOCK, we simply return None, which will trigger
# a call to getaddrinfo() letting it process this request.
proto = socket.IPPROTO_TCP
elif type == socket.SOCK_DGRAM:
proto = socket.IPPROTO_UDP
Expand All @@ -124,7 +142,9 @@ def _ipaddr_info(host, port, family, type, proto):
return None

if family == socket.AF_UNSPEC:
afs = [socket.AF_INET, socket.AF_INET6]
afs = [socket.AF_INET]
if hasattr(socket, 'AF_INET6'):
afs.append(socket.AF_INET6)
else:
afs = [family]

Expand Down Expand Up @@ -771,9 +791,13 @@ def create_connection(self, protocol_factory, host=None, port=None, *,
raise OSError('Multiple exceptions: {}'.format(
', '.join(str(exc) for exc in exceptions)))

elif sock is None:
raise ValueError(
'host and port was not specified and no sock specified')
else:
if sock is None:
raise ValueError(
'host and port was not specified and no sock specified')
if not _is_stream_socket(sock) or not _is_ip_socket(sock):
raise ValueError(
'A TCP Stream Socket was expected, got {!r}'.format(sock))

transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
Expand Down Expand Up @@ -817,6 +841,9 @@ def create_datagram_endpoint(self, protocol_factory,
allow_broadcast=None, sock=None):
"""Create datagram connection."""
if sock is not None:
if not _is_dgram_socket(sock):
raise ValueError(
'A UDP Socket was expected, got {!r}'.format(sock))
if (local_addr or remote_addr or
family or proto or flags or
reuse_address or reuse_port or allow_broadcast):
Expand Down Expand Up @@ -1027,6 +1054,9 @@ def create_server(self, protocol_factory, host=None, port=None,
else:
if sock is None:
raise ValueError('Neither host/port nor sock were specified')
if not _is_stream_socket(sock) or not _is_ip_socket(sock):
raise ValueError(
'A TCP Stream Socket was expected, got {!r}'.format(sock))
sockets = [sock]

server = Server(self, sockets)
Expand All @@ -1048,6 +1078,10 @@ def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None):
This method is a coroutine. When completed, the coroutine
returns a (transport, protocol) pair.
"""
if not _is_stream_socket(sock):
raise ValueError(
'A Stream Socket was expected, got {!r}'.format(sock))

transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True)
if self._debug:
Expand Down
4 changes: 2 additions & 2 deletions asyncio/unix_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def create_unix_connection(self, protocol_factory, path, *,
if sock is None:
raise ValueError('no path and sock were specified')
if (sock.family != socket.AF_UNIX or
sock.type != socket.SOCK_STREAM):
not base_events._is_stream_socket(sock)):
raise ValueError(
'A UNIX Domain Stream Socket was expected, got {!r}'
.format(sock))
Expand Down Expand Up @@ -289,7 +289,7 @@ def create_unix_server(self, protocol_factory, path=None, *,
'path was not specified, and no sock specified')

if (sock.family != socket.AF_UNIX or
sock.type != socket.SOCK_STREAM):
not base_events._is_stream_socket(sock)):
raise ValueError(
'A UNIX Domain Stream Socket was expected, got {!r}'
.format(sock))
Expand Down
63 changes: 55 additions & 8 deletions tests/test_base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ def test_ipaddr_info(self):
self.assertIsNone(
base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))

if hasattr(socket, 'SOCK_NONBLOCK'):
self.assertEqual(
None,
base_events._ipaddr_info(
'1.2.3.4', 1, INET, STREAM | socket.SOCK_NONBLOCK, TCP))


def test_port_parameter_types(self):
# Test obscure kinds of arguments for "port".
INET = socket.AF_INET
Expand Down Expand Up @@ -1040,6 +1047,43 @@ def test_create_connection_host_port_sock(self):
MyProto, 'example.com', 80, sock=object())
self.assertRaises(ValueError, self.loop.run_until_complete, coro)

@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
def test_create_connection_wrong_sock(self):
sock = socket.socket(socket.AF_UNIX)
with sock:
coro = self.loop.create_connection(MyProto, sock=sock)
with self.assertRaisesRegex(ValueError,
'A TCP Stream Socket was expected'):
self.loop.run_until_complete(coro)

@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
def test_create_server_wrong_sock(self):
sock = socket.socket(socket.AF_UNIX)
with sock:
coro = self.loop.create_server(MyProto, sock=sock)
with self.assertRaisesRegex(ValueError,
'A TCP Stream Socket was expected'):
self.loop.run_until_complete(coro)

@unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
'no socket.SOCK_NONBLOCK (linux only)')
def test_create_server_stream_bittype(self):
sock = socket.socket(
socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
with sock:
coro = self.loop.create_server(lambda: None, sock=sock)
srv = self.loop.run_until_complete(coro)
srv.close()
self.loop.run_until_complete(srv.wait_closed())

def test_create_datagram_endpoint_wrong_sock(self):
sock = socket.socket(socket.AF_INET)
with sock:
coro = self.loop.create_datagram_endpoint(MyProto, sock=sock)
with self.assertRaisesRegex(ValueError,
'A UDP Socket was expected'):
self.loop.run_until_complete(coro)

def test_create_connection_no_host_port_sock(self):
coro = self.loop.create_connection(MyProto)
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
Expand Down Expand Up @@ -1487,36 +1531,39 @@ def test_create_datagram_endpoint_sock(self):
self.assertEqual('CLOSED', protocol.state)

def test_create_datagram_endpoint_sock_sockopts(self):
class FakeSock:
type = socket.SOCK_DGRAM

fut = self.loop.create_datagram_endpoint(
MyDatagramProto, local_addr=('127.0.0.1', 0), sock=object())
MyDatagramProto, local_addr=('127.0.0.1', 0), sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)

fut = self.loop.create_datagram_endpoint(
MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=object())
MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)

fut = self.loop.create_datagram_endpoint(
MyDatagramProto, family=1, sock=object())
MyDatagramProto, family=1, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)

fut = self.loop.create_datagram_endpoint(
MyDatagramProto, proto=1, sock=object())
MyDatagramProto, proto=1, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)

fut = self.loop.create_datagram_endpoint(
MyDatagramProto, flags=1, sock=object())
MyDatagramProto, flags=1, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)

fut = self.loop.create_datagram_endpoint(
MyDatagramProto, reuse_address=True, sock=object())
MyDatagramProto, reuse_address=True, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)

fut = self.loop.create_datagram_endpoint(
MyDatagramProto, reuse_port=True, sock=object())
MyDatagramProto, reuse_port=True, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)

fut = self.loop.create_datagram_endpoint(
MyDatagramProto, allow_broadcast=True, sock=object())
MyDatagramProto, allow_broadcast=True, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)

def test_create_datagram_endpoint_sockopts(self):
Expand Down
11 changes: 8 additions & 3 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,9 +791,9 @@ def client():
conn, _ = lsock.accept()
proto = MyProto(loop=loop)
proto.loop = loop
f = loop.create_task(
loop.run_until_complete(
loop.connect_accepted_socket(
(lambda : proto), conn, ssl=server_ssl))
(lambda: proto), conn, ssl=server_ssl))
loop.run_forever()
proto.transport.close()
lsock.close()
Expand Down Expand Up @@ -1377,6 +1377,11 @@ def datagram_received(self, data, addr):
server.transport.close()

def test_create_datagram_endpoint_sock(self):
if (sys.platform == 'win32' and
isinstance(self.loop, proactor_events.BaseProactorEventLoop)):
raise unittest.SkipTest(
'UDP is not supported with proactor event loops')

sock = None
local_address = ('127.0.0.1', 0)
infos = self.loop.run_until_complete(
Expand All @@ -1394,7 +1399,7 @@ def test_create_datagram_endpoint_sock(self):
else:
assert False, 'Can not create socket.'

f = self.loop.create_connection(
f = self.loop.create_datagram_endpoint(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried here that this test doesn't test much...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right, it tests only that the method doesn't crash. Still it was testing a wrong method.

lambda: MyDatagramProto(loop=self.loop), sock=sock)
tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(tr, asyncio.Transport)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_unix_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,33 @@ def test_create_unix_server_path_inetsock(self):
'A UNIX Domain Stream.*was expected'):
self.loop.run_until_complete(coro)

def test_create_unix_server_path_dgram(self):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
with sock:
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=sock)
with self.assertRaisesRegex(ValueError,
'A UNIX Domain Stream.*was expected'):
self.loop.run_until_complete(coro)

@unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
'no socket.SOCK_NONBLOCK (linux only)')
def test_create_unix_server_path_stream_bittype(self):
sock = socket.socket(
socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
with tempfile.NamedTemporaryFile() as file:
fn = file.name
try:
with sock:
sock.bind(fn)
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=sock)
srv = self.loop.run_until_complete(coro)
srv.close()
self.loop.run_until_complete(srv.wait_closed())
finally:
os.unlink(fn)

def test_create_unix_connection_path_inetsock(self):
sock = socket.socket()
with sock:
Expand Down