Simple flow:
Relax → DistIR(global view) → DistIR(local view)→ Disco runtime
Advanced flow:
Relax → DistIR(global view) → DistIR(intermediate view) → DistIR(local view)→ Disco runtime
Example: MLP
Relax
@I.ir_module
class MLP:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2,), I.Range(0, 2)), # mesh[0]
R.device_mesh((1,), I.Range(4, 5)), # mesh[1]
]
}
)
@R.function
def foo(
x: R.Tensor((128, 128), "float32"),
weight1: R.Tensor((128, 128), "float32"),
weight2: R.Tensor((128, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
lv0 = R.matmul(x, weight1)
lv1 = R.nn.gelu(lv0)
lv2 = R.dist.annotate_sharding(lv1, device_mesh="mesh[0]", placement="S[1]")
lv3 = R.matmul(lv2, weight2)
return lv3
After propagate sharding spec, becomes DistIR (global view TensorIR)
DistIR means computation graph are built on distributed tensors
Global view means every TensorIR represents the computation performed by the whole device mesh
# global view - every op represents the computation performed by the whole device mesh
@I.ir_module
class Module:
I.module_attrs({"device_num": 10})
I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]})
@T.prim_func
def matmul(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), B: T.Buffer((T.int64(128), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(128)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(A[v_i0, v_k], B[v_k, v_i1])
T.writes(matmul_1[v_i0, v_i1])
with T.init():
matmul_1[v_i0, v_i1] = T.float32(0)
matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
@T.prim_func
def matmul2(A: T.Buffer((T.int64(128), T.int64(64)), "float32"), B: T.Buffer((T.int64(64), T.int64(128), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(A[v_i0, v_k], B[v_k, v_i1])
T.writes(matmul_1[v_i0, v_i1])
with T.init():
matmul_1[v_i0, v_i1] = T.float32(0)
matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
with T.block("all_reduce"):
T.evaluate(T.allreduce, src_buffer=matmul_1.data, dst_buffer=matmul_1.data, size=128*128, group=(mesh="mesh[0]", dims=[0]))
@R.function
def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "R"), weight1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), weight2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]")) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
cls = Module
lv0 = R.dist.call_tir(cls.matmul, (x, weight1), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"))
lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = lv1
lv3 = R.dist.call_tir_local_vi(cls.matmul2, (lv2, weight2), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"))
return lv3
Lower from global view TensorIR to local view TensorIR:
Local view means that TensorIR function represents the computation performed by one device node
@I.ir_module
class Module:
I.module_attrs({"device_num": 10})
I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]})
@T.prim_func
def matmul1(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), B: T.Buffer((T.int64(128), T.int64(64)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(64)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(A[v_i0, v_k], B[v_k, v_i1])
T.writes(matmul_1[v_i0, v_i1])
with T.init():
matmul_1[v_i0, v_i1] = T.float32(0)
matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
@T.prim_func
def matmul2(A: T.Buffer((T.int64(128), T.int64(64)), "float32"), B: T.Buffer((T.int64(64), T.int64(128), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(A[v_i0, v_k], B[v_k, v_i1])
T.writes(matmul_1[v_i0, v_i1])
with T.init():
matmul_1[v_i0, v_i1] = T.float32(0)
matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
with T.block("all_reduce"):
T.evaluate(T.allreduce, src_buffer=matmul_1.data, dst_buffer=matmul_1.data, size=128*128, group=(mesh="mesh[0]", dims=[0]))
@R.function
def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "R"), weight1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), weight2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]")) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
cls = Module
lv0 = R.dist.call_tir_local_view(cls.matmul1, (x, weight1), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"))
lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = lv1
lv3 = R.dist.call_tir_local_view(cls.matmul2, (lv2, weight2), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"))
return lv3
Lower to disco runtime: (pipeline stage splitting happens here)
@I.ir_module
class Module:
I.module_attrs({"device_num": 10})
I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]})
@T.prim_func
def matmul1(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), B: T.Buffer((T.int64(128), T.int64(64)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(64)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(A[v_i0, v_k], B[v_k, v_i1])
T.writes(matmul_1[v_i0, v_i1])
with T.init():
matmul_1[v_i0, v_i1] = T.float32(0)
matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
@T.prim_func
def matmul2(A: T.Buffer((T.int64(128), T.int64(64)), "float32"), B: T.Buffer((T.int64(64), T.int64(128), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(A[v_i0, v_k], B[v_k, v_i1])
T.writes(matmul_1[v_i0, v_i1])
with T.init():
matmul_1[v_i0, v_i1] = T.float32(0)
matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
with T.block("all_reduce"):
T.evaluate(T.allreduce, src_buffer=matmul_1.data, dst_buffer=matmul_1.data, size=128*128, group=(mesh="mesh[0]", dims=[0]))
@R.function
def foo(x: R.Tensor((128, 128), "float32"), weight1: R.Tensor((128, 64), "float32"), weight2: R.Tensor((64, 128), "float32")) -> R.Tensor((128, 128), "float32"):
cls = Module
lv0 = R.call_tir(cls.matmul1, (x, weight1), out_sinfo=R.Tensor((128, 64), "float32"))
lv1: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
lv2: R.Tensor((128, 64), "float32") = lv1
lv3 = R.call_tir(cls.matmul2, (lv2, weight2), out_sinfo=R.Tensor((128, 128), "float32"))
return lv3
@R.function
def worker_func(x: DObject, weight1: DObject, weight2: DObject) -> DObject:
x = R.broadcast(x)
y = foo(x, weight1, weight2)
return y
Advanced flow: