Young Woman Contemplating a Skull by Alessandro Casolani Statens Museum for Kunst DSC08131

A few months ago I had to solve a problem in PyMongo that is harder than it seems: how do you register for notifications when the current thread has died?

The circumstances are these: when you call start_request in PyMongo, it gets a socket from its pool and assigns the socket to the current thread. We need some way to know when the current thread dies so we can reclaim the socket and return it to the socket pool for future use, rather than wastefully allowing it to be closed.

PyMongo can assume nothing about what kind of thread this is: It could've been started from the threading module, or the more primitive thread module, or it could've been started outside Python entirely, in C, as when PyMongo is running under mod_wsgi.

Here's what I came up with:

import threading
import weakref

class ThreadWatcher(object):
    class Vigil(object):
        pass

    def __init__(self):
        self._refs = {}
        self._local = threading.local()

    def _on_death(self, vigil_id, callback, ref):
        self._refs.pop(vigil_id)
        callback()

    def watch(self, callback):
        if not self.is_watching():
            self._local.vigil = v = ThreadWatcher.Vigil()
            on_death = partial(
                self._on_death, id(v), callback)

            ref = weakref.ref(v, on_death)
            self._refs[id(v)] = ref

    def is_watching(self):
        "Is the current thread being watched?"
        return hasattr(self._local, 'vigil')

    def unwatch(self):
        try:
            v = self._local.vigil
            del self._local.vigil
            self._refs.pop(id(v))
        except AttributeError:
            pass

The key lines are highlighted, in watch(). First, I make a weakref to a thread local. Weakrefs are permitted on subclasses of object but not object itself, so I use an inner class called Vigil. I initialize the weakref with a callback, which will be executed when the vigil is deleted.

The callback only fires if the weakref outlives the vigil, so I keep the weakref alive by storing it as a value in the _refs dict. The key into _refs can't be the vigil itself, since then the vigil would have a strong reference and wouldn't be deleted when the thread dies. I use id(key) instead.

Let's step through this. When a thread calls watch(), the only strong reference to the vigil is a thread-local. When a thread dies its locals are cleaned up, the vigil is dereferenced, and _on_death runs. _on_death cleans up _refs and then voilà, it runs the original callback.

When exactly is the vigil deleted? This is a subtle point, as the sages among you know. First, PyPy uses occasional mark and sweep garbage collection instead of reference-counting, so the vigil isn't deleted until some time after the thread dies. In unittests, I force the issue with gc.collect().

Second, there's a bug in CPython 2.6 and earlier, fixed by Antoine Pitrou in CPython 2.7.1, where thread locals aren't cleaned up until the thread dies and some other thread accesses the local. I wrote about this in detail last year when I was struggling with it. gc.collect() won't help in this case.

Thirdly, when is the local cleaned up in Python 2.7.1 and later? It happens as soon as the interpreter deletes the underlying PyThreadState, but that can actually come after Thread.join() returns—join() is simply waiting for a Condition to be set at the end of the thread's run, which comes before the locals are cleared. So in Python 2.7.1 we need to sleep a few milliseconds after joining the thread to be certain it's truly gone.

Thus a reliable test for my ThreadWatcher class might look like:

class TestWatch(unittest.TestCase):
    def test_watch(self):
        watcher = ThreadWatcher()
        callback_ran = [False]

        def callback():
            callback_ran[0] = True

        def target():
            watcher.watch(callback)

        t = threading.Thread(target=target)
        t.start()
        t.join()

        # Trigger collection in Py 2.6
        # See http://bugs.python.org/issue1868
        watcher.is_watching()
        gc.collect()

        # Cleanup can take a few ms in
        # Python >= 2.7
        for _ in range(10):
            if callback_ran[0]:
                break
            else:
                time.sleep(.1)


        assert callback_ran[0]
        # id(v) removed from _refs?
        assert not watcher._refs

The is_watching() call accesses the local object from the main thread after the child has died, working around the Python 2.6 bug, and the gc.collect() call makes the test pass in PyPy. The sleep loop gives Python 2.7.1 a chance to finish tearing down the thread state, including locals.

Two final cautions. The first is, you can't predict which thread runs the callback. In Python 2.6 it's whichever thread accesses the local after the child dies. In later versions, with Pitrou's improved thread-local implementation, the callback is run on the dying child thread. In PyPy it's whichever thread is active when the garbage collector decides to run.

The second caution is, there's an unreported memory-leak bug in Python 2.6, which Pitrou fixed in Python 2.7.1 along with the other bug I linked to. If you access a thread-local from within the weakref callback, you're touching the local in an inconsistent state, and the next object stored in the local will never be dereferenced. So don't do that. Here's a demonstration:

class TestRefLeak(unittest.TestCase):
    def test_leak(self):
        watcher = ThreadWatcher()
        n_callbacks = [0]
        nthreads = 10

        def callback():
            # BAD, NO!:
            # Accessing thread-local in callback
            watcher.is_watching()
            n_callbacks[0] += 1

        def target():
            watcher.watch(callback)

        for _ in range(nthreads):
            t = threading.Thread(target=target)
            t.start()
            t.join()

        watcher.is_watching()
        gc.collect()
        for _ in range(10):
            if n_callbacks[0] == nthreads:
                break
            else:
                time.sleep(.1)

        self.assertEqual(nthreads, n_callbacks[0])

In Python 2.7.1 and later the test passes because all ten threads' locals are cleaned up, and the callback runs ten times. But in Python 2.6 only five locals are deleted.

I discovered this bug when I rewrote the connection pool in PyMongo 2.2 and a user reported that in Python 2.6 and mod_wsgi, every second request leaked one socket! I fixed PyMongo in version 2.2.1 by avoiding accessing thread locals while they're being torn down. (See bug PYTHON-353.)

Update: I've discovered that in Python 2.7.0 and earlier, you need to lock around the assignment to self._local.vigil, see "Another Thing About Threadlocals".

For further reading:


Post-script: The image up top is a memento mori, a "reminder you will die," by Alessandro Casolani from the 16th Century. The memento mori genre is intended to offset a portrait subject's vanity—you look good now, but your beauty won't make a difference when you face your final judgment.

This was painted circa 1502 by Andrea Previtali:

Andrea Previtali Memento Mori WGA18406

The inscription is "Hic decor hec forma manet, hec lex omnibus unam," which my Latin-nerd friends translate as, "This beauty endures only in this form, this law is the same for everyone." It was painted upside-down on the back of this handsome guy:

Andrea Previtali portrait of a man

The painting was mounted on an axle so the face and the skull could be rapidly alternated and compared. Think about that the next time you start a thread—it may be running now, but soon enough it will terminate and even its thread-id will be recycled.