13 min read
Published: 16 June 2025
Updated: 28 March 2026
pythonjaxcoding

Hashing functions by what they select

It is impossible in general to verify the equivalence of functions in terms of their input-output behaviour, or semantics. For different practical reasons, many programming languages, including Python, also don't evaluate functions' equivalence in terms of their code structure, or syntax. Instead, Python computes hashes and equivalence of functions based solely on their memory address. So it doesn't make sense to use a dict to map directly from subtree selectors to some respective node data, and expect the uniqueness of keys to be decided by what subtree they select. But what if I very wisely insist I must build such a dict anyway? However misguided, it turns out solving this problem helps us with a more practical one: serialising hyperparameters for training runs, when those hyperparameters happen to be represented at runtime as selectors.

For the past two years, I've developed my machine learning projects with JAX and Equinox. The basis of JAX's power is how flexibly we can compose our computation graphs by writing in a functional paradigm, using transformations like grad, vmap, and jit. But the substance of its power is PyTrees, or the ability to transform over different types of tree-structured inputs in a unified way.

To be clear: "a PyTree" is code for "some nested composition of nodes, which JAX knows how to treat like any other tree because it's been told how to flatten and unflatten the nodes". Technically every Python object is a PyTree, even if it can only be treated as just a single, unstructured leaf.

Subtree selectors

A subtree selector function1. Sometimes I call these where-functions for short, since “selector” might be misinterpreted: note that a subtree selector returns a tree, not a structureless container. However, I also use “selector” when I think it is clear. takes a PyTree, selects one or more nodes, and composes them into some tree to return.

One use case for subtree selectors is to edit the nodes of a PyTree out-of-place. Suppose we have an instance of some data type:

import jax
import equinox as eqx
from equinox import Module

class Foo(Module):
    bar: dict[str, int]
    baz: tuple[float, float]

    def __call__(self, x: jax.Array):
        # If `Foo` were part of a model, we could transform `x` here,
        # using `self.bar` and `self.baz` as parameters.
        ...

tree = Foo(
    dict(a=1, b=2),
    (3.14, 2.718)
)

An Equinox Module is a type of Python dataclass, which JAX can manipulate as a PyTree.

A selector lets us specify the nodes whose values will be replaced:

where_nodes_to_update = lambda tree: (tree.bar['a'], tree.baz)

# Update some parts of the tree
updated_tree = eqx.tree_at(
    where_nodes_to_update,
    tree,
    (5, (6.28, 1.618)),
)

updated_tree
>> Foo(bar={'a': 5, 'b': 2}, baz=(6.28, 1.618))

Similarly, we can use selectors to specify partial initializations of model states. When working with JAX, we typically compose models as trees of Equinox Module objects. Each module node may have leaves which are its parameters; it may also be callable, and specify a transformation of the model state. But to keep the model purely functional, state cannot live inside the modules themselves: it must be passed around separately.

This is the approach I used when designing Feedbax.

Note: as of March 2026, I no longer use hierarchical state objects/passing, but instead use a dict-like state object, similar to what Equinox uses for stateful operations.

States typically have default initializations, but users may want to provide custom initializations for some part(s) of the state, without needing to construct the entire PyTree themselves. We might provide the custom initializations as a mapping from selectors, which specify the part(s) of the state to initialize, to the respective data to initialize with.

Suppose we are given a PyTree of default model states (i.e. arrays), and we'd like to initialize some of its leaves using custom arrays. For each leaf we want to initialize, we could pair its selector with appropriate leaf values to assign to the selected subtree. But maybe we don't know beforehand exactly which leaves we'll want to initialize, so we'd like a general tool that can perform the initialization for us in any case:

from collections.abc import Callable
from jaxtyping import Array, PyTree

def replace_nodes(tree: PyTree, spec: Sequence[tuple[Callable, Any]]):
    for where_func, node_value in spec:
        tree = eqx.tree_at(where_func, tree, node_value)
    return tree

