Simple flow:

Relax → DistIR(global view) → DistIR(local view)→ Disco runtime

Advanced flow:

Relax → DistIR(global view) → DistIR(intermediate view) → DistIR(local view)→ Disco runtime

Example: MLP

Relax

@I.ir_module
class MLP:
    I.module_attrs({"device_num": 10})
    I.module_global_infos(
        {
            "mesh": [
                R.device_mesh((2,), I.Range(0, 2)),  # mesh[0]
                R.device_mesh((1,), I.Range(4, 5)),  # mesh[1]
            ]
        }
    )

    @R.function
    def foo(
        x: R.Tensor((128, 128), "float32"),
        weight1: R.Tensor((128, 128), "float32"),
        weight2: R.Tensor((128, 128), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        lv0 = R.matmul(x, weight1)
        lv1 = R.nn.gelu(lv0)
        lv2 = R.dist.annotate_sharding(lv1, device_mesh="mesh[0]", placement="S[1]")
        lv3 = R.matmul(lv2, weight2)
        return lv3

After propagate sharding spec, becomes DistIR (global view TensorIR)

DistIR means computation graph are built on distributed tensors

Global view means every TensorIR represents the computation performed by the whole device mesh

# global view - every op represents the computation performed by the whole device mesh

@I.ir_module
class Module:
    I.module_attrs({"device_num": 10})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]})
    @T.prim_func
    def matmul(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), B: T.Buffer((T.int64(128), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(128)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(A[v_i0, v_k], B[v_k, v_i1])
                T.writes(matmul_1[v_i0, v_i1])
                with T.init():
                    matmul_1[v_i0, v_i1] = T.float32(0)
                matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
    @T.prim_func
    def matmul2(A: T.Buffer((T.int64(128), T.int64(64)), "float32"), B: T.Buffer((T.int64(64), T.int64(128), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(A[v_i0, v_k], B[v_k, v_i1])
                T.writes(matmul_1[v_i0, v_i1])
                with T.init():
                    matmul_1[v_i0, v_i1] = T.float32(0)
                matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
				with T.block("all_reduce"):
					T.evaluate(T.allreduce, src_buffer=matmul_1.data, dst_buffer=matmul_1.data, size=128*128, group=(mesh="mesh[0]", dims=[0]))

    @R.function
    def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "R"), weight1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), weight2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]")) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
        cls = Module
        lv0 = R.dist.call_tir(cls.matmul, (x, weight1), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"))
        lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
        lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = lv1
        lv3 = R.dist.call_tir_local_vi(cls.matmul2, (lv2, weight2), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"))
        return lv3

Lower from global view TensorIR to local view TensorIR:

Local view means that TensorIR function represents the computation performed by one device node

@I.ir_module
class Module:
    I.module_attrs({"device_num": 10})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]})
    @T.prim_func
    def matmul1(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), B: T.Buffer((T.int64(128), T.int64(64)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(64)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(A[v_i0, v_k], B[v_k, v_i1])
                T.writes(matmul_1[v_i0, v_i1])
                with T.init():
                    matmul_1[v_i0, v_i1] = T.float32(0)
                matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]

    @T.prim_func
    def matmul2(A: T.Buffer((T.int64(128), T.int64(64)), "float32"), B: T.Buffer((T.int64(64), T.int64(128), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(A[v_i0, v_k], B[v_k, v_i1])
                T.writes(matmul_1[v_i0, v_i1])
                with T.init():
                    matmul_1[v_i0, v_i1] = T.float32(0)
                matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
				with T.block("all_reduce"):
					T.evaluate(T.allreduce, src_buffer=matmul_1.data, dst_buffer=matmul_1.data, size=128*128, group=(mesh="mesh[0]", dims=[0]))

    @R.function
    def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "R"), weight1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), weight2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]")) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
        cls = Module
        lv0 = R.dist.call_tir_local_view(cls.matmul1, (x, weight1), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"))
        lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
        lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = lv1
        lv3 = R.dist.call_tir_local_view(cls.matmul2, (lv2, weight2), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"))
        return lv3

Lower to disco runtime: (pipeline stage splitting happens here)

@I.ir_module
class Module:
    I.module_attrs({"device_num": 10})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]})
    @T.prim_func
    def matmul1(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), B: T.Buffer((T.int64(128), T.int64(64)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(64)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(A[v_i0, v_k], B[v_k, v_i1])
                T.writes(matmul_1[v_i0, v_i1])
                with T.init():
                    matmul_1[v_i0, v_i1] = T.float32(0)
                matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]

    @T.prim_func
    def matmul2(A: T.Buffer((T.int64(128), T.int64(64)), "float32"), B: T.Buffer((T.int64(64), T.int64(128), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(A[v_i0, v_k], B[v_k, v_i1])
                T.writes(matmul_1[v_i0, v_i1])
                with T.init():
                    matmul_1[v_i0, v_i1] = T.float32(0)
                matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
				with T.block("all_reduce"):
					T.evaluate(T.allreduce, src_buffer=matmul_1.data, dst_buffer=matmul_1.data, size=128*128, group=(mesh="mesh[0]", dims=[0]))

    @R.function
    def foo(x: R.Tensor((128, 128), "float32"), weight1: R.Tensor((128, 64), "float32"), weight2: R.Tensor((64, 128), "float32")) -> R.Tensor((128, 128), "float32"):
        cls = Module
        lv0 = R.call_tir(cls.matmul1, (x, weight1), out_sinfo=R.Tensor((128, 64), "float32"))
        lv1: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
        lv2: R.Tensor((128, 64), "float32") = lv1
        lv3 = R.call_tir(cls.matmul2, (lv2, weight2), out_sinfo=R.Tensor((128, 128), "float32"))
        return lv3

		@R.function
		def worker_func(x: DObject, weight1: DObject, weight2: DObject) -> DObject:
			x = R.broadcast(x)
		  y = foo(x, weight1, weight2)
		  return y	

Advanced flow: