8.2 Iterators and Generators
Contents
8.2 Iterators and Generators#
Estimated time for this notebook: 25 minutes
In Python, anything which can be iterated over is called an iterable:
bowl = {"apple": 5, "banana": 3, "orange": 7}
for fruit in bowl:
print(fruit.upper())
APPLE
BANANA
ORANGE
Surprisingly often, we want to iterate over something that takes a moderately large amount of memory to store - for example, our map images in the green-graph example.
Our green-graph example involved making an array of all the maps between London and Birmingham. This kept them all in memory at the same time: first we downloaded all the maps, then we counted the green pixels in each of them.
This would NOT work if we used more points: eventually, we would run out of memory. We need to use a generator instead. This chapter will look at iterators and generators in more detail: how they work, when to use them, how to create our own.
Iterators#
Consider the basic python range
function:
range(10)
range(0, 10)
total = 0
for x in range(int(1e6)):
total += x
total
499999500000
In order to avoid allocating a million integers, range
actually uses an iterator.
We don’t actually need a million integers at once, just each integer in turn up to a million.
Because we can get an iterator from it, we say that a range is an iterable.
So we can for
-loop over it:
for i in range(3):
print(i)
0
1
2
There are two important Python built-in functions for working with iterables.
First is iter
, which lets us create an iterator from any iterable object.
a = iter(range(3))
Once we have an iterator object, we can pass it to the next
function. This
moves the iterator forward, and gives us its next element:
next(a)
0
next(a)
1
next(a)
2
When we are out of elements, a StopIteration
exception is raised:
next(a)
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
Cell In[9], line 1
----> 1 next(a)
StopIteration:
This tells Python that the iteration is over. For example, if we are in a for i in range(3)
loop, this lets us know when we should exit the loop.
We can turn an iterable or iterator into a list with the list
constructor function:
list(range(5))
[0, 1, 2, 3, 4]
Defining Our Own Iterable#
When we write next(a)
, under the hood Python tries to call the __next__()
method of a
. Similarly, iter(a)
calls a.__iter__()
.
We can make our own iterators by defining classes that can be used with the next()
and iter()
functions: this is the iterator protocol.
For each of the concepts in Python, like sequence, container, iterable, the language defines a protocol, a set of methods a class must implement, in order to be treated as a member of that concept.
To define an iterator, the methods that must be supported are __next__()
and __iter__()
.
__next__()
must update the iterator.
We’ll see why we need to define __iter__
in a moment.
Here is an example of defining a custom iterator class:
class fib_iterator:
"""An iterator over part of the Fibonacci sequence."""
def __init__(self, limit, seed1=1, seed2=1):
self.limit = limit
self.previous = seed1
self.current = seed2
def __iter__(self):
return self
def __next__(self):
(self.previous, self.current) = (self.current, self.previous + self.current)
self.limit -= 1
if self.limit < 0:
raise StopIteration()
return self.current
x = fib_iterator(5)
next(x)
2
next(x)
3
next(x)
5
next(x)
8
for x in fib_iterator(5):
print(x)
2
3
5
8
13
sum(fib_iterator(1000))
297924218508143360336882819981631900915673130543819759032778173440536722190488904520034508163846345539055096533885943242814978469042830417586260359446115245634668393210192357419233828310479227982326069668668250
A shortcut to iterables: the __iter__
method#
In fact, we don’t always have to define both __iter__
and __next__
!
If, to be iterated over, a class just wants to behave as if it were some other iterable, you can just implement __iter__
and return iter(some_other_iterable)
, without implementing next
. For example, an image class might want to implement some metadata, but behave just as if it were just a 1-d pixel array when being iterated:
from matplotlib import pyplot as plt
from numpy import array
class MyImage:
def __init__(self, pixels):
self.pixels = array(pixels, dtype="uint8")
self.channels = self.pixels.shape[2]
def __iter__(self):
# return an iterator over just the pixel values
return iter(self.pixels.reshape(-1, self.channels))
def show(self):
plt.imshow(self.pixels, interpolation="None")
x = [[[255, 255, 0], [0, 255, 0]], [[0, 0, 255], [255, 255, 255]]]
image = MyImage(x)
%matplotlib inline
image.show()
image.channels
3
from webcolors import rgb_to_name
for pixel in image:
print(rgb_to_name(pixel))
yellow
lime
blue
white
See how we used image
in a for
loop, even though it doesn’t satisfy the iterator protocol (we didn’t define both __iter__
and __next__
for it)?
The key here is that we can use any iterable object (like image
) in a for
expression,
not just iterators! Internally, Python will create an iterator from the iterable (by calling its __iter__
method), but this means we don’t need to define a __next__
method explicitly.
The iterator protocol is to implement both __iter__
and
__next__
, while the iterable protocol is to implement __iter__
and return
an iterator.
Generators#
There’s a fair amount of “boiler-plate” in the above class-based definition of an iterable.
Python provides another way to specify something which meets the iterator protocol: generators.
def my_generator():
yield 5
yield 10
x = my_generator()
next(x)
5
next(x)
10
next(x)
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
Cell In[26], line 1
----> 1 next(x)
StopIteration:
for a in my_generator():
print(a)
5
10
sum(my_generator())
15
A function which has yield
statements instead of a return
statement returns
temporarily: it automagically becomes something which implements __next__
.
Each call of next()
returns control to the function where it
left off.
Control passes back-and-forth between the generator and the caller. Our Fibonacci example therefore becomes a function rather than a class.
def yield_fibs(limit, seed1=1, seed2=1):
current = seed1
previous = seed2
while limit > 0:
limit -= 1
current, previous = current + previous, current
yield current
We can now use the output of the function like a normal iterable:
sum(yield_fibs(5))
31
for a in yield_fibs(10):
if a % 2 == 0:
print(a)
2
8
34
144
Sometimes we may need to gather all values from a generator into a list, such as before passing them to a function that expects a list:
list(yield_fibs(10))
[2, 3, 5, 8, 13, 21, 34, 55, 89, 144]
plt.plot(list(yield_fibs(20)))
[<matplotlib.lines.Line2D at 0x7fd634e98a30>]
Supplementary material#
The remainder of this page contains an example of the flexibility of the features discussed above. Specifically, it shows how generators and context managers can be combined to create a testing framework like the one previously seen in the course.
Test generators#
Earlier in the course we saw a test which loaded its test cases from a YAML file and asserted each input with each output.
This was nice and concise, but had one flaw: we had just one test, covering all the fixtures, so we got just one .
in the test output when we ran the tests, and if any test failed, the rest were not run.
We can do a nicer job with a test generator:
import os
def assert_exemplar(**fixture):
answer = fixture.pop("answer")
assert_equal(greet(**fixture), answer)
def test_greeter():
with open(
os.path.join(os.path.dirname(__file__), "fixtures", "samples.yaml")
) as fixtures_file:
fixtures = yaml.safe_load(fixtures_file)
for fixture in fixtures:
yield assert_exemplar(**fixture)
Each time a function beginning with test_
does a yield
it results in another test.
Negative test contexts managers#
We have seen this:
from pytest import raises
with raises(AttributeError):
x = 2
x.foo()
We can now see how pytest
might have implemented this:
@contextmanager
def reimplement_raises(exception):
try:
yield
except exception:
pass
else:
raise Exception("Expected,", exception, " to be raised, nothing was.")
with reimplement_raises(AttributeError):
x = 2
x.foo()
Skip test decorators#
Some frameworks also implement decorators for skipping tests or dealing with tests that are known to raise exceptions (due to known bugs or limitations). For example:
%%writefile test_skipped.py
import pytest
import sys
@pytest.mark.skipif(sys.version_info < (4, 0), reason="requires python 4")
def test_python_4():
raise RuntimeError("something went wrong")
Overwriting test_skipped.py
! pytest test_skipped.py
============================= test session starts ==============================
platform linux -- Python 3.8.18, pytest-7.4.4, pluggy-1.5.0
rootdir: /home/runner/work/rse-course/rse-course/module08_advanced_programming_techniques
plugins: cov-4.1.0, anyio-4.4.0, pylama-8.4.1
collecting ...
collected 1 item
test_skipped.py s [100%]
============================== 1 skipped in 0.01s ==============================
%%writefile test_not_skipped.py
import pytest
import sys
@pytest.mark.skipif(sys.version_info < (3, 0), reason="requires python 3")
def test_python_3():
raise RuntimeError("something went wrong")
Overwriting test_not_skipped.py
! pytest test_not_skipped.py
============================= test session starts ==============================
platform linux -- Python 3.8.18, pytest-7.4.4, pluggy-1.5.0
rootdir: /home/runner/work/rse-course/rse-course/module08_advanced_programming_techniques
plugins: cov-4.1.0, anyio-4.4.0, pylama-8.4.1
collecting ...
collected 1 item
test_not_skipped.py
F [100%]
=================================== FAILURES ===================================
________________________________ test_python_3 _________________________________
@pytest.mark.skipif(sys.version_info < (3, 0), reason="requires python 3")
def test_python_3():
> raise RuntimeError("something went wrong")
E RuntimeError: something went wrong
test_not_skipped.py:7: RuntimeError
=========================== short test summary info ============================
FAILED test_not_skipped.py::test_python_3 - RuntimeError: something went wrong
============================== 1 failed in 0.10s ===============================
We could reimplement this ourselves now too:
def homemade_skip_decorator(skip):
def wrap_function(func):
if skip:
# if the test should be skipped, return a function
# that just prints a message
def do_nothing(*args):
print("test was skipped")
return do_nothing
# otherwise use the original function as normal
return func
return wrap_function
@homemade_skip_decorator(3.9 < 4.0)
def test_skipped():
raise RuntimeError("This test is skipped")
test_skipped()
test was skipped
@homemade_skip_decorator(3.9 < 3.0)
def test_runs():
raise RuntimeError("This test is run")
test_runs()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[53], line 6
1 @homemade_skip_decorator(3.9 < 3.0)
2 def test_runs():
3 raise RuntimeError("This test is run")
----> 6 test_runs()
Cell In[53], line 3, in test_runs()
1 @homemade_skip_decorator(3.9 < 3.0)
2 def test_runs():
----> 3 raise RuntimeError("This test is run")
RuntimeError: This test is run