Iterators and Generators

In Python, anything which can be iterated over is called an iterable:

bowl = {"apple": 5, "banana": 3, "orange": 7}

for fruit in bowl:
    print(fruit.upper())
APPLE
BANANA
ORANGE

Surprisingly often, we want to iterate over something that takes a moderately large amount of memory to store - for example, our map images in the green-graph example.

Our green-graph example involved making an array of all the maps between London and Birmingham. This kept them all in memory at the same time: first we downloaded all the maps, then we counted the green pixels in each of them.

This would NOT work if we used more points: eventually, we would run out of memory. We need to use a generator instead. This chapter will look at iterators and generators in more detail: how they work, when to use them, how to create our own.

Iterators

Consider the basic python range function:

range(10)
range(0, 10)
total = 0
for x in range(int(1e6)):
    total += x

total
499999500000

In order to avoid allocating a million integers, range actually uses an iterator.

We don’t actually need a million integers at once, just each integer in turn up to a million.

Because we can get an iterator from it, we say that a range is an iterable.

So we can for-loop over it:

for i in range(3):
    print(i)
0
1
2

There are two important Python built-in functions for working with iterables. First is iter, which lets us create an iterator from any iterable object.

a = iter(range(3))

Once we have an iterator object, we can pass it to the next function. This moves the iterator forward, and gives us its next element:

next(a)
0
next(a)
1
next(a)
2

When we are out of elements, a StopIteration exception is raised:

next(a)
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Input In [9], in <cell line: 1>()
----> 1 next(a)

StopIteration: 

This tells Python that the iteration is over. For example, if we are in a for i in range(3) loop, this lets us know when we should exit the loop.

We can turn an iterable or iterator into a list with the list constructor function:

list(range(5))
[0, 1, 2, 3, 4]

Defining Our Own Iterable

When we write next(a), under the hood Python tries to call the __next__() method of a. Similarly, iter(a) calls a.__iter__().

We can make our own iterators by defining classes that can be used with the next() and iter() functions: this is the iterator protocol.

For each of the concepts in Python, like sequence, container, iterable, the language defines a protocol, a set of methods a class must implement, in order to be treated as a member of that concept.

To define an iterator, the methods that must be supported are __next__() and __iter__().

__next__() must update the iterator.

We’ll see why we need to define __iter__ in a moment.

Here is an example of defining a custom iterator class:

class fib_iterator:
    """An iterator over part of the Fibonacci sequence."""

    def __init__(self, limit, seed1=1, seed2=1):
        self.limit = limit
        self.previous = seed1
        self.current = seed2

    def __iter__(self):
        return self

    def __next__(self):
        (self.previous, self.current) = (self.current, self.previous + self.current)
        self.limit -= 1
        if self.limit < 0:
            raise StopIteration()
        return self.current
x = fib_iterator(5)
next(x)
2
next(x)
3
next(x)
5
next(x)
8
for x in fib_iterator(5):
    print(x)
2
3
5
8
13
sum(fib_iterator(1000))
297924218508143360336882819981631900915673130543819759032778173440536722190488904520034508163846345539055096533885943242814978469042830417586260359446115245634668393210192357419233828310479227982326069668668250

A shortcut to iterables: the __iter__ method

In fact, we don’t always have to define both __iter__ and __next__!

If, to be iterated over, a class just wants to behave as if it were some other iterable, you can just implement __iter__ and return iter(some_other_iterable), without implementing next. For example, an image class might want to implement some metadata, but behave just as if it were just a 1-d pixel array when being iterated:

from numpy import array
from matplotlib import pyplot as plt


class MyImage:
    def __init__(self, pixels):
        self.pixels = array(pixels, dtype="uint8")
        self.channels = self.pixels.shape[2]

    def __iter__(self):
        # return an iterator over just the pixel values
        return iter(self.pixels.reshape(-1, self.channels))

    def show(self):
        plt.imshow(self.pixels, interpolation="None")


x = [[[255, 255, 0], [0, 255, 0]], [[0, 0, 255], [255, 255, 255]]]
image = MyImage(x)
%matplotlib inline
image.show()
../_images/08_02_iterators_and_generators_36_0.png
image.channels
3
from webcolors import rgb_to_name

for pixel in image:
    print(rgb_to_name(pixel))
yellow
lime
blue
white

See how we used image in a for loop, even though it doesn’t satisfy the iterator protocol (we didn’t define both __iter__ and __next__ for it)?

The key here is that we can use any iterable object (like image) in a for expression, not just iterators! Internally, Python will create an iterator from the iterable (by calling its __iter__ method), but this means we don’t need to define a __next__ method explicitly.

The iterator protocol is to implement both __iter__ and __next__, while the iterable protocol is to implement __iter__ and return an iterator.

Generators

There’s a fair amount of “boiler-plate” in the above class-based definition of an iterable.

Python provides another way to specify something which meets the iterator protocol: generators.

def my_generator():
    yield 5
    yield 10


x = my_generator()
next(x)
5
next(x)
10
next(x)
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Input In [26], in <cell line: 1>()
----> 1 next(x)

StopIteration: 
for a in my_generator():
    print(a)
5
10
sum(my_generator())
15

