Motivation Recap:

Parallelizing model costs a lot of engineering effort:

DistIR is aimed to address the pain points.

Basic Concepts of DistIR: DTensor

To begin with, an abstraction of a parallelized program needs to represent two information: data sharding and computation sharding. In the current stage, DistIR only cares about data sharding and automatically infers simple computation sharding pattern.

DistIR uses DTensor, or distributed tensor, to represent how a tensor is sharded among devices. In DistIR, all tensors must be DTensor instead of Tensor.

R.DTensor(shape=(128, 128), dtype="float32", device_mesh=R.device_mesh((2, 2), I.Range(0, 4)), placement="R, S[0]")

This is the type annotation of a DTensor in Relax. It has 4 arguments:

DistIR can be considered as a normal Relax program with additional information on tensors' sharding information, so anything that can happen on normal relax IRModule can also be performed on DistIR IRModule. For example, we can implement a transformation pass, print TVMScript, parse from TVMScript, and so on.

Suppose we have 2 devices, and we want to follow Megatron-LM style sharding strategy. We will shard A along columns, and shard B along rows.

Let the original expression be

Y=GeLU(XA), Z=YB