Skip to content
Snippets Groups Projects
Commit 8a8e1a4e authored by Bryce Guinta's avatar Bryce Guinta
Browse files

Fix contextmanager transform for nested contextmanagers

Close PyCQA/pylint#1746
parent ffb98c4b
No related branches found
No related tags found
No related merge requests found
Loading
Loading
@@ -18,6 +18,11 @@ Change log for the astroid package (used to be astng)
 
Close PyCQA/pylint#1884
 
* Fix ``contextlib.contextmanager`` inference for nested
context managers
Close #1699
 
2018-01-23 -- 1.6.1
 
Loading
Loading
Loading
Loading
@@ -445,7 +445,10 @@ def _infer_context_manager(self, mgr, context):
# Get the first yield point. If it has multiple yields,
# then a RuntimeError will be raised.
# TODO(cpopa): Handle flows.
yield_point = next(func.nodes_of_class(nodes.Yield), None)
possible_yield_points = func.nodes_of_class(nodes.Yield)
# Ignore yields in nested functions
yield_point = next((node for node in possible_yield_points
if node.scope() == func), None)
if yield_point:
if not yield_point.value:
# TODO(cpopa): an empty yield. Should be wrapped to Const.
Loading
Loading
Loading
Loading
@@ -2154,6 +2154,36 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
self.assertRaises(InferenceError, next, module['other_decorators'].infer())
self.assertRaises(InferenceError, next, module['no_yield'].infer())
 
def test_nested_contextmanager(self):
"""Make sure contextmanager works with nested functions
Previously contextmanager would retrieve
the first yield instead of the yield in the
proper scope
Fixes https://github.com/PyCQA/pylint/issues/1746
"""
code = """
from contextlib import contextmanager
@contextmanager
def outer():
@contextmanager
def inner():
yield 2
yield inner
with outer() as ctx:
ctx #@
with ctx() as val:
val #@
"""
context_node, value_node = extract_node(code)
value = next(value_node.infer())
context = next(context_node.infer())
assert isinstance(context, nodes.FunctionDef)
assert isinstance(value, nodes.Const)
def test_unary_op_leaks_stop_iteration(self):
node = extract_node('+[] #@')
self.assertEqual(util.Uninferable, next(node.infer()))
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment