Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enter parameter to iterutils.research to allow traversing custom data types #372

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 75 additions & 34 deletions boltons/iterutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,15 @@ def split_iter(src, sep=None, maxsplit=None):
sep_func = sep
elif not is_scalar(sep):
sep = frozenset(sep)
sep_func = lambda x: x in sep
def sep_func(x): return x in sep
else:
sep_func = lambda x: x == sep
def sep_func(x): return x == sep

cur_group = []
split_count = 0
for s in src:
if maxsplit is not None and split_count >= maxsplit:
sep_func = lambda x: False
def sep_func(x): return False
if sep_func(s):
if sep is None and not cur_group:
# If sep is none, str.split() "groups" separators
Expand Down Expand Up @@ -229,7 +229,7 @@ def rstrip(iterable, strip_value=None):
['Foo', 'Bar']

"""
return list(rstrip_iter(iterable,strip_value))
return list(rstrip_iter(iterable, strip_value))


def rstrip_iter(iterable, strip_value=None):
Expand All @@ -253,7 +253,7 @@ def rstrip_iter(iterable, strip_value=None):
else:
broken = True
break
if not broken: # Return to caller here because the end of the
if not broken: # Return to caller here because the end of the
return # iterator has been reached
yield from cache
yield i
Expand All @@ -268,10 +268,10 @@ def strip(iterable, strip_value=None):
['Foo', 'Bar', 'Bam']

"""
return list(strip_iter(iterable,strip_value))
return list(strip_iter(iterable, strip_value))


def strip_iter(iterable,strip_value=None):
def strip_iter(iterable, strip_value=None):
"""Strips values from the beginning and end of an iterable. Stripped items
will match the value of the argument strip_value. Functionality is
analogous to that of the method str.strip. Returns a generator.
Expand All @@ -280,7 +280,7 @@ def strip_iter(iterable,strip_value=None):
['Foo', 'Bar', 'Bam']

"""
return rstrip_iter(lstrip_iter(iterable,strip_value),strip_value)
return rstrip_iter(lstrip_iter(iterable, strip_value), strip_value)


def chunked(src, size, count=None, **kw):
Expand Down Expand Up @@ -340,11 +340,12 @@ def chunked_iter(src, size, **kw):
raise ValueError('got unexpected keyword arguments: %r' % kw.keys())
if not src:
return
postprocess = lambda chk: chk