This works, so long as the structure of the node values matches the structure of their respective selector's return value. However, of course some of our selectors might pick out trees containing redundant leaves from tree, and in that case we'll end up pointlessly replacing the same leaf more than once. Perhaps it is better to assume in this case that our selectors will specify individual, unique leaves (or at least, trees of non-intersecting subsets of leaves). So, suppose instead we use the function:

def replace_leaves(tree: PyTree, spec: Sequence[dict[Callable, Any]]):
    for where_func, leaf_value in spec.items():
        tree = eqx.tree_at(where_func, tree, leaf_value)
    return tree

Now we can specify a mapping between selectors that pick out unique leaves, and the arrays we should replace those leaves with:

# Assume these have already been assigned
default_state: PyTree[Array]
some_part_init_data: Array
some_other_init_data: Array

# Construct the mapping from parts of the state PyTree, to the data to initialize them with
init_state_spec: dict[Callable, Array] = {
    (lambda state: state.some_substate.part): some_part_init_data,
    (lambda state: state.some_other_substate.other_part): some_other_init_data,
}

# Update the default state with the custom init states
initial_state = replace_leaves(default_state, init_state_spec)

This does work, but all is not what it seems. We'll encounter some strange behaviour if we assume that init_state_spec treats its selector keys as unique, because they pick out unique leaves.

For example, if in some other context we want to determine which data will be used to initialize state.some_substate.part, we might try to access it like so:

init_state_mapping[lambda state: state.some_substate.part]

But this doesn't return some_part_init_data! It actually raises a KeyError.

Likewise, if we already have some init_state_spec and want to update it to use some different data for part of the state, we might try this:

init_state_mapping[lambda state: state.some_substate.part] = some_new_init_data

But this actually adds a new entry to the mapping, rather than replacing the old one.

What is responsible for this weirdness? Well, Python accesses and assigns dict entries based on the uniqueness of a key's hash... and this doesn't work like you might think, or like I once naively hoped, when the key happens to be a function.

What's in a function?

Any Python object that can be hashed2. That is, any object the builtin function hash can take as an input, returning a unique string. can be used as a key in a dict. While a Python function is hashable, its hash is based on its object identity, which in CPython (i.e. for almost all Python users) is just its memory address. For example, all of the following expressions evaluate to False:

# Explicitly compare by object identity
id(lambda x: x) == id(lambda x: x)
(lambda x: x) is (lambda x: x)

# Implicitly compare by object identity
(lambda x: x) == (lambda x: x)

# Explicitly compare by hash
hash(lambda x: x) == hash(lambda x: x)

Since these will evaluate False for any function, and since selectors are of course "any function", this example is also False:

hash(lambda model: model.some_layer) == hash(lambda model: model.some_layer)

At first this might seem kind of weird. Not only is lambda model: model.some_layer obviously identical to lambda model: model.some_layer in how it's written (i.e. its structure, or syntax), but the two are also obviously identical in how we should expect them to behave (i.e. their meaning, or semantics). Actually, in the case of a subtree selector, syntax and semantics are pretty closely aligned in general: once we know its syntax then we can be pretty certain about what it is going to return, and vice versa, since it's just performing some simple accesses of parts of a data structure.

If functions were hashed in those terms, then our selectors' uniqueness would be determined just like we want it to be. Seems pretty useful! So why wouldn't Python evaluate hashes or equivalence, in terms of either syntax or semantics?

Semantics? That's easy. And by "easy", of course I mean impossible: it's a well-known mathematical fact that there's no general method by which we can test whether any two functions will always behave the same way, for all inputs. So if programming language designers want tools like hash and == operators to be general-purpose, and useful for comparing any two functions their users might write, then those tools can't be based on semantic comparisons which are necessarily not general.

We could also try to evaluate equivalence in terms of syntax, which is the structure of the function's code. This is at least possible, when we have access to the code. But for arbitrary functions it can get expensive: parsing source, building ASTs, and normalizing for trivial differences such as variable names. Perhaps it's enough most of the time to know that two functions aren't literally identical in memory, which is what we test with a direct hash comparison.

