8.2 Iterators and Generators#

Estimated time for this notebook: 25 minutes

In Python, anything which can be iterated over is called an iterable:

bowl = {"apple": 5, "banana": 3, "orange": 7}

for fruit in bowl:
    print(fruit.upper())
APPLE
BANANA
ORANGE

Surprisingly often, we want to iterate over something that takes a moderately large amount of memory to store - for example, our map images in the green-graph example.

Our green-graph example involved making an array of all the maps between London and Birmingham. This kept them all in memory at the same time: first we downloaded all the maps, then we counted the green pixels in each of them.

This would NOT work if we used more points: eventually, we would run out of memory. We need to use a generator instead. This chapter will look at iterators and generators in more detail: how they work, when to use them, how to create our own.

Iterators#

Consider the basic python range function:

range(10)
range(0, 10)
total = 0
for x in range(int(1e6)):
    total += x

total
499999500000

In order to avoid allocating a million integers, range actually uses an iterator.

We don’t actually need a million integers at once, just each integer in turn up to a million.

Because we can get an iterator from it, we say that a range is an iterable.

So we can for-loop over it:

for i in range(3):
    print(i)
0
1
2

There are two important Python built-in functions for working with iterables. First is iter, which lets us create an iterator from any iterable object.

a = iter(range(3))

Once we have an iterator object, we can pass it to the next function. This moves the iterator forward, and gives us its next element:

next(a)
0
next(a)
1
next(a)
2

When we are out of elements, a StopIteration exception is raised:

next(a)
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Cell In[9], line 1
----> 1 next(a)

StopIteration: 

This tells Python that the iteration is over. For example, if we are in a for i in range(3) loop, this lets us know when we should exit the loop.

We can turn an iterable or iterator into a list with the list constructor function:

list(range(5))
[0, 1, 2, 3, 4]

Defining Our Own Iterable#

When we write next(a), under the hood Python tries to call the __next__() method of a. Similarly, iter(a) calls a.__iter__().

We can make our own iterators by defining classes that can be used with the next() and iter() functions: this is the iterator protocol.

For each of the concepts in Python, like sequence, container, iterable, the language defines a protocol, a set of methods a class must implement, in order to be treated as a member of that concept.

To define an iterator, the methods that must be supported are __next__() and __iter__().

__next__() must update the iterator.

We’ll see why we need to define __iter__ in a moment.

Here is an example of defining a custom iterator class:

class fib_iterator:
    """An iterator over part of the Fibonacci sequence."""

    def __init__(self, limit, seed1=1, seed2=1):
        self.limit = limit
        self.previous = seed1
        self.current = seed2

    def __iter__(self):
        return self

    def __next__(self):
        (self.previous, self.current) = (self.current, self.previous + self.current)
        self.limit -= 1
        if self.limit < 0:
            raise StopIteration()
        return self.current
x = fib_iterator(5)
next(x)
2
next(x)
3
next(x)
5
next(x)
8
for x in fib_iterator(5):
    print(x)
2
3
5
8
13
sum(fib_iterator(1000))
297924218508143360336882819981631900915673130543819759032778173440536722190488904520034508163846345539055096533885943242814978469042830417586260359446115245634668393210192357419233828310479227982326069668668250

A shortcut to iterables: the __iter__ method#

In fact, we don’t always have to define both __iter__ and __next__!

If, to be iterated over, a class just wants to behave as if it were some other iterable, you can just implement __iter__ and return iter(some_other_iterable), without implementing next. For example, an image class might want to implement some metadata, but behave just as if it were just a 1-d pixel array when being iterated:

from matplotlib import pyplot as plt
from numpy import array


class MyImage:
    def __init__(self, pixels):
        self.pixels = array(pixels, dtype="uint8")
        self.channels = self.pixels.shape[2]

    def __iter__(self):
        # return an iterator over just the pixel values
        return iter(self.pixels.reshape(-1, self.channels))

    def show(self):
        plt.imshow(self.pixels, interpolation="None")


x = [[[255, 255, 0], [0, 255, 0]], [[0, 0, 255], [255, 255, 255]]]
image = MyImage(x)
%matplotlib inline
image.show()
../_images/08_02_iterators_and_generators_37_0.png
image.channels
3
from webcolors import rgb_to_name

for pixel in image:
    print(rgb_to_name(pixel))
yellow
lime
blue
white

See how we used image in a for loop, even though it doesn’t satisfy the iterator protocol (we didn’t define both __iter__ and __next__ for it)?

The key here is that we can use any iterable object (like image) in a for expression, not just iterators! Internally, Python will create an iterator from the iterable (by calling its __iter__ method), but this means we don’t need to define a __next__ method explicitly.

The iterator protocol is to implement both __iter__ and __next__, while the iterable protocol is to implement __iter__ and return an iterator.

Generators#

There’s a fair amount of “boiler-plate” in the above class-based definition of an iterable.

Python provides another way to specify something which meets the iterator protocol: generators.

def my_generator():
    yield 5
    yield 10


x = my_generator()
next(x)
5
next(x)
10
next(x)
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Cell In[26], line 1
----> 1 next(x)

StopIteration: 
for a in my_generator():
    print(a)
5
10
sum(my_generator())
15

A function which has yield statements instead of a return statement returns temporarily: it automagically becomes something which implements __next__.

