The Basic Knowledge of Torch Train Pipeline

4 minute read

Update:

Published:

这篇博客主要讲解 PyTorch 训练模型的整个流程的具体细节, 包括如何在前向过程中构建计算图;后向传播过程中如何计算并保存梯度;优化器如何根据梯度更新模型参数。(建议先阅读我之前关于 torch.autograd 的博客 The Basic Knowledge of PyTorch Autograd )

Torch 训练的整体流程

我们以最简单的乘法为例:两个标量 $x_1$ 和 $x_2$ 相乘得到 $v$;然后使用v.backward()函数反向计算 $x_1$ 和 $x_2$ 的梯度;最后使用 SGD 优化器更新 $x_1$ 和 $x_2$。代码如下:

simple torch mul pipeline

接着我们使用torchvizmake_dot函数获取 PyTorch 构建的计算图:

simple torch DAG

可以看到,计算图的方向与前向计算过程刚好相反。这里,我们将简单描述 Torch 训练的整体流程:在执行乘法过程中,Torch 分别为 $x_1$ 和 $x_2$ 构建一个AccumulateGrad节点,并将 $x_1$ / $x_2$ 存储在对应的AccumulateGrad节点的variable属性中;然后根据*的乘法操作为 $v$ 构建一个MulBackwrad0节点,并将其存储在 $v$ 的grad_fn属性中。

而在后向传播计算梯度的过程中,执行v.backward()函数时,Torch 首先会获取到存储在 $v$ 的grad_fn属性中MulBackwrad0节点,然后将初始梯度gradient作为输入传递给其.backward()函数计算该节点的输入的梯度,即 $x_1$ 和 $x_2$ 的梯度;接着将 $x_1$ 和 $x_2$ 的梯度作为输入传递给各自对应的AccumulateGrad节点的.backward()函数实现将梯度累加到 $x_1$ 和 $x_2$ 的.grad属性中。

在 SGD 优化器更新 $x_1$ 和 $x_2$ 的过程中,SGD 的step()函数遍历初始化时传入的params参数,判断其required_grad属性是否为True,若为True,则取出其data属性和grad属性,将data减去grad,得到更新后的参数params

前向过程构建计算图

介绍基本知识:Node $\rightarrow$ Edge $\rightarrow$ MulBackward0

叙述过程:Tensor.mul $\rightarrow$ torch._C._TensorBase.__mul__ $\rightarrow$ mul_Tensor $\rightarrow$ collect_next_edges $\rightarrow$ gradient_edge $\rightarrow$ set_next_edges $\rightarrow$ set_history $\rightarrow$ set_gradient_edge

不知道你们有没有这样的疑惑:在我们的代码中,只是简单的编写了两个tensor的矩阵相乘:tensor = tensor1 * tensor2;而 PyTorch 便自动为我们构建了一个计算图(可以看到tensor.grad_fn属性为MulBackward0;如果tensor1 / tensor2.required_grad属性为True)。这是如何实现的?虽然我们在前面的博客 The Basic Knowledge of PyTorch Autograd 中讲了关于 PyTorch 自动求导的过程,知道了每个节点是在初等函数执行时立即创建的,但并没有涉及到具体的代码对应过程(即在tensor = tensor1 * tensor2背后究竟是哪些代码实现了计算图的创建)。实际上,PyTorch 在Tensor类中实现了对每个初等函数的重载,使得每个初等函数操作并不只是简单的实现初等函数而已。例如对于mul操作,Tensor类内的重载实现为:

tensor mul operation

可以看到,其内部实现是使用 C++ 语言来编写的,继续追溯到 C++ 源代码中,可以看到mul操作的具体实现为:

tensor mul operation in C++

这个代码有点吓人,让我们一步步来。其中,self, other分别是mul操作的第一个tensor和第二个tensor。首先,第 $4$ 行代码的compute_requires_grad()函数判断self/otherrequired_grad属性是否为True,只要有一个为True,则_any_requires_gradTrue,表示此时的mul操作需要生成节点,同时其生成的输出的required_grad也为True。在得到_any_requires_gradTrue后(第 $6$ 行代码),代码会创建一个MulBackward0作为该mul操作在计算图上的节点(第 $8$ 行代码),同时将其赋值给grad_fn;而set_next_edges()则是设置当前的MulBackward0节点与之前操作生成的节点的连接。

接下来,让我们继续深入每个部分。首先,self/other是一个Tensor,当设置其required_grad的属性为 True 时,会执行下面的set_requires_grad()函数:

set requires grad

其会为self/other创建一个新的属性autograd_meta_(AutogradMeta类),该属性用于存储self/other的梯度(grad_)和节点(grad_fn_)等),对应于 Python 代码里的.gradgrad_fn属性。(当然其还有梯度累加器(grad_accumulator_用于累加多个父节点传递的梯度)

其次,计算图的每个节点的类型均为Node结构体(对应于 Python 代码中的Function类)。下图是Node结构体的具体内容:

Node structure

其中,operator()apply()分别是节点的前向和反向计算函数(对应于 Python 代码中的forward()backward()函数),不同的节点可以重写它们以实现不同的计算过程。而next_edges_则是存储节点所连接的前向节点(对应于 Python 代码中的next_functions)。因此,next_edges_中的每条边都是Edge结构体,结构体中存储执行前向节点的指针。下图是Edge结构体的具体内容:

edge structure

所以,mul_Tensor()函数中的MulBackward0操作即是Node结构体的子结构体,其主要重写了apply()方法用于计算mul操作的反向过程(其没有重写operator()方法,因为mul操作的前向过程在mul_Tensor()函数中实现),如下图所示:

mulbackward structure

了解了各个变量的基本结构后,我们回到mul_Tensor()函数中。可以猜到,set_next_edges()应该是要将之前操作生成的节点赋值到当前的MulBackward0节点的next_edges_中。首先需要获取之前操作生成的节点,通过collect_next_edges()函数实现。如下图所示:

collect next edges

这个代码更吓人,还是让我们一步步来!首先,collect_next_edges()函数是通过输入mul操作的输入数据,即self, other;然后创建MakeNextFunctionList结构体的实例make,并调用其apply()方法(即MakeNextFunctionListoperator()方法)实现的获取之前操作生成的节点。而MakeNextFunctionListoperator()方法同样输入mul操作的输入数据,然后构建next_edges数组,接着通过调用gradient_edge()方法获取每个输入数据里保存的之前操作生成的节点(使用Edge结构体包装),并将其存储在next_edges数组中,最后将next_edges数组返回给collect_next_edges()函数。而gradient_edge()方法输入mul操作的输入数据,判断其是否保存的之前操作生成的节点gradient = self.grad_fn():若有,则说明该输入数据属于中间数据,则将其包装成Edge结构体后返回;若没有,则说明该输入数据属于最原始的输入数据,则将其保存的节点设置为AccumulateBackward节点(通过调用grad_accumulator()函数获得),并其包装成Edge结构体后返回(这就是为什么每个叶子节点前都有一个AccumulateBackward节点的原因)。从gradient_edge()方法返回到MakeNextFunctionListoperator()方法,再返回到collect_next_edges()函数,即可得到当前的MulBackward0节点的之前操作生成的节点。然后通过set_next_edges()将其赋值到next_edges_中。如下图所示:

set next edges

完成了在set_next_edges()后,接下来便需要计算前向过程(对应于 Python 代码中的Function类的forward()方法),获得计算结果result(mul_Tensor()函数中的第 $15 \sim 20$ 行)。

最后,需要将生成的MulBackward0节点保存到输出result中(对应 Python 代码的outputs.grad_fn = now_fn),其通过set_history()函数实现。如下图所示:

set_history_c++

首先,set_history()函数是通过输入前向过程的输出result和生成的MulBackward0节点,然后调用set_gradient_edge()方法实现的将MulBackward0节点保存在输出resultAutogradMeta属性的grad_fn_中。而set_gradient_edge()方法则是通过输入同样的前向过程的输出result和生成的MulBackward0节点,取出resultAutogradMeta属性meta,将MulBackward0赋值在其grad_fn_属性中。

至此,我们终于“稍微”搞懂了 PyTorch 自动化构建计算图的过程。原来在我们写了一个简单的tensor = tensor1 * tensor2代码背后,PyTorch 执行了如此多的额外代码操作来实现计算图的构建。

后向传播过程计算并保存梯度

敬请期待!

优化器根据梯度更新模型参数

不同于前向过程和后向过程,其代码需要深入到底层的 C++ 源代码进行理解,优化器利用计算得到的梯度更新模型参数的过程主要在 Python 源代码中实现。现在让我们以最简单的 SGD 优化器为例:首先我们需要初始化一个 SGD 优化器实例,它至少需要输入两个参数(模型参数params(即 $x_1$ 和 $x_2$)和初始学习率lr),如下图所示:

SGD optimizer

在经过前向过程(v = x1 * x2)和后向过程(v.backward())后,此时 $x_1$ 和 $x_2$ 的grad属性内已经存储了计算得到的梯度。因此,我们能想到的最直接的做法就是遍历params的每一个参数,判断每个参数的required_grad属性是否为True;若是,则取出其对应的grad属性内存储的梯度,并将该参数与其梯度(乘以学习率)进行相减即可实现参数更新。因此 SGD 类的简单实现应该如下图所示:

SDG class

但是这里有个问题,前面我们说过,PyTorch 重载了Tensor类的所有初等函数操作;因此,当我们执行param -= grad * self.lr操作时,我们实际上会在原有计算图的基础上再构建一个SubBackward0节点分支,如下图所示:

simple SGD problem

因此,为了不让 PyTorch 继续构建计算图,我们需要设置with torch.no_grad()来“告诉” PyTorch 下面的操作不需要构建计算图,此时Tensor类的所有初等函数操作就不会构建计算图。因此,改进的SGD 类代码如下图所示:

advance SGD class

而在 SGD 的源代码中,PyTorch 使用另一种方式来避免计算图的构建,通过使用torch._dynamo.graph_break()实现计算图的脱离来确保初等函数操作就不会继续构建计算图。

了解了如何简单实现 SGD 后,接下来让我们进入 SGD 的源代码来验证我们的实现是否正确。首先是 SGD 如何保存输入进来的params参数,下图为 SGD 的__init__()函数部分代码:

SGD source code init

可以看到,SGD 是通过调用其父类Optimizer__init__()函数将输入进来的参数保存在self.param_groups列表内。接下来就是 SGD 的step()函数,下图为 SGD 的step()函数部分代码:

SGD source code step

首先,step()函数对每个self.param_groups列表内的每个参数组group,调用self._init_group()判断其每个参数pgrad属性是否为None:如果不是,则表示需要更新该参数,则将其存储在params_with_grad列表中,同时使用d_p_list列表存储其对应的梯度p.grad

接着,对于那些需要更新的参数params_with_grad,调用sgd()函数进行参数更新。在sgd()函数中,进行一系列的检查后,调用_single_tensor_sgd()函数进行参数更新。 而_single_tensor_sgd()函数则是遍历params_with_grad列表中的所有参数,对于每个参数param列表,取出其在d_p_list列表中的对应的梯度d_p,并使用原地更新的方式进行参数更新:param.add_(d_p, alpha=-lr)。由于是原地更新,且传入优化器的参照即为模型参数,因此对应的模型中的参数也会同步进行更新。

至此,我们终于完成了 PyTorch 训练模型的整个流程的具体细节(深入到底层代码),包括如何在前向过程中构建计算图;后向传播过程中如何计算并保存梯度;优化器如何根据梯度更新模型参数。(灰底白字部分表示尚未完成部分🤪)

Reference

  1. Understanding PyTorch with an example: a step-by-step tutorial

  2. The SGD source code of PyTorch

  3. A lightweight package to create visualizations of PyTorch execution graphs and traces

  4. Overview of PyTorch Autograd Engine

  5. How Computational Graphs are Constructed in PyTorch

  6. How Computational Graphs are Executed in PyTorch

  7. PyTorch internals

  8. Ultimate guide to PyTorch Optimizers

  9. torch.optim

  10. What is a PyTorch optimizer?