Skip to content

[PoC] Fix stacklevel for warning #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 60 additions & 54 deletions Lib/ldap/ldapobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class SimpleLDAPObject:
"""
Drop-in wrapper class around _ldap.LDAPObject
"""
_stacklevel = 3

CLASSATTR_OPTION_MAPPING = {
"protocol_version": ldap.OPT_PROTOCOL_VERSION,
Expand Down Expand Up @@ -102,7 +103,7 @@ def __init__(
# On by default on Py2, off on Py3.
self.bytes_mode = bytes_mode

def _bytesify_input(self, value):
def _bytesify_input(self, value, _stackup=0):
"""Adapt a value following bytes_mode in Python 2.

In Python 3, returns the original value unmodified.
Expand Down Expand Up @@ -130,7 +131,7 @@ def _bytesify_input(self, value):
"Received non-bytes value %r with default (disabled) bytes mode; please choose an explicit "
"option for bytes_mode on your LDAP connection" % (value,),
LDAPBytesWarning,
stacklevel=6,
stacklevel=self._stacklevel + _stackup,
)
return value.encode('utf-8')
else:
Expand All @@ -139,21 +140,6 @@ def _bytesify_input(self, value):
assert not isinstance(value, bytes)
return value.encode('utf-8')

def _bytesify_inputs(self, *values):
"""Adapt values following bytes_mode.

Applies _bytesify_input on each arg.

Usage:
>>> a, b, c = self._bytesify_inputs(a, b, c)
"""
if not PY2:
return values
return (
self._bytesify_input(value)
for value in values
)

def _bytesify_modlist(self, modlist, with_opcode):
"""Adapt a modlist according to bytes_mode.

Expand All @@ -166,12 +152,12 @@ def _bytesify_modlist(self, modlist, with_opcode):
return modlist
if with_opcode:
return tuple(
(op, self._bytesify_input(attr), val)
(op, self._bytesify_input(attr, _stackup=1), val)
for op, attr, val in modlist
)
else:
return tuple(
(self._bytesify_input(attr), val)
(self._bytesify_input(attr, _stackup=1), val)
for attr, val in modlist
)

Expand Down Expand Up @@ -380,8 +366,9 @@ def add_ext(self,dn,modlist,serverctrls=None,clientctrls=None):
The parameter modlist is similar to the one passed to modify(),
except that no operation integer need be included in the tuples.
"""
dn = self._bytesify_input(dn)
modlist = self._bytesify_modlist(modlist, with_opcode=False)
if PY2:
dn = self._bytesify_input(dn)
modlist = self._bytesify_modlist(modlist, with_opcode=False)
return self._ldap_call(self._l.add_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def add_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None):
Expand All @@ -406,7 +393,9 @@ def simple_bind(self,who='',cred='',serverctrls=None,clientctrls=None):
"""
simple_bind([who='' [,cred='']]) -> int
"""
who, cred = self._bytesify_inputs(who, cred)
if PY2:
who = self._bytesify_input(who)
cred = self._bytesify_input(cred)
return self._ldap_call(self._l.simple_bind,who,cred,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def simple_bind_s(self,who='',cred='',serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -465,7 +454,7 @@ def sasl_bind_s(self,dn,mechanism,cred,serverctrls=None,clientctrls=None):
"""
return self._ldap_call(self._l.sasl_bind_s,dn,mechanism,cred,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def compare_ext(self,dn,attr,value,serverctrls=None,clientctrls=None):
def compare_ext(self,dn,attr,value,serverctrls=None,clientctrls=None, _stackup=0):
"""
compare_ext(dn, attr, value [,serverctrls=None[,clientctrls=None]]) -> int
compare_ext_s(dn, attr, value [,serverctrls=None[,clientctrls=None]]) -> int
Expand All @@ -483,11 +472,13 @@ def compare_ext(self,dn,attr,value,serverctrls=None,clientctrls=None):
A design bug in the library prevents value from containing
nul characters.
"""
dn, attr = self._bytesify_inputs(dn, attr)
if PY2:
dn = self._bytesify_input(dn, _stackup)
attr = self._bytesify_input(attr, _stackup)
return self._ldap_call(self._l.compare_ext,dn,attr,value,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def compare_ext_s(self,dn,attr,value,serverctrls=None,clientctrls=None):
msgid = self.compare_ext(dn,attr,value,serverctrls,clientctrls)
def compare_ext_s(self,dn,attr,value,serverctrls=None,clientctrls=None, _stackup=0):
msgid = self.compare_ext(dn,attr,value,serverctrls,clientctrls, _stackup=_stackup+1)
try:
ldap_res = self.result3(msgid,all=1,timeout=self.timeout)
except ldap.COMPARE_TRUE:
Expand All @@ -499,12 +490,12 @@ def compare_ext_s(self,dn,attr,value,serverctrls=None,clientctrls=None):
)

def compare(self,dn,attr,value):
return self.compare_ext(dn,attr,value,None,None)
return self.compare_ext(dn,attr,value,None,None, _stackup=1)

def compare_s(self,dn,attr,value):
return self.compare_ext_s(dn,attr,value,None,None)
return self.compare_ext_s(dn,attr,value,None,None, _stackup=1)

def delete_ext(self,dn,serverctrls=None,clientctrls=None):
def delete_ext(self,dn,serverctrls=None,clientctrls=None, _stackup=0):
"""
delete(dn) -> int
delete_s(dn) -> None
Expand All @@ -514,19 +505,20 @@ def delete_ext(self,dn,serverctrls=None,clientctrls=None):
form returns the message id of the initiated request, and the
result can be obtained from a subsequent call to result().
"""
dn = self._bytesify_input(dn)
if PY2:
dn = self._bytesify_input(dn, _stackup)
return self._ldap_call(self._l.delete_ext,dn,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def delete_ext_s(self,dn,serverctrls=None,clientctrls=None):
msgid = self.delete_ext(dn,serverctrls,clientctrls)
def delete_ext_s(self,dn,serverctrls=None,clientctrls=None, _stackup=0):
msgid = self.delete_ext(dn,serverctrls,clientctrls, _stackup=_stackup+1)
resp_type, resp_data, resp_msgid, resp_ctrls = self.result3(msgid,all=1,timeout=self.timeout)
return resp_type, resp_data, resp_msgid, resp_ctrls

def delete(self,dn):
return self.delete_ext(dn,None,None)
return self.delete_ext(dn,None,None, _stackup=1)

def delete_s(self,dn):
return self.delete_ext_s(dn,None,None)
return self.delete_ext_s(dn,None,None, _stackup=1)

def extop(self,extreq,serverctrls=None,clientctrls=None):
"""
Expand Down Expand Up @@ -563,8 +555,9 @@ def modify_ext(self,dn,modlist,serverctrls=None,clientctrls=None):
"""
modify_ext(dn, modlist[,serverctrls=None[,clientctrls=None]]) -> int
"""
dn = self._bytesify_input(dn)
modlist = self._bytesify_modlist(modlist, with_opcode=True)
if PY2:
dn = self._bytesify_input(dn)
modlist = self._bytesify_modlist(modlist, with_opcode=True)
return self._ldap_call(self._l.modify_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def modify_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -598,7 +591,7 @@ def modify(self,dn,modlist):
def modify_s(self,dn,modlist):
return self.modify_ext_s(dn,modlist,None,None)

def modrdn(self,dn,newrdn,delold=1):
def modrdn(self,dn,newrdn,delold=1,_stackup=0):
"""
modrdn(dn, newrdn [,delold=1]) -> int
modrdn_s(dn, newrdn [,delold=1]) -> None
Expand All @@ -612,20 +605,23 @@ def modrdn(self,dn,newrdn,delold=1):
This operation is emulated by rename() and rename_s() methods
since the modrdn2* routines in the C library are deprecated.
"""
return self.rename(dn,newrdn,None,delold)
return self.rename(dn,newrdn,None,delold, _stackup=_stackup+1)

def modrdn_s(self,dn,newrdn,delold=1):
return self.rename_s(dn,newrdn,None,delold)
return self.rename_s(dn,newrdn,None,delold, _stackup=1)

def passwd(self,user,oldpw,newpw,serverctrls=None,clientctrls=None):
user, oldpw, newpw = self._bytesify_inputs(user, oldpw, newpw)
if PY2:
user = self._bytesify_input(user)
oldpw = self._bytesify_input(oldpw)
newpw = self._bytesify_input(newpw)
return self._ldap_call(self._l.passwd,user,oldpw,newpw,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def passwd_s(self,user,oldpw,newpw,serverctrls=None,clientctrls=None):
msgid = self.passwd(user,oldpw,newpw,serverctrls,clientctrls)
return self.extop_result(msgid,all=1,timeout=self.timeout)

def rename(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls=None):
def rename(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls=None, _stackup=0):
"""
rename(dn, newrdn [, newsuperior=None [,delold=1][,serverctrls=None[,clientctrls=None]]]) -> int
rename_s(dn, newrdn [, newsuperior=None] [,delold=1][,serverctrls=None[,clientctrls=None]]) -> None
Expand All @@ -640,11 +636,14 @@ def rename(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls
This actually corresponds to the rename* routines in the
LDAP-EXT C API library.
"""
dn, newrdn, newsuperior = self._bytesify_inputs(dn, newrdn, newsuperior)
if PY2:
dn = self._bytesify_input(dn, _stackup)
newrdn = self._bytesify_input(newrdn, _stackup)
newsuperior = self._bytesify_input(newsuperior, _stackup)
return self._ldap_call(self._l.rename,dn,newrdn,newsuperior,delold,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def rename_s(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls=None):
msgid = self.rename(dn,newrdn,newsuperior,delold,serverctrls,clientctrls)
msgid = self.rename(dn,newrdn,newsuperior,delold,serverctrls,clientctrls, _stackup=1)
resp_type, resp_data, resp_msgid, resp_ctrls = self.result3(msgid,all=1,timeout=self.timeout)
return resp_type, resp_data, resp_msgid, resp_ctrls

Expand Down Expand Up @@ -733,7 +732,7 @@ def result4(self,msgid=ldap.RES_ANY,all=1,timeout=None,add_ctrls=0,add_intermedi
resp_data = self._bytesify_results(resp_data, with_ctrls=add_ctrls)
return resp_type, resp_data, resp_msgid, decoded_resp_ctrls, resp_name, resp_value

def search_ext(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0,serverctrls=None,clientctrls=None,timeout=-1,sizelimit=0):
def search_ext(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0,serverctrls=None,clientctrls=None,timeout=-1,sizelimit=0,_stackup=0):
"""
search(base, scope [,filterstr='(objectClass=*)' [,attrlist=None [,attrsonly=0]]]) -> int
search_s(base, scope [,filterstr='(objectClass=*)' [,attrlist=None [,attrsonly=0]]])
Expand Down Expand Up @@ -778,9 +777,12 @@ def search_ext(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrson
The amount of search results retrieved can be limited with the
sizelimit parameter if non-zero.
"""
base, filterstr = self._bytesify_inputs(base, filterstr)
if attrlist is not None:
attrlist = tuple(self._bytesify_inputs(*attrlist))
if PY2:
bytesify_input = self._bytesify_input
base = bytesify_input(base, _stackup)
filterstr = bytesify_input(filterstr, _stackup)
if attrlist is not None:
attrlist = tuple(bytesify_input(attr, _stackup+1) for attr in attrlist)
return self._ldap_call(
self._l.search_ext,
base,scope,filterstr,
Expand All @@ -790,18 +792,18 @@ def search_ext(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrson
timeout,sizelimit,
)

def search_ext_s(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0,serverctrls=None,clientctrls=None,timeout=-1,sizelimit=0):
msgid = self.search_ext(base,scope,filterstr,attrlist,attrsonly,serverctrls,clientctrls,timeout,sizelimit)
def search_ext_s(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0,serverctrls=None,clientctrls=None,timeout=-1,sizelimit=0, _stackup=1):
msgid = self.search_ext(base,scope,filterstr,attrlist,attrsonly,serverctrls,clientctrls,timeout,sizelimit, _stackup=_stackup)
return self.result(msgid,all=1,timeout=timeout)[1]

def search(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0):
return self.search_ext(base,scope,filterstr,attrlist,attrsonly,None,None)
return self.search_ext(base,scope,filterstr,attrlist,attrsonly,None,None, _stackup=1)

def search_s(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0):
return self.search_ext_s(base,scope,filterstr,attrlist,attrsonly,None,None,timeout=self.timeout)
def search_s(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0, _stackup=0):
return self.search_ext_s(base,scope,filterstr,attrlist,attrsonly,None,None,timeout=self.timeout, _stackup=_stackup+2)

def search_st(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0,timeout=-1):
return self.search_ext_s(base,scope,filterstr,attrlist,attrsonly,None,None,timeout)
return self.search_ext_s(base,scope,filterstr,attrlist,attrsonly,None,None,timeout, _stackup=2)

def start_tls_s(self):
"""
Expand Down Expand Up @@ -879,7 +881,8 @@ def search_subschemasubentry_s(self,dn=''):
"""
try:
r = self.search_s(
dn,ldap.SCOPE_BASE,'(objectClass=*)',['subschemaSubentry']
dn,ldap.SCOPE_BASE,'(objectClass=*)',['subschemaSubentry'],
_stackup=1
)
except (ldap.NO_SUCH_OBJECT,ldap.NO_SUCH_ATTRIBUTE,ldap.INSUFFICIENT_ACCESS):
r = []
Expand Down Expand Up @@ -991,6 +994,9 @@ class ReconnectLDAPObject(SimpleLDAPObject):
application.
"""

# public method + _apply_method_s()
_stacklevel = SimpleLDAPObject._stacklevel + 2

__transient_attrs__ = set([
'_l',
'_ldap_object_lock',
Expand Down
54 changes: 53 additions & 1 deletion Tests/t_ldapobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
PY2 = False
text_type = str

import contextlib
import linecache
import os
import unittest
import warnings
import pickle
import warnings
from slapdtest import SlapdTestCase, requires_sasl
Expand Down Expand Up @@ -329,7 +332,7 @@ def test_ldapbyteswarning(self):
self.assertIsInstance(self.server.suffix, text_type)
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
warnings.simplefilter('default')
warnings.simplefilter('always', ldap.LDAPBytesWarning)
conn = self._get_bytes_ldapobject(explicit=False)
result = conn.search_s(
self.server.suffix,
Expand All @@ -350,6 +353,55 @@ def test_ldapbyteswarning(self):
"LDAP connection" % self.server.suffix
)

@contextlib.contextmanager
def catch_byteswarnings(self, *args, **kwargs):
with warnings.catch_warnings(record=True) as w:
conn = self._get_bytes_ldapobject(*args, **kwargs)
warnings.resetwarnings()
warnings.simplefilter('always', ldap.LDAPBytesWarning)
yield conn, w

def _test_byteswarning_level_search(self, methodname):
with self.catch_byteswarnings(explicit=False) as (conn, w):
method = getattr(conn, methodname)
result = method(
self.server.suffix.encode('utf-8'),
ldap.SCOPE_SUBTREE,
'(cn=Foo*)',
attrlist=['*'], # CORRECT LINE
)
self.assertEqual(len(result), 4)

self.assertEqual(len(w), 2, w)

self.assertIs(w[0].category, ldap.LDAPBytesWarning)
self.assertIn(
u"Received non-bytes value u'(cn=Foo*)'",
text_type(w[0].message)
)
self.assertEqual(w[0].filename, __file__)
self.assertIn(
'CORRECT LINE',
linecache.getline(w[0].filename, w[0].lineno)
)

self.assertIs(w[1].category, ldap.LDAPBytesWarning)
self.assertIn(
u"Received non-bytes value u'*'",
text_type(w[1].message)
)
self.assertEqual(w[1].filename, __file__)
self.assertIn(
'CORRECT LINE',
linecache.getline(w[1].filename, w[1].lineno)
)

@unittest.skipUnless(PY2, "no bytes_mode under Py3")
def test_byteswarning_level_search(self):
self._test_byteswarning_level_search('search_s')
self._test_byteswarning_level_search('search_st')
self._test_byteswarning_level_search('search_ext_s')


class Test01_ReconnectLDAPObject(Test00_SimpleLDAPObject):
"""
Expand Down