A function which has yield statements instead of a return statement returns temporarily: it automagically becomes something which implements __next__.

Each call of next() returns control to the function where it left off.

Control passes back-and-forth between the generator and the caller. Our Fibonacci example therefore becomes a function rather than a class.

def yield_fibs(limit, seed1=1, seed2=1):
    current = seed1
    previous = seed2

    while limit > 0:
        limit -= 1
        current, previous = current + previous, current
        yield current

We can now use the output of the function like a normal iterable:

sum(yield_fibs(5))
31
for a in yield_fibs(10):
    if a % 2 == 0:
        print(a)
2
8
34
144

Sometimes we may need to gather all values from a generator into a list, such as before passing them to a function that expects a list:

list(yield_fibs(10))
[2, 3, 5, 8, 13, 21, 34, 55, 89, 144]
plt.plot(list(yield_fibs(20)))
[<matplotlib.lines.Line2D at 0x7f9ae3429e80>]
../_images/08_02_iterators_and_generators_58_1.png

Supplementary material

The remainder of this page contains an example of the flexibility of the features discussed above. Specifically, it shows how generators and context managers can be combined to create a testing framework like the one previously seen in the course.

Test generators

Earlier in the course we saw a test which loaded its test cases from a YAML file and asserted each input with each output. This was nice and concise, but had one flaw: we had just one test, covering all the fixtures, so we got just one . in the test output when we ran the tests, and if any test failed, the rest were not run. We can do a nicer job with a test generator:

def assert_exemplar(**fixture):
    answer = fixture.pop("answer")
    assert_equal(greet(**fixture), answer)


def test_greeter():
    with open(
        os.path.join(os.path.dirname(__file__), "fixtures", "samples.yaml")
    ) as fixtures_file:
        fixtures = yaml.load(fixtures_file)

        for fixture in fixtures:

            yield assert_exemplar(**fixture)

Each time a function beginning with test_ does a yield it results in another test.

Negative test contexts managers

We have seen this:

from pytest import raises

with raises(AttributeError):
    x = 2
    x.foo()

We can now see how pytest might have implemented this:

from contextlib import contextmanager


@contextmanager
def reimplement_raises(exception):
    try:
        yield
    except exception:
        pass
    else:
        raise Exception("Expected,", exception, " to be raised, nothing was.")
with reimplement_raises(AttributeError):
    x = 2
    x.foo()

Skip test decorators

Some frameworks also implement decorators for skipping tests or dealing with tests that are known to raise exceptions (due to known bugs or limitations). For example:

%%writefile test_skipped.py
import pytest
import sys


@pytest.mark.skipif(sys.version_info < (4, 0), reason="requires python 4")
def test_python_4():
    raise RuntimeError("something went wrong")
Overwriting test_skipped.py
! pytest test_skipped.py
============================= test session starts ==============================
platform linux -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/runner/work/rse-course/rse-course/module08_advanced_programming_techniques
plugins: cov-3.0.0, anyio-3.6.1
collecting ... 
collected 1 item                                                               

test_skipped.py s                                                        [100%]

============================== 1 skipped in 0.02s ==============================
%%writefile test_not_skipped.py
import pytest
import sys


@pytest.mark.skipif(sys.version_info < (3, 0), reason="requires python 3")
def test_python_3():
    raise RuntimeError("something went wrong")
Overwriting test_not_skipped.py
! pytest test_not_skipped.py
============================= test session starts ==============================
platform linux -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/runner/work/rse-course/rse-course/module08_advanced_programming_techniques
plugins: cov-3.0.0, anyio-3.6.1
collecting ... 
collected 1 item                                                               

test_not_skipped.py F                                                    [100%]

=================================== FAILURES ===================================
________________________________ test_python_3 _________________________________
    @pytest.mark.skipif(sys.version_info < (3, 0), reason="requires python 3")
    def test_python_3():
>       raise RuntimeError("something went wrong")
E       RuntimeError: something went wrong

test_not_skipped.py:7: RuntimeError
=========================== short test summary info ============================
FAILED test_not_skipped.py::test_python_3 - RuntimeError: something went wrong
============================== 1 failed in 0.11s ===============================

We could reimplement this ourselves now too:

def homemade_skip_decorator(skip):
    def wrap_function(func):
        if skip:
            # if the test should be skipped, return a function
            # that just prints a message
            def do_nothing(*args):
                print("test was skipped")
            return do_nothing
        else:
            # otherwise use the original function as normal
            return func

    return wrap_function
@homemade_skip_decorator(3.9 < 4.0)
def test_skipped():
    raise RuntimeError("This test is skipped")
    

test_skipped()
test was skipped
@homemade_skip_decorator(3.9 < 3.0)
def test_runs():
    raise RuntimeError("This test is run")


test_runs()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [53], in <cell line: 6>()
      1 @homemade_skip_decorator(3.9 < 3.0)
      2 def test_runs():
      3     raise RuntimeError("This test is run")
----> 6 test_runs()

Input In [53], in test_runs()
      1 @homemade_skip_decorator(3.9 < 3.0)
      2 def test_runs():
----> 3     raise RuntimeError("This test is run")

RuntimeError: This test is run