Parallelizing model costs a lot of engineering effort:
DistIR is aimed to address the pain points.
To begin with, an abstraction of a parallelized program needs to represent two information: data sharding and computation sharding. In the current stage, DistIR only cares about data sharding and automatically infers simple computation sharding pattern.
DistIR uses DTensor, or distributed tensor, to represent how a tensor is sharded among devices. In DistIR, all tensors must be DTensor instead of Tensor.
R.DTensor(shape=(128, 128), dtype="float32", device_mesh=R.device_mesh((2, 2), I.Range(0, 4)), placement="R, S[0]")
This is the type annotation of a DTensor in Relax. It has 4 arguments:
shape
: The shape of the tensor (before sharding)
dtype
: The data type of the tensor
device_mesh
: The devices where to shard the tensor. To define a device mesh, we use R.device_mesh(shape, device_ids)
. In the code example above, the device mesh has a shape (2, 2), and the device ids are range(0, 4), which stands for 0, 1, 2, 3.
placement
: an array that has the same rank as the device_mesh's shape. It is used to describe how the DTensor data is distributed in the ith dimension of the device_mesh. There are two possible values for placement[i]
Replica()
: The tensor (or a subregion of tensor) is replicated among the dimension of device meshSharding(j)
: The j-th dimension of a tensor is sharded among the i-th dimension of device mesh.In the code example above, "R, S[0]" represents a placement array: [Replica(), Sharding(0)]
, meaning the tensor is sharded along row dimension into 2 parts. The upper part is replicated among device 0, 2, and the lower part is replicated among device 1, 3.
DistIR can be considered as a normal Relax program with additional information on tensors' sharding information, so anything that can happen on normal relax IRModule can also be performed on DistIR IRModule. For example, we can implement a transformation pass, print TVMScript, parse from TVMScript, and so on.
Suppose we have 2 devices, and we want to follow Megatron-LM style sharding strategy. We will shard A along columns, and shard B along rows.
Let the original expression be
Y=GeLU(XA), Z=YB