def postprocess(chk): return chk
if isinstance(src, (str, bytes)):
postprocess = lambda chk, _sep=type(src)(): _sep.join(chk)
def postprocess(chk, _sep=type(src)()): return _sep.join(chk)
if isinstance(src, bytes):
postprocess = lambda chk: bytes(chk)
def postprocess(chk): return bytes(chk)
src_iter = iter(src)
while True:
cur_chunk = list(itertools.islice(src_iter, size))
Expand Down Expand Up @@ -385,15 +386,19 @@ def chunk_ranges(input_size, chunk_size, input_offset=0, overlap_size=0, align=F
>>> list(chunk_ranges(input_offset=3, input_size=15, chunk_size=5, overlap_size=1, align=True))
[(3, 5), (4, 9), (8, 13), (12, 17), (16, 18)]
"""
input_size = _validate_positive_int(input_size, 'input_size', strictly_positive=False)
input_size = _validate_positive_int(
input_size, 'input_size', strictly_positive=False)
chunk_size = _validate_positive_int(chunk_size, 'chunk_size')
input_offset = _validate_positive_int(input_offset, 'input_offset', strictly_positive=False)
overlap_size = _validate_positive_int(overlap_size, 'overlap_size', strictly_positive=False)
input_offset = _validate_positive_int(
input_offset, 'input_offset', strictly_positive=False)
overlap_size = _validate_positive_int(
overlap_size, 'overlap_size', strictly_positive=False)

input_stop = input_offset + input_size

if align:
initial_chunk_len = chunk_size - input_offset % (chunk_size - overlap_size)
initial_chunk_len = chunk_size - \
input_offset % (chunk_size - overlap_size)
if initial_chunk_len != overlap_size:
yield (input_offset, min(input_offset + initial_chunk_len, input_stop))
if input_offset + initial_chunk_len >= input_stop:
Expand Down Expand Up @@ -479,7 +484,7 @@ def windowed_iter(src, size, fill=_UNSET):

With *fill* set, the iterator always yields a number of windows
equal to the length of the *src* iterable.

>>> windowed(range(4), 3, fill=None)
[(0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]

Expand All @@ -495,17 +500,16 @@ def windowed_iter(src, size, fill=_UNSET):
except StopIteration:
return zip([])
return zip(*tees)

for i, t in enumerate(tees):
for _ in range(i):
for _ in range(i):
try:
next(t)
except StopIteration:
continue
return zip_longest(*tees, fillvalue=fill)



def xfrange(stop, start=None, step=1.0):
"""Same as :func:`frange`, but generator-based instead of returning a
list.
Expand Down Expand Up @@ -726,21 +730,21 @@ def bucketize(src, key=bool, value_transform=None, key_filter=None):
src = zip(key, src)

if isinstance(key, str):
key_func = lambda x: getattr(x, key, x)
def key_func(x): return getattr(x, key, x)
elif callable(key):
key_func = key
elif isinstance(key, list):
key_func = lambda x: x[0]
def key_func(x): return x[0]
else:
raise TypeError('expected key to be callable or a string or a list')

if value_transform is None:
value_transform = lambda x: x
def value_transform(x): return x
if not callable(value_transform):
raise TypeError('expected callable value transform function')
if isinstance(key, list):
f = value_transform
value_transform=lambda x: f(x[1])
def value_transform(x): return f(x[1])

ret = {}
for val in src:
Expand Down Expand Up @@ -807,11 +811,11 @@ def unique_iter(src, key=None):
if not is_iterable(src):
raise TypeError('expected an iterable, not %r' % type(src))
if key is None:
key_func = lambda x: x
def key_func(x): return x
elif callable(key):
key_func = key
elif isinstance(key, str):
key_func = lambda x: getattr(x, key, x)
def key_func(x): return getattr(x, key, x)
else:
raise TypeError('"key" expected a string or callable, not %r' % key)
seen = set()
Expand Down Expand Up @@ -862,7 +866,7 @@ def redundant(src, key=None, groups=False):
elif callable(key):
key_func = key
elif isinstance(key, (str, bytes)):
key_func = lambda x: getattr(x, key, x)
def key_func(x): return getattr(x, key, x)
else:
raise TypeError('"key" expected a string or callable, not %r' % key)
seen = {} # key to first seen item
Expand Down Expand Up @@ -964,6 +968,7 @@ def flatten_iter(iterable):
else:
yield item


def flatten(iterable):
"""``flatten()`` returns a collapsed list of all the elements from
*iterable* while collapsing any nested iterables.
Expand Down Expand Up @@ -1006,6 +1011,7 @@ def default_visit(path, key, value):
# print('visit(%r, %r, %r)' % (path, key, value))
return key, value


# enable the extreme: monkeypatching iterutils with a different default_visit
_orig_default_visit = default_visit

Expand Down Expand Up @@ -1128,6 +1134,9 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
callable. When set to ``False``, remap ignores any errors
raised by the *visit* callback. Items causing exceptions
are kept. See examples for more details.
trace (bool): Pass ``trace=True`` to print out the entire
traversal. Or pass a tuple of ``'visit'``, ``'enter'``,
or ``'exit'`` to print only the selected events.

remap is designed to cover the majority of cases with just the
*visit* callable. While passing in multiple callables is very
Expand Down Expand Up @@ -1156,6 +1165,15 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
if not callable(exit):
raise TypeError('exit expected callable, not: %r' % exit)
reraise_visit = kwargs.pop('reraise_visit', True)
trace = kwargs.pop('trace', ())
if trace is True:
trace = ('visit', 'enter', 'exit')
elif isinstance(trace, str):
trace = (trace,)
if not isinstance(trace, (tuple, list, set)):
raise TypeError('trace expected tuple of event names, not: %r' % trace)
trace_enter, trace_exit, trace_visit = 'enter' in trace, 'exit' in trace, 'visit' in trace

if kwargs:
raise TypeError('unexpected keyword arguments: %r' % kwargs.keys())

Expand All @@ -1168,14 +1186,23 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
key, new_parent, old_parent = value
id_value = id(old_parent)
path, new_items = new_items_stack.pop()
if trace_exit:
print(' .. remap exit:', path, '-', key, '-',
old_parent, '-', new_parent, '-', new_items)
value = exit(path, key, old_parent, new_parent, new_items)
if trace_exit:
print(' .. remap exit result:', value)
registry[id_value] = value
if not new_items_stack:
continue
elif id_value in registry:
value = registry[id_value]
else:
if trace_enter:
print(' .. remap enter:', path, '-', key, '-', value)
res = enter(path, key, value)
if trace_enter:
print(' .. remap enter result:', res)
try:
new_parent, new_items = res
except TypeError:
Expand All @@ -1191,21 +1218,29 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
stack.append((_REMAP_EXIT, (key, new_parent, value)))
if new_items:
stack.extend(reversed(list(new_items)))
if trace_enter:
print(' .. remap stack size now:', len(stack))
continue
if visit is _orig_default_visit:
# avoid function call overhead by inlining identity operation
visited_item = (key, value)
else:
try:
if trace_visit:
print(' .. remap visit:', path, '-', key, '-', value)
visited_item = visit(path, key, value)
except Exception:
if reraise_visit:
raise
visited_item = True
if visited_item is False:
if trace_visit:
print(' .. remap visit result: <drop>')
continue # drop
elif visited_item is True:
visited_item = (key, value)
if trace_visit:
print(' .. remap visit result:', visited_item)
# TODO: typecheck?
# raise TypeError('expected (key, value) from visit(),'
# ' not: %r' % visited_item)
Expand All @@ -1221,6 +1256,7 @@ class PathAccessError(KeyError, IndexError, TypeError):
representing what can occur when looking up a path in a nested
object.
"""

def __init__(self, exc, seg, path):
self.exc = exc
self.seg = seg
Expand Down Expand Up @@ -1296,7 +1332,7 @@ def get_path(root, path, default=_UNSET):
return cur


def research(root, query=lambda p, k, v: True, reraise=False):
def research(root, query=lambda p, k, v: True, reraise=False, enter=default_enter):
"""The :func:`research` function uses :func:`remap` to recurse over
any data nested in *root*, and find values which match a given
criterion, specified by the *query* callable.
Expand Down Expand Up @@ -1343,16 +1379,16 @@ def research(root, query=lambda p, k, v: True, reraise=False):
if not callable(query):
raise TypeError('query expected callable, not: %r' % query)

def enter(path, key, value):
def _enter(path, key, value):
try:
if query(path, key, value):
ret.append((path + (key,), value))
except Exception:
if reraise:
raise
return default_enter(path, key, value)
return enter(path, key, value)

remap(root, enter=enter)
remap(root, enter=_enter)
return ret


Expand Down Expand Up @@ -1383,6 +1419,7 @@ class GUIDerator:
detect a fork on next iteration and reseed accordingly.

"""

def __init__(self, size=24):
self.size = size
if size < 20 or size > 36:
Expand Down Expand Up @@ -1495,13 +1532,16 @@ def soft_sorted(iterable, first=None, last=None, key=None, reverse=False):
last = last or []
key = key or (lambda x: x)
seq = list(iterable)
other = [x for x in seq if not ((first and key(x) in first) or (last and key(x) in last))]
other = [x for x in seq if not (
(first and key(x) in first) or (last and key(x) in last))]
other.sort(key=key, reverse=reverse)

if first:
first = sorted([x for x in seq if key(x) in first], key=lambda x: first.index(key(x)))
first = sorted([x for x in seq if key(x) in first],
key=lambda x: first.index(key(x)))
if last:
last = sorted([x for x in seq if key(x) in last], key=lambda x: last.index(key(x)))
last = sorted([x for x in seq if key(x) in last],
key=lambda x: last.index(key(x)))
return first + other + last


Expand Down Expand Up @@ -1536,7 +1576,7 @@ def __lt__(self, other):
ret = obj < other
except TypeError:
ret = ((type(obj).__name__, id(type(obj)), obj)
< (type(other).__name__, id(type(other)), other))
< (type(other).__name__, id(type(other)), other))
return ret

if key is not None and not callable(key):
Expand All @@ -1545,6 +1585,7 @@ def __lt__(self, other):

return sorted(iterable, key=_Wrapper, reverse=reverse)


"""
May actually be faster to do an isinstance check for a str path

Expand Down
22 changes: 22 additions & 0 deletions tests/test_iterutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,28 @@ def broken_query(p, k, v):
assert research(root, broken_query) == []


def test_research_custom_enter():
# see #368
from types import SimpleNamespace as NS
root = NS(
a='a',
b='b',
c=NS(aa='aa') )

def query(path, key, value):
return value.startswith('a')

def custom_enter(path, key, value):
if isinstance(value, NS):
return [], value.__dict__.items()
return default_enter(path, key, value)

with pytest.raises(TypeError):
research(root, query)
assert research(root, query, enter=custom_enter) == [(('a',), 'a'), (('c', 'aa'), 'aa')]



def test_backoff_basic():
from boltons.iterutils import backoff

Expand Down
Loading