Flow Matching Part 2. Solving ODEs with `tinygrad`

Posted on Mon 20 January 2025 in posts

Introduction

Welcome to the second part of the flow matching series of posts.

First major update is that the code is live on GitHub: tinyflow.

Secondly, let's extend our flow matching with different ODE solvers.

In the previous post, we introduced flow matching in general terms. In this post, we will focus on solving ODEs with tinygrad, a library intended for deep learning but suitable for numerical and scientific computing as well.

This is especially useful since we can utilize GPUs almost seamlessly.

To add more versatility, I added a simple collection of solvers such as Runge-Kutta-4 solver and Euler method sovlers. Unlike more mature libraries such ass torchdiffeq, my approach is more simple and probably less numerically stable for now.

Implemented solvers

Runge-Kutta-4

The first solver is the classical Runge-Kutta-4 solver. It is easy to code and is implemented in a backend-agnostic fashion.

class RK4(ODESolver):
    def __init__(self, rhs_fn: Callable | BaseNeuralNetwork):
        super().__init__(rhs_fn)

    def sample(self, h, t, rhs_prev):
        t = t.reshape((1, 1))
        t = t.repeat(rhs_prev.shape[0], 1)
        return self.step(h, t, rhs_prev)

    def step(self, h, t, rhs_prev):
        k1 = self.rhs(t=t, x=rhs_prev)
        k2 = self.rhs(t=t + h / 2, x=rhs_prev + k1 * h / 2)
        k3 = self.rhs(t=t + h / 2, x=rhs_prev + k2 * h / 2)
        k4 = self.rhs(t=t + h, x=rhs_prev + k3 * h)
        return rhs_prev + h / 6 * (k1 + k2 * 2 + k3 * 2 + k4)

Euler Method

The second solver is the classical Euler method. It is easy to code and seems like a good one to include in any ODE solving library (albeit, this package does not aim to be a comprehensive ODE solver)

class Euler(ODESolver):
    def __init__(self, rhs_fn: Callable):
        super().__init__(rhs_fn)

    def step(self, h, t, rhs_prev):
        return rhs_prev + h * self.rhs(t=t, x=rhs_prev)

Midpoint Method

Finally, let's include the midpoint method since we used it last time:

class MidpointSolver(ODESolver):
    def __init__(self, rhs_fn: Callable):
        super().__init__(rhs_fn)

    def step(self, h, t, rhs_prev):
        return rhs_prev + h * self.rhs(
            t=t + h / 2, x=rhs_prev + h / 2 * self.rhs(t=t, x=rhs_prev)
        )

    def sample(self, h, t, rhs_prev):
        t = t.reshape((1, 1))
        t = t.repeat(rhs_prev.shape[0], 1)
        return self.step(h, t, rhs_prev)

Other modifications and future steps

Trainer Class

In my opinion, a training script belongs inside a special training class that can contain any relevant methods and allows for callback. So I added a trainer class that does something like that but is still very much a work in progress.

Path

Big part of flow matching is the actual path between distributions. I am currently working on multiple path classes, but our basic path used in previous post is already added.