Flow Matching Part 2. Solving ODEs with `tinygrad`
Posted on Mon 20 January 2025 in posts
Introduction
Welcome to the second part of the flow matching series of posts.
First major update is that the code is live on GitHub: tinyflow.
Secondly, let's extend our flow matching with different ODE solvers.
In the
previous post, we
introduced flow matching in general terms. In this post, we will focus on
solving ODEs with tinygrad
, a library intended for deep learning but suitable
for numerical and scientific computing as well.
This is especially useful since we can utilize GPUs almost seamlessly.
To add more versatility, I added a simple collection of solvers such as
Runge-Kutta-4 solver and Euler method sovlers. Unlike more mature libraries such
ass torchdiffeq
, my approach is
more simple and probably less numerically stable for now.
Implemented solvers
Runge-Kutta-4
The first solver is the classical Runge-Kutta-4 solver. It is easy to code and is implemented in a backend-agnostic fashion.
class RK4(ODESolver):
def __init__(self, rhs_fn: Callable | BaseNeuralNetwork):
super().__init__(rhs_fn)
def sample(self, h, t, rhs_prev):
t = t.reshape((1, 1))
t = t.repeat(rhs_prev.shape[0], 1)
return self.step(h, t, rhs_prev)
def step(self, h, t, rhs_prev):
k1 = self.rhs(t=t, x=rhs_prev)
k2 = self.rhs(t=t + h / 2, x=rhs_prev + k1 * h / 2)
k3 = self.rhs(t=t + h / 2, x=rhs_prev + k2 * h / 2)
k4 = self.rhs(t=t + h, x=rhs_prev + k3 * h)
return rhs_prev + h / 6 * (k1 + k2 * 2 + k3 * 2 + k4)
Euler Method
The second solver is the classical Euler method. It is easy to code and seems like a good one to include in any ODE solving library (albeit, this package does not aim to be a comprehensive ODE solver)
class Euler(ODESolver):
def __init__(self, rhs_fn: Callable):
super().__init__(rhs_fn)
def step(self, h, t, rhs_prev):
return rhs_prev + h * self.rhs(t=t, x=rhs_prev)
Midpoint Method
Finally, let's include the midpoint method since we used it last time:
class MidpointSolver(ODESolver):
def __init__(self, rhs_fn: Callable):
super().__init__(rhs_fn)
def step(self, h, t, rhs_prev):
return rhs_prev + h * self.rhs(
t=t + h / 2, x=rhs_prev + h / 2 * self.rhs(t=t, x=rhs_prev)
)
def sample(self, h, t, rhs_prev):
t = t.reshape((1, 1))
t = t.repeat(rhs_prev.shape[0], 1)
return self.step(h, t, rhs_prev)
Other modifications and future steps
Trainer Class
In my opinion, a training script belongs inside a special training class that can contain any relevant methods and allows for callback. So I added a trainer class that does something like that but is still very much a work in progress.
Path
Big part of flow matching is the actual path between distributions. I am currently working on multiple path classes, but our basic path used in previous post is already added.