Since dict uses hash to determine the uniqueness of keys, and the hash of a selector is not determined by what it picks out, dict cannot see our selectors the way we see them...

Using selectors as dict keys

Once upon a time I decided that I really wanted to be able to use subtree selectors as dict keys. But I knew selectors don't hash in terms of the part(s) of a PyTree they access. So I decided to make a custom type of dict that did not directly (mis)treat them as keys, but instead converted them to an intermediate value which would hash on my terms. For that, I needed a function that takes a selector as input, and returns a hashable representation of its syntax as output.

I grasped for a function where_func_to_str that would convert a selector to a string. I expected it to behave like so:

where_func_to_str(lambda state: state.layer2)
>> "layer2"

where_func_to_str(lambda s: s.some_substate.layer1)
>> "some_substate.layer1"

And maybe I could also get it to work for slightly more complex cases:

where_func_to_str(lambda state: (state.layer2, state.layer3))
>> ("layer2", "layer3")

Note that the name of the bound variable ("state", "s", whatever) does not matter here. What is important is that given some PyTree, whatever its name, we are accessing some part(s) of it, whose names are given.

I hadn't yet imagined how I would handle all possible selectors. For example, this would be no good:

where_func_to_str(lambda state: [state.layer2, state.layer3])
>> ["layer2", "layer3"]

That's because lists, including ["layer2", "layer3"], aren't hashable. They can't be used as keys! Nor can dicts or sets. Tuples are hashable, so ("layer2", "layer3") is fine as a dict key. Even after we find a where_func_to_str that works like we want, we'll still need to make sure our selectors return PyTrees which are hashable when their leaves are replaced with strings, or else be careful to cast them as such.

On my quest for a selector-based dict, I made three attempts to write such a where_func_to_str. But before we look at those, let's start with an example that will not work.

Won't work: literal source code

That is, what if we hash the string literal containing the code as it appears in the source file? This is the most direct comparison we can make. It's pretty cheap, since we just need to compare the strings character-by-character until there is a single mismatch. However, it's pretty naive, and hardly useful. Consider that none of the following functions will test as equivalent if we compare them literally, though their differences are trivial:

def func(x):
    return x + 1

def func_with_comment(x):
    return x + 1  # ooowee

def func_with_parens(x):
    return (x + 1)

def func_without_spaces(x):
    return x+1

Also not great: abstract syntax tree

An Abstract syntax tree (AST) is a data structure specifically intended to represent the syntax of a program. Python's ast implementation would say that all of the functions in the previous example have equivalent ASTs. However it, like most AST implementations, includes the naming of bound variables as part of the syntax. So for example:

def compare_ast(*code: str):
    tree_dumps = [ast.dump(ast.parse(s)) for s in code]
    return len(set(tree_dumps)) == 1

compare_ast("lambda x: x", "lambda x: (x)")  # True
compare_ast("lambda x: x", "lambda y: y")  # False; the only difference is a bound variable name

There are good reasons why we'd want to preserve bound variable names in ASTs, but when testing for the syntactic equivalence of functions, we generally admit alpha equivalence. Normalizing our ASTs to an alpha-equivalent form can require a little work when dealing with arbitrarily complex functions with nested scopes of bound variables.

When a programming language designer implements an operator like == for arbitrary comparisons, they don't know how complex the objects to be compared will be. If they base the comparison on structure, then in certain cases a very large amount of work may be implied. Instead, it makes sense to base comparisons on unique ID (e.g. memory address) so that we can tell whether two references are to the same object, while leaving more complex structural comparisons to be implemented by developers on top of the language base.

First actual attempt: from bytecode

Once Python has parsed the literal source code, it encodes the program as a sequence of lower-level instructions for the interpreter to run. Therefore, bytecode is language-dependent; bytecode in Python looks different than bytecode for a different interpreted language. Parsing does involve some semantic interpretation, and unlike an AST, bytecode is not mainly intended as a representation of syntax. But we can imagine the syntax of the resulting bytecode as a practical substitute for our source syntax.

So in some cases, it might make sense to compare the bytecode of two functions, to test for some kind of equivalence.

from collections.abc import Callable
import dis


def get_where_str(where_func: Callable) -> str:
    """
    Returns a string representation of the (nested) attributes accessed by a function.

    Only works for functions that take a single argument, and return that argument,
    or a single (nested) attribute accessed from the argument.
    """
    bytecode = dis.Bytecode(where_func)
    return ".".join(
        instr.argrepr for instr in bytecode
        if instr.opname == "LOAD_ATTR"
    )

But this is inflexible. It only works for a single attribute access chain. To make it work for arbitrarily complex subtrees would require more elaborate bytecode parsing.

Second attempt: from node paths via jax.tree

My second attempt used the PyTree facilities from JAX+Equinox. First, use the selector with tree_at to mark the selected leaves. Then, use JAX's leaves_with_path to get paths which encode the selector, and are certainly one-to-one with their string representations.

While this method is in some ways more flexible than using bytecode, it's even slower. Also it requires access to the tree we'll be using the selector on (or more technically, any tree that contains at least the selected node paths).

from collections.abc import Callable

import equinox as eqx
import jax.tree as jt
from jaxtyping import PyTree


class _NodeWrapper:
    def __init__(self, value):
        self.value = value


class _NodePath:
    def __init__(self, path):
        self.path = path

    def __iter__(self):
        return iter(self.path)


def where_func_to_paths(where: Callable, tree: PyTree):
    """
    Similar to `get_where_str`, but:

    - returns node paths, not strings;
    - works for `where` functions that return arbitrary PyTrees of nodes;
    - works for arbitrary node access (e.g. dict keys, seq. indices)
      and not just attribute access.

    Limitations:

    - requires a PyTree argument;
    - assumes the same object does not appear as multiple nodes in the tree;
    - if `where` specifies a node that is a subtree, it cannot also specify a node
      within that subtree.

    See [this issue](https://github.com/i-m-mll/feedbax/issues/14).
    """
    tree = eqx.tree_at(where, tree, replace_fn=lambda x: _NodeWrapper(x))
    id_tree = jt.map(id, tree, is_leaf=lambda x: isinstance(x, _NodeWrapper))
    node_ids = where(id_tree)

    paths_by_id = {node_id: path for path, node_id in jt.leaves_with_path(
        jt.map(
            lambda x: x if x in jt.leaves(node_ids) else None,
            id_tree,
        )
    )}

    paths = jt.map(lambda node_id: _NodePath(paths_by_id[node_id]), node_ids)

    return paths

Third attempt: representation tracer object

The third solution, and the best one I've found so far, relies on passing an instance of a special class to the selector.

By implementing special attribute-access and item-access behaviour for this class, we can easily construct a tree-of-strings representation of the structure of the selector's return value. This is the most flexible method, and also by far the fastest.

from collections.abc import Callable

import jax.tree as jt
from jaxtyping import PyTree


class _WhereStrConstructor:

    def __init__(self, label: str = ""):
        self.label = label

    def __getitem__(self, key: Any):
        if isinstance(key, str):
            key = f"'{key}'"
        elif isinstance(key, type):
            key = key.__name__
        return _WhereStrConstructor("".join([self.label, f"[{key}]"]))

    def __getattr__(self, name: str):
        sep = "." if self.label else ""
        return _WhereStrConstructor(sep.join([self.label, name]))


def _get_where_str_constructor_label(x: _WhereStrConstructor) -> str:
    return x.label


def where_func_to_attr_str_tree(where: Callable) -> PyTree[str]:
    """Similar to `get_where_str` and `where_func_to_paths`, but:

    - Avoids complicated logic of parsing bytecode, or traversing pytrees;
    - Works for `where` functions that return arbitrary PyTrees of node references;
    - Runs significantly (10+ times) faster than the other solutions.
    """

    try:
        return jt.map(_get_where_str_constructor_label, where(_WhereStrConstructor()))
    except TypeError:
        raise TypeError("`where` must return a PyTree of node references")

This is about 10 µs per call, for the data I'm working with. That’s an acceptable number, given that I only need to do a small number of these operations every training iteration.

Now that we have a decent where_func_to_str, we can construct the custom dict that effectively uses selectors as keys. The implementation is in the appendix, since it's less about selectors themselves and more about Python dict mechanics.

Serialisation of selectors

Another common use case for a subtree selector is to define the parts of a model which will be trained:

Our model objects are typically PyTrees whose nodes are of type eqx.Module. In this case, our selector assumes that our model possesses whichever nodes it refers to.

where_train = lambda model: (
    model.layer1,
    model.layer3,
)

Notice that where_train has the flavour of a hyperparameter: on different training runs/phases, we might want to train different parts of the model. So we might want to encode where_train in our hyperparameter data, similarly to how we would encode the number of training iterations, or the learning rate.

It can be costly to encode functions as data (e.g. an AST). And if we want our encoding to reflect the effect (i.e. semantics) of the function rather than simply the way it was coded, there is no general solution. But since we're working with selectors in particular, we can just use where_func_to_str to convert them to an equivalent representation which is easily serialisable:

where_func_to_str(where_train)
>> ('layer1', 'layer3')

Of course, we'll probably want to deserialise this representation and turn it back into a selector function, like when loading hyperparameters from disk to start a training run. For attribute accesses, the reverse transformation is straightforward:

from collections.abc import Callable
from operator import attrgetter

from jaxtyping import PyTree


def attr_str_tree_to_where_func(tree: PyTree[str, 'T']) -> Callable[[PyTree], PyTree[Any, 'T']]:
    """Reverse transformation to `where_func_to_attr_str_tree`.

    Takes a PyTree of strings describing attribute accesses, and returns a function
    that returns a PyTree of attributes.

    Note: This handles attribute accesses only. Extending to handle item accesses
    (e.g. dict keys, list indices) would require parsing the bracket notation
    from the string representation.
    """
    getters = jt.map(lambda s: attrgetter(s), tree)

    def where_func(obj):
        return jt.map(lambda g: g(obj), getters)

    return where_func

This is sufficient for my purposes, but it would be better in the future to generalize it to handle other access operations (e.g. indexing) as well.

Conclusion

While Python can't hash arbitrary functions by their behaviour (impossible) or their syntax (expensive), subtree selectors are constrained enough that we can trace their execution with a custom object and regenerate exactly which paths they access. This lets us construct a hashable representation that matches our intuition about when two selectors are "the same."

I found this pretty interesting, and I learned a lot. If you've read this far and think I'm mistaken or confused, I'd be glad to hear about it.

Appendix: a dict that uses accessor functions as keys

We'll start by defining the general case of a dict that which transforms its keys before trying to hash them.

from abc import abstractmethod
from collections import OrderedDict
from collections.abc import MutableMapping
from typing import Generic, TypeVar, overload


T = TypeVar("T")
KT1 = TypeVar("KT1")
KT2 = TypeVar("KT2")
VT = TypeVar("VT")

class AbstractTransformedOrderedDict(MutableMapping[KT2, VT], Generic[KT1, KT2, VT]):
    """Base for `OrderedDict`s which transform keys when getting and setting items.

    It stores the original keys, and otherwise behaves (e.g. when iterating)
    as though they are the true keys.

    This is useful when we want to use a certain type of object as a key, but
    it would not be hashed properly by `OrderedDict`, so we need to transform
    it into something else.

    Based on https://stackoverflow.com/a/3387975
    """
    store: OrderedDict[KT1, tuple[KT2, VT]]

    def __init__(self, *args, **kwargs):
        self.store = OrderedDict()
        self.update(OrderedDict(*args, **kwargs))

    def __getitem__(self, key: KT1 | KT2) -> VT:
        k = self._key_transform(key)
        return self.store[k][1]

    @overload
    def get(self, key: KT1 | KT2) -> VT | None: ...

    @overload
    def get(self, key: KT1 | KT2, default: T) -> VT | T: ...

    def get(self, key, default=None):
        k = self._key_transform(key)
        if k in self.store:
            return self.store[k][1]
        else:
            return default

    def __setitem__(self, key: KT2, value: VT):
        self.store[self._key_transform(key)] = (key, value)

    def __delitem__(self, key: KT2):
        del self.store[self._key_transform(key)]

    def __iter__(self):
        for key in self.store:
            yield self.store[key][0]

    def __len__(self) -> int:
        return len(self.store)

    def tree_flatten(self):
        """The same flatten function used by JAX for `dict`"""
        return tuple(self.values()), tuple(self.keys())

    @classmethod
    def tree_unflatten(cls, keys, values):
        return cls(zip(keys, values))

    @abstractmethod
    def _key_transform(self, key: KT1 | KT2) -> KT1:
        # TODO: Subclass and implement the desired transformation.
        ...

Finally, here is the dict that can use accessor function lambdas as keys:

from collections.abc import Callable
from typing import TypeVar

import equinox as eqx
import jax.tree as jt
from jax.tree_util import register_pytree_node_class
from jaxtyping import Array, PyTree


T = TypeVar('T')


def _where_to_str(where: Callable) -> str:
    """Return a single string representing an accessor function."""
    terms = where_func_to_attr_str_tree(where)
    if isinstance(terms, str):
        where_str = terms
    else:
        where_str = ", ".join(jt.leaves(terms))
    return where_str

@register_pytree_node_class
class WhereDict(AbstractTransformedOrderedDict[str, Callable[[PyTree], Any], T]):
    """An `OrderedDict` that allows use of accessor functions as keys.

    In particular, keys can be callables (functions/lambdas) that take a single argument,
    and return a PyTree of leaves, as visited by attribute/item accesses.

    Functions are parsed to strings, which can be used interchangeably as keys.
    For example, the following return the same value when `init_spec` is a `WhereDict`:

    > init_spec[lambda state: state.mechanics.effector]
    > init_spec['mechanics.effector']

    Finally, a `tuple[Callable, str]` may also be provided as a key, for cases where
    different unique entries must be included for the same callable. For example,
    the following are equivalent:

    > init_spec[(lambda state: state.mechanics.effector, "first")]
    > init_spec['mechanics.effector#first']

    Note that the hash symbol `#` is used as a delimiter in the string representation.

    ??? Note "Performance"
        For typical initialization mappings (1-10 items) construction is on the order
        of 50x slower than `OrderedDict`. Access is about 2-20x slower, depending
        whether indexed by string or by callable.

        However, we only need to do a single construction and a single access of
        init_spec per batch/evaluation, so performance shouldn't matter too much in
        practice: the added overhead is <50 us/batch, and a batch normally takes
        at least 20,000 us to train.
    """

    def _key_transform(self, key: str | Callable | tuple[Callable, str]) -> str:
        return self.key_transform(key)

    @staticmethod
    def key_transform(key: str | Callable | tuple[Callable, str]) -> str:

        if isinstance(key, str):
            pass
        elif isinstance(key, Callable):
            where_str = _where_to_str(key)
            return where_str
        elif isinstance(key, tuple):
            if not isinstance(key[0], Callable) or not isinstance(key[1], str):
                raise ValueError("Each `WhereDict` key should be supplied as a string, "
                                 "a callable, or a tuple of a callable and a string")
            where_str = _where_to_str(key[0])
            return '#'.join([where_str, key[1]])
        else:
            raise ValueError("Each `WhereDict` key should be supplied as a string, "
                             "a callable, or a tuple of a callable and a string")
        return key

    def __repr__(self):
        return eqx.tree_pformat(self)
  1. 1. Sometimes I call these where-functions for short, since “selector” might be misinterpreted: note that a subtree selector returns a tree, not a structureless container. However, I also use “selector” when I think it is clear.
  2. 2. That is, any object the builtin function hash can take as an input, returning a unique string.