Skip to content
Snippets Groups Projects
Commit b0d4b1c0 authored by Pierre Penninckx's avatar Pierre Penninckx Committed by GitHub
Browse files

Merge pull request #126 from PyCQA/feature/find_iter

Issue #117: Fix bug with finding comment node
parents ee74b0da 2866320e
No related branches found
No related tags found
No related merge requests found
Loading
Loading
@@ -326,11 +326,13 @@ class NodeList(UserList, GenericNodesUtils):
def __setitem__(self, key, value):
self.data[key] = self._convert_input_to_node_object(value, parent=self.parent, on_attribute=self.on_attribute)
 
def find_iter(self, identifier, *args, **kwargs):
for node in self.data:
for matched_node in node.find_iter(identifier, *args, **kwargs):
yield matched_node
def find_all(self, identifier, *args, **kwargs):
to_return = NodeList([])
for i in self.data:
to_return += i.find_all(identifier, *args, **kwargs)
return to_return
return NodeList(list(self.find_iter(identifier, *args, **kwargs)))
 
findAll = find_all
__call__ = find_all
Loading
Loading
@@ -668,40 +670,6 @@ class Node(GenericNodesUtils):
 
return in_list
 
def find(self, identifier, *args, **kwargs):
if "recursive" in kwargs:
recursive = kwargs["recursive"]
kwargs = kwargs.copy()
del kwargs["recursive"]
else:
recursive = True
if self._node_match_query(self, identifier, *args, **kwargs):
return self
if not recursive:
return None
for kind, key, _ in filter(lambda x: x[0] in ("list", "key"), self._render()):
if kind == "key":
i = getattr(self, key)
if not i:
continue
found = i.find(identifier, *args, **kwargs)
if found is not None:
return found
elif kind == "list":
attr = getattr(self, key).node_list if isinstance(getattr(self, key), ProxyList) else getattr(self, key)
for i in attr:
found = i.find(identifier, *args, **kwargs)
if found is not None:
return found
else:
raise Exception()
def __getattr__(self, key):
if key.endswith("_") and key[:-1] in self._dict_keys + self._list_keys + self._str_keys:
return getattr(self, key[:-1])
Loading
Loading
@@ -762,8 +730,7 @@ class Node(GenericNodesUtils):
else:
raise AttributeError("__delitem__")
 
def find_all(self, identifier, *args, **kwargs):
to_return = NodeList([])
def find_iter(self, identifier, *args, **kwargs):
if "recursive" in kwargs:
recursive = kwargs["recursive"]
kwargs = kwargs.copy()
Loading
Loading
@@ -772,33 +739,29 @@ class Node(GenericNodesUtils):
recursive = True
 
if self._node_match_query(self, identifier, *args, **kwargs):
to_return.append(self)
yield self
if recursive:
for (kind, key, _) in self._render():
if kind == "key":
node = getattr(self, key)
if not isinstance(node, Node):
continue
for matched_node in node.find_iter(identifier, *args, **kwargs):
yield matched_node
elif kind in ("list", "formatting"):
nodes = getattr(self, key)
if isinstance(nodes, ProxyList):
nodes = nodes.node_list
for node in nodes:
for matched_node in node.find_iter(identifier, *args, **kwargs):
yield matched_node
 
if not recursive:
return to_return
for kind, key, _ in filter(
lambda x: x[0] in ("list", "formatting") or (x[0] == "key" and isinstance(getattr(self, x[1]), Node)),
self._render()):
if kind == "key":
i = getattr(self, key)
if not i:
continue
to_return += i.find_all(identifier, *args, **kwargs)
elif kind in ("list", "formatting"):
if isinstance(getattr(self, key), ProxyList):
for i in getattr(self, key).node_list:
to_return += i.find_all(identifier, *args, **kwargs)
else:
for i in getattr(self, key):
to_return += i.find_all(identifier, *args, **kwargs)
else:
raise Exception()
def find(self, identifier, *args, **kwargs):
return next(self.find_iter(identifier, *args, **kwargs), None)
 
return to_return
def find_all(self, identifier, *args, **kwargs):
return NodeList(list(self.find_iter(identifier, *args, **kwargs)))
 
findAll = find_all
__call__ = find_all
Loading
Loading
@@ -888,6 +851,7 @@ class Node(GenericNodesUtils):
 
def _get_helpers(self):
not_helpers = set([
'at',
'copy',
'decrease_indentation',
'dumps',
Loading
Loading
@@ -897,25 +861,33 @@ class Node(GenericNodesUtils):
'findAll',
'find_by_path',
'find_by_position',
'at',
'find_iter',
'from_fst',
'fst',
'fst',
'generate_identifiers',
'get_absolute_bounding_box_of_attribute',
'get_indentation_node',
'get_indentation_node',
'has_render_key',
'help',
'help',
'increase_indentation',
'indentation_node_is_direct',
'indentation_node_is_direct',
'index_on_parent',
'index_on_parent_raw',
'insert_after',
'insert_before',
'next_generator',
'next_generator',
'parent_find',
'parent_find',
'parse_code_block',
'parse_decorators',
'path',
'path',
'previous_generator',
'previous_generator',
'replace',
'to_python',
Loading
Loading
Loading
Loading
@@ -958,6 +958,13 @@ def test_default_test_value_find_all():
red = RedBaron("badger\nmushroom\nsnake")
assert red("name", "snake") == red("name", value="snake")
 
def test_find_comment_node():
red = RedBaron("def f():\n #a\n pass\n#b")
assert red.find('comment').value == '#a'
def test_find_all_comment_nodes():
red = RedBaron("def f():\n #a\n pass\n#b")
assert [x.value for x in red.find_all('comment')] == ['#a', '#b']
 
def test_default_test_value_find_def():
red = RedBaron("def a(): pass\ndef b(): pass")
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment