diff --git a/CHANGES.txt b/CHANGES.txt index fa53044..0a40619 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,10 @@ +Development +----------- + +- Add attrlist to LDAPQuery, set_login_query, set_groups_query to allow + retrieved attributes to be filtered. (Default: All attributes on users, no + attributes on groups). + 0.1 --- diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 80fbb23..84bb893 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -103,4 +103,4 @@ Contributors ------------ - Chris McDonough, 2013/04/24 - +- Charles Duffy (Indeed.com), 2013/05/03 diff --git a/docs/index.rst b/docs/index.rst index fc4a475..61ab9dc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -95,7 +95,8 @@ the startup phase of your Pyramid application. ``Configurator.ldap_set_login_query`` This configurator method accepts parameters which tell ``pyramid_ldap`` - how to find a user based on a login. Invoking this method allows the LDAP + how to find a user based on a login, and which LDAP attributes to retrieve + during user lookups. Invoking this method allows the LDAP connector's ``authenticate`` method to work. See :func:`pyramid_ldap.ldap_set_login_query` for argument details. @@ -105,8 +106,9 @@ the startup phase of your Pyramid application. ``Configurator.ldap_set_groups_query`` This configurator method accepts parameters which tell ``pyramid_ldap`` - how to find groups based on a user DN. Invoking this method allows the - connector's ``user_groups`` method to work. See + how to find groups based on a user DN, and which group attributes to + retrieve during lookups. Invoking this method allows the connector's + ``user_groups`` method to work. See :func:`pyramid_ldap.ldap_set_groups_query` for argument details. If ``ldap_set_groups_query`` is not called, the diff --git a/pyramid_ldap/__init__.py b/pyramid_ldap/__init__.py index ba56ea6..e87f016 100644 --- a/pyramid_ldap/__init__.py +++ b/pyramid_ldap/__init__.py @@ -27,17 +27,22 @@ def __init__(self, *arg, **kw): class _LDAPQuery(object): """ Represents an LDAP query. Provides rudimentary in-RAM caching of query results.""" - def __init__(self, base_dn, filter_tmpl, scope, cache_period): + def __init__(self, base_dn, filter_tmpl, scope, attrlist, cache_period): self.base_dn = base_dn self.filter_tmpl = filter_tmpl self.scope = scope + if attrlist is not None: + self.attrlist = tuple(sorted(attrlist)) + else: + self.attrlist = None self.cache_period = cache_period self.last_timeslice = 0 self.cache = {} def __str__(self): return ('base_dn=%(base_dn)s, filter_tmpl=%(filter_tmpl)s, ' - 'scope=%(scope)s, cache_period=%(cache_period)s' % + 'scope=%(scope)s, attrlist=%(attrlist)r, ' + 'cache_period=%(cache_period)s' % self.__dict__) def query_cache(self, cache_key): @@ -61,7 +66,8 @@ def execute(self, conn, **kw): cache_key = ( bytes_(self.base_dn % kw, 'utf-8'), self.scope, - bytes_(self.filter_tmpl % kw, 'utf-8') + bytes_(self.filter_tmpl % kw, 'utf-8'), + self.attrlist, ) logger.debug('searching for %r' % (cache_key,)) @@ -159,7 +165,8 @@ def user_groups(self, userdn): return None def ldap_set_login_query(config, base_dn, filter_tmpl, - scope=ldap.SCOPE_ONELEVEL, cache_period=0): + scope=ldap.SCOPE_ONELEVEL, cache_period=0, + attrlist=None): """ Configurator method to set the LDAP login search. ``base_dn`` is the DN at which to begin the search. ``filter_tmpl`` is a string which can be used as an LDAP filter: it should contain the replacement value @@ -179,7 +186,7 @@ def ldap_set_login_query(config, base_dn, filter_tmpl, The registered search must return one and only one value to be considered a valid login. """ - query = _LDAPQuery(base_dn, filter_tmpl, scope, cache_period) + query = _LDAPQuery(base_dn, filter_tmpl, scope, attrlist, cache_period) def register(): config.registry.ldap_login_query = query @@ -193,7 +200,8 @@ def register(): config.action('ldap-set-login-query', register, introspectables=(intr,)) def ldap_set_groups_query(config, base_dn, filter_tmpl, - scope=ldap.SCOPE_SUBTREE, cache_period=0): + scope=ldap.SCOPE_SUBTREE, cache_period=0, + attrlist=('',)): """ Configurator method to set the LDAP groups search. ``base_dn`` is the DN at which to begin the search. ``filter_tmpl`` is a string which can be used as an LDAP filter: it should contain the replacement value @@ -211,7 +219,7 @@ def ldap_set_groups_query(config, base_dn, filter_tmpl, ) """ - query = _LDAPQuery(base_dn, filter_tmpl, scope, cache_period) + query = _LDAPQuery(base_dn, filter_tmpl, scope, attrlist, cache_period) def register(): config.registry.ldap_groups_query = query intr = config.introspectable( diff --git a/pyramid_ldap/tests.py b/pyramid_ldap/tests.py index 72d8b4c..6612048 100644 --- a/pyramid_ldap/tests.py +++ b/pyramid_ldap/tests.py @@ -118,6 +118,7 @@ def test_it_defaults(self): self._callFUT(config, 'dn', 'tmpl') self.assertEqual(config.registry.ldap_groups_query.base_dn, 'dn') self.assertEqual(config.registry.ldap_groups_query.filter_tmpl, 'tmpl') + self.assertEqual(config.registry.ldap_groups_query.attrlist, ('',)) self.assertEqual(config.registry.ldap_groups_query.scope, ldap.SCOPE_SUBTREE) self.assertEqual(config.registry.ldap_groups_query.cache_period, 0) @@ -133,6 +134,7 @@ def test_it_defaults(self): self._callFUT(config, 'dn', 'tmpl') self.assertEqual(config.registry.ldap_login_query.base_dn, 'dn') self.assertEqual(config.registry.ldap_login_query.filter_tmpl, 'tmpl') + self.assertEqual(config.registry.ldap_login_query.attrlist, None) self.assertEqual(config.registry.ldap_login_query.scope, ldap.SCOPE_ONELEVEL) self.assertEqual(config.registry.ldap_login_query.cache_period, 0) @@ -192,7 +194,7 @@ def test_user_groups_execute_raises(self): class Test_LDAPQuery(unittest.TestCase): def _makeOne(self, base_dn, filter_tmpl, scope, cache_period): from pyramid_ldap import _LDAPQuery - return _LDAPQuery(base_dn, filter_tmpl, scope, cache_period) + return _LDAPQuery(base_dn, filter_tmpl, scope, None, cache_period) def test_query_cache_no_rollover(self): inst = self._makeOne(None, None, None, 1) @@ -212,23 +214,23 @@ def test_execute_no_cache_period(self): conn = DummyConnection('abc') result = inst.execute(conn, login='foo') self.assertEqual(result, 'abc') - self.assertEqual(conn.arg, ('foo', None, 'foo')) + self.assertEqual(conn.arg, ('foo', None, 'foo', None)) def test_execute_with_cache_period_miss(self): inst = self._makeOne('%(login)s', '%(login)s', None, 1) conn = DummyConnection('abc') result = inst.execute(conn, login='foo') self.assertEqual(result, 'abc') - self.assertEqual(conn.arg, ('foo', None, 'foo')) + self.assertEqual(conn.arg, ('foo', None, 'foo', None)) def test_execute_with_cache_period_hit(self): inst = self._makeOne('%(login)s', '%(login)s', None, 1) inst.last_timeslice = sys.maxint - inst.cache[('foo', None, 'foo')] = 'def' + inst.cache[('foo', None, 'foo', None)] = 'def' conn = DummyConnection('abc') result = inst.execute(conn, login='foo') self.assertEqual(result, 'def') - + class DummyLDAPConnector(object): def __init__(self, dn, group_list): self.dn = dn