Each call of next() returns control to the function where it left off.

Control passes back-and-forth between the generator and the caller. Our Fibonacci example therefore becomes a function rather than a class.

def yield_fibs(limit, seed1=1, seed2=1):
    current = seed1
    previous = seed2

    while limit > 0:
        limit -= 1
        current, previous = current + previous, current
        yield current

We can now use the output of the function like a normal iterable:

sum(yield_fibs(5))
31
for a in yield_fibs(10):
    if a % 2 == 0:
        print(a)
2
8
34
144

Sometimes we may need to gather all values from a generator into a list, such as before passing them to a function that expects a list:

list(yield_fibs(10))
[2, 3, 5, 8, 13, 21, 34, 55, 89, 144]
plt.plot(list(yield_fibs(20)))
[<matplotlib.lines.Line2D at 0x7fd634e98a30>]
../_images/08_02_iterators_and_generators_59_1.png

Supplementary material#

The remainder of this page contains an example of the flexibility of the features discussed above. Specifically, it shows how generators and context managers can be combined to create a testing framework like the one previously seen in the course.

Test generators#

Earlier in the course we saw a test which loaded its test cases from a YAML file and asserted each input with each output. This was nice and concise, but had one flaw: we had just one test, covering all the fixtures, so we got just one . in the test output when we ran the tests, and if any test failed, the rest were not run. We can do a nicer job with a test generator:

import os


def assert_exemplar(**fixture):
    answer = fixture.pop("answer")
    assert_equal(greet(**fixture), answer)


def test_greeter():
    with open(
        os.path.join(os.path.dirname(__file__), "fixtures", "samples.yaml")
    ) as fixtures_file:
        fixtures = yaml.safe_load(fixtures_file)

        for fixture in fixtures:
            yield assert_exemplar(**fixture)

Each time a function beginning with test_ does a yield it results in another test.

Negative test contexts managers#

We have seen this:

from pytest import raises

with raises(AttributeError):
    x = 2
    x.foo()

We can now see how pytest might have implemented this:

@contextmanager
def reimplement_raises(exception):
    try:
        yield
    except exception:
        pass
    else:
        raise Exception("Expected,", exception, " to be raised, nothing was.")
with reimplement_raises(AttributeError):
    x = 2
    x.foo()

Skip test decorators#

Some frameworks also implement decorators for skipping tests or dealing with tests that are known to raise exceptions (due to known bugs or limitations). For example:

%%writefile test_skipped.py
import pytest
import sys


@pytest.mark.skipif(sys.version_info < (4, 0), reason="requires python 4")
def test_python_4():
    raise RuntimeError("something went wrong")
Overwriting test_skipped.py
! pytest test_skipped.py
============================= test session starts ==============================
platform linux -- Python 3.8.18, pytest-7.4.4, pluggy-1.5.0
rootdir: /home/runner/work/rse-course/rse-course/module08_advanced_programming_techniques
plugins: cov-4.1.0, anyio-4.4.0, pylama-8.4.1
collecting ... 
collected 1 item                                                               

test_skipped.py s                                                        [100%]

============================== 1 skipped in 0.01s ==============================
%%writefile test_not_skipped.py
import pytest
import sys


@pytest.mark.skipif(sys.version_info < (3, 0), reason="requires python 3")
def test_python_3():
    raise RuntimeError("something went wrong")
Overwriting test_not_skipped.py
! pytest test_not_skipped.py
============================= test session starts ==============================
platform linux -- Python 3.8.18, pytest-7.4.4, pluggy-1.5.0
rootdir: /home/runner/work/rse-course/rse-course/module08_advanced_programming_techniques
plugins: cov-4.1.0, anyio-4.4.0, pylama-8.4.1
collecting ... 
collected 1 item                                                               

test_not_skipped.py 
F                                                    [100%]

=================================== FAILURES ===================================
________________________________ test_python_3 _________________________________
    @pytest.mark.skipif(sys.version_info < (3, 0), reason="requires python 3")
    def test_python_3():
>       raise RuntimeError("something went wrong")
E       RuntimeError: something went wrong

test_not_skipped.py:7: RuntimeError
=========================== short test summary info ============================
FAILED test_not_skipped.py::test_python_3 - RuntimeError: something went wrong
============================== 1 failed in 0.10s ===============================

We could reimplement this ourselves now too:

def homemade_skip_decorator(skip):
    def wrap_function(func):
        if skip:
            # if the test should be skipped, return a function
            # that just prints a message
            def do_nothing(*args):
                print("test was skipped")

            return do_nothing
        # otherwise use the original function as normal
        return func

    return wrap_function
@homemade_skip_decorator(3.9 < 4.0)
def test_skipped():
    raise RuntimeError("This test is skipped")


test_skipped()
test was skipped
@homemade_skip_decorator(3.9 < 3.0)
def test_runs():
    raise RuntimeError("This test is run")


test_runs()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[53], line 6
      1 @homemade_skip_decorator(3.9 < 3.0)
      2 def test_runs():
      3     raise RuntimeError("This test is run")
----> 6 test_runs()

Cell In[53], line 3, in test_runs()
      1 @homemade_skip_decorator(3.9 < 3.0)
      2 def test_runs():
----> 3     raise RuntimeError("This test is run")

RuntimeError: This test is run