Skip to content

Auto-detect stacklevel for bytes warning #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 43 additions & 31 deletions Lib/ldap/ldapobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
'LDAPBytesWarning'
]

_LDAP_WARN_SKIP_FRAME = True

if __debug__:
# Tracing is only supported in debugging mode
Expand Down Expand Up @@ -55,6 +56,7 @@ class SimpleLDAPObject:
"""
Drop-in wrapper class around _ldap.LDAPObject
"""
_warn_frameup = 2

CLASSATTR_OPTION_MAPPING = {
"protocol_version": ldap.OPT_PROTOCOL_VERSION,
Expand Down Expand Up @@ -102,7 +104,7 @@ def __init__(
# On by default on Py2, off on Py3.
self.bytes_mode = bytes_mode

def _bytesify_input(self, value):
def _bytesify_input(self, value, frameup=0):
"""Adapt a value following bytes_mode in Python 2.

In Python 3, returns the original value unmodified.
Expand All @@ -126,11 +128,17 @@ def _bytesify_input(self, value):
if self.bytes_mode_hardfail:
raise TypeError("All provided fields *must* be bytes when bytes mode is on; got %r" % (value,))
else:
stacklevel = self._warn_frameup + frameup
frame = sys._getframe(stacklevel)
# walk up the stacks until we leave the file
while frame and frame.f_globals.get('_LDAP_WARN_SKIP_FRAME'):
stacklevel += 1
frame = frame.f_back
warnings.warn(
"Received non-bytes value %r with default (disabled) bytes mode; please choose an explicit "
"option for bytes_mode on your LDAP connection" % (value,),
LDAPBytesWarning,
stacklevel=6,
stacklevel=stacklevel+1,
)
return value.encode('utf-8')
else:
Expand All @@ -139,21 +147,6 @@ def _bytesify_input(self, value):
assert not isinstance(value, bytes)
return value.encode('utf-8')

def _bytesify_inputs(self, *values):
"""Adapt values following bytes_mode.

Applies _bytesify_input on each arg.

Usage:
>>> a, b, c = self._bytesify_inputs(a, b, c)
"""
if not PY2:
return values
return (
self._bytesify_input(value)
for value in values
)

def _bytesify_modlist(self, modlist, with_opcode):
"""Adapt a modlist according to bytes_mode.

Expand All @@ -166,12 +159,12 @@ def _bytesify_modlist(self, modlist, with_opcode):
return modlist
if with_opcode:
return tuple(
(op, self._bytesify_input(attr), val)
(op, self._bytesify_input(attr, 1), val)
for op, attr, val in modlist
)
else:
return tuple(
(self._bytesify_input(attr), val)
(self._bytesify_input(attr, 1), val)
for attr, val in modlist
)

Expand Down Expand Up @@ -380,8 +373,9 @@ def add_ext(self,dn,modlist,serverctrls=None,clientctrls=None):
The parameter modlist is similar to the one passed to modify(),
except that no operation integer need be included in the tuples.
"""
dn = self._bytesify_input(dn)
modlist = self._bytesify_modlist(modlist, with_opcode=False)
if PY2:
dn = self._bytesify_input(dn)
modlist = self._bytesify_modlist(modlist, with_opcode=False)
return self._ldap_call(self._l.add_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def add_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None):
Expand All @@ -406,7 +400,9 @@ def simple_bind(self,who='',cred='',serverctrls=None,clientctrls=None):
"""
simple_bind([who='' [,cred='']]) -> int
"""
who, cred = self._bytesify_inputs(who, cred)
if PY2:
who = self._bytesify_input(who)
cred = self._bytesify_input(cred)
return self._ldap_call(self._l.simple_bind,who,cred,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def simple_bind_s(self,who='',cred='',serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -483,7 +479,9 @@ def compare_ext(self,dn,attr,value,serverctrls=None,clientctrls=None):
A design bug in the library prevents value from containing
nul characters.
"""
dn, attr = self._bytesify_inputs(dn, attr)
if PY2:
dn = self._bytesify_input(dn)
attr = self._bytesify_input(attr)
return self._ldap_call(self._l.compare_ext,dn,attr,value,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def compare_ext_s(self,dn,attr,value,serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -514,7 +512,8 @@ def delete_ext(self,dn,serverctrls=None,clientctrls=None):
form returns the message id of the initiated request, and the
result can be obtained from a subsequent call to result().
"""
dn = self._bytesify_input(dn)
if PY2:
dn = self._bytesify_input(dn)
return self._ldap_call(self._l.delete_ext,dn,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def delete_ext_s(self,dn,serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -563,8 +562,9 @@ def modify_ext(self,dn,modlist,serverctrls=None,clientctrls=None):
"""
modify_ext(dn, modlist[,serverctrls=None[,clientctrls=None]]) -> int
"""
dn = self._bytesify_input(dn)
modlist = self._bytesify_modlist(modlist, with_opcode=True)
if PY2:
dn = self._bytesify_input(dn)
modlist = self._bytesify_modlist(modlist, with_opcode=True)
return self._ldap_call(self._l.modify_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def modify_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -618,7 +618,10 @@ def modrdn_s(self,dn,newrdn,delold=1):
return self.rename_s(dn,newrdn,None,delold)

def passwd(self,user,oldpw,newpw,serverctrls=None,clientctrls=None):
user, oldpw, newpw = self._bytesify_inputs(user, oldpw, newpw)
if PY2:
user = self._bytesify_input(user)
oldpw = self._bytesify_input(oldpw)
newpw = self._bytesify_input(newpw)
return self._ldap_call(self._l.passwd,user,oldpw,newpw,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def passwd_s(self,user,oldpw,newpw,serverctrls=None,clientctrls=None):
Expand All @@ -640,7 +643,10 @@ def rename(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls
This actually corresponds to the rename* routines in the
LDAP-EXT C API library.
"""
dn, newrdn, newsuperior = self._bytesify_inputs(dn, newrdn, newsuperior)
if PY2:
dn = self._bytesify_input(dn)
newrdn = self._bytesify_input(newrdn)
newsuperior = self._bytesify_input(newsuperior)
return self._ldap_call(self._l.rename,dn,newrdn,newsuperior,delold,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def rename_s(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -778,9 +784,12 @@ def search_ext(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrson
The amount of search results retrieved can be limited with the
sizelimit parameter if non-zero.
"""
base, filterstr = self._bytesify_inputs(base, filterstr)
if attrlist is not None:
attrlist = tuple(self._bytesify_inputs(*attrlist))
if PY2:
bytesify_input = self._bytesify_input
base = bytesify_input(base)
filterstr = bytesify_input(filterstr)
if attrlist is not None:
attrlist = tuple(bytesify_input(attr, 1) for attr in attrlist)
return self._ldap_call(
self._l.search_ext,
base,scope,filterstr,
Expand Down Expand Up @@ -991,6 +1000,9 @@ class ReconnectLDAPObject(SimpleLDAPObject):
application.
"""

# public method + _apply_method_s()
_warn_frameup = SimpleLDAPObject._warn_frameup + 2

__transient_attrs__ = set([
'_l',
'_ldap_object_lock',
Expand Down
54 changes: 53 additions & 1 deletion Tests/t_ldapobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
PY2 = False
text_type = str

import contextlib
import linecache
import os
import unittest
import warnings
import pickle
import warnings
from slapdtest import SlapdTestCase, requires_sasl
Expand Down Expand Up @@ -329,7 +332,7 @@ def test_ldapbyteswarning(self):
self.assertIsInstance(self.server.suffix, text_type)
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
warnings.simplefilter('default')
warnings.simplefilter('always', ldap.LDAPBytesWarning)
conn = self._get_bytes_ldapobject(explicit=False)
result = conn.search_s(
self.server.suffix,
Expand All @@ -350,6 +353,55 @@ def test_ldapbyteswarning(self):
"LDAP connection" % self.server.suffix
)

@contextlib.contextmanager
def catch_byteswarnings(self, *args, **kwargs):
with warnings.catch_warnings(record=True) as w:
conn = self._get_bytes_ldapobject(*args, **kwargs)
warnings.resetwarnings()
warnings.simplefilter('always', ldap.LDAPBytesWarning)
yield conn, w

def _test_byteswarning_level_search(self, methodname):
with self.catch_byteswarnings(explicit=False) as (conn, w):
method = getattr(conn, methodname)
result = method(
self.server.suffix.encode('utf-8'),
ldap.SCOPE_SUBTREE,
'(cn=Foo*)',
attrlist=['*'], # CORRECT LINE
)
self.assertEqual(len(result), 4)

self.assertEqual(len(w), 2, w)

self.assertIs(w[0].category, ldap.LDAPBytesWarning)
self.assertIn(
u"Received non-bytes value u'(cn=Foo*)'",
text_type(w[0].message)
)
self.assertEqual(w[0].filename, __file__)
self.assertIn(
'CORRECT LINE',
linecache.getline(w[0].filename, w[0].lineno)
)

self.assertIs(w[1].category, ldap.LDAPBytesWarning)
self.assertIn(
u"Received non-bytes value u'*'",
text_type(w[1].message)
)
self.assertEqual(w[1].filename, __file__)
self.assertIn(
'CORRECT LINE',
linecache.getline(w[1].filename, w[1].lineno)
)

@unittest.skipUnless(PY2, "no bytes_mode under Py3")
def test_byteswarning_level_search(self):
self._test_byteswarning_level_search('search_s')
self._test_byteswarning_level_search('search_st')
self._test_byteswarning_level_search('search_ext_s')


class Test01_ReconnectLDAPObject(Test00_SimpleLDAPObject):
"""
Expand Down