Skip to content

Mosaic GPU Dialect

Reactant.MLIR.Dialects.mosaic_gpu.async_load Function

async_load

Schedules an async copy of the contents of the source MemRef in GMEM to the destination MemRef in SMEM. The destination MemRef in SMEM must be contiguous.

Upon completion of the copy, the complete-tx(complete-count) operation will always be executed on the provided barrier.

The indices and slice_lengths inputs define what slice of the GMEM source corresponds to the SMEM destination. Both indices and slice_lengths must have a length equal to the rank of the source. The values in indices are the starting indices of each dimension and the values in slice_lengths are the lengths. Providing -1 in slice_lengths indicates that the slice length is 1 and that the corresponding dimension should be collapsed and does not appear in the destination MemRef.

The data is written in row-major order to the contiguous SMEM destination. The source data does not need to be contiguous, except for the last (and minor-most) dimension.

The collective attribute can be provided to use TMA multicast to more efficiently load the GMEM data in cases where multiple thread blocks are grouped together in a cluster and need to load the same data. Each block in a cluster will first load a slice from GMEM to SMEM and then the slices will be multicast to all other blocks in the cluster. In this way TMA multicast guarantees L2 cache hits. The collective attribute is the list of cluster dimensions along which to partition the input data loads.

The predicate allows scheduling the transfer conditionally. The async copy is always scheduled by at most a single lane in the warpgroup.

source
Reactant.MLIR.Dialects.mosaic_gpu.async_store Function

async_store

Schedules an async store of the contents of the source MemRef in SMEM to the destination MemRef in GMEM. The source MemRef in SMEM must be contiguous.

The indices and slice_lengths inputs define what slice of the GMEM destination corresponds to the SMEM source. Both indices and slice_lengths must have a length equal to the rank of the destination. The values in indices are the starting indices of each dimension and the values in slice_lengths are the lengths. Providing -1 in slice_lengths indicates that this dimension is collapsed in the source and needs to be expanded to a slice of size 1 in the destination.

The data is written in row-major order to the GMEM destination. The source data in SMEM needs to be contiguous, but the destination GMEM does not.

The predicate allows scheduling the transfer conditionally. The async copy is always scheduled by at most a single lane in the warpgroup.

source
Reactant.MLIR.Dialects.mosaic_gpu.broadcast_in_dim Method

broadcast_in_dim

broadcast_dimensions must have the same size as the rank of the input vector and for each input dimension, specifies which output dimension it corresponds to.

source
Reactant.MLIR.Dialects.mosaic_gpu.custom_primitive Method

custom_primitive

Allows defining a custom Mosaic GPU primitive.

Custom primitives should carry input and output layouts for each of their vector operands and outputs, and input transforms for each of their memref operands that live in SMEM.

Custom primitives can only return vectors.

source
Reactant.MLIR.Dialects.mosaic_gpu.initialize_barrier Method

initialize_barrier

Initializes a memref of barriers each meant to synchronize exactly arrival_count threads.

The base pointer of the result memref corresponds to base_pointer, which must be a pointer to a shared memory location.

source
Reactant.MLIR.Dialects.mosaic_gpu.layout_cast Method

layout_cast Casts a vector value to a new strided or tiled layout.

source
Reactant.MLIR.Dialects.mosaic_gpu.return_ Method

return_

The return op is a terminator that indicates the end of execution within a CustomPrimitiveOp's region. It can optionally return some values, which become the results of the parent CustomPrimitiveOp.

The declared results of the parent CustomPrimitiveOp must match the operand types of this op.

source
Reactant.MLIR.Dialects.mosaic_gpu.tcgen05_mma Function

tcgen05_mma

Schedules tcgen05.mma instructions that perform the following matrix multiply and accumulate:

accumulator = a * b + accumulator

This operation supports larger inputs than the PTX-level MMA instruction and will schedule as many PTX-level MMA instructions as needed to accomplish the calculation.

The inputs should have the following shapes:

  • a: [groups_m * m, groups_k * s]

  • b: [groups_k * s, groups_n * s]

  • accumulator: [groups_m * m, groups_n * s]

where s == swizzle / element_bytewidth and m is specified according to https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape.

The output has an identical shape and type as the input accumulator.

The accumulator, a and b matrices need to be provided as 2-dimensional memrefs. The accumulator is always in TMEM and b is always in SMEM. a can be in TMEM or SMEM. a and b must have the same element type and when a is in TMEM only F16 or BF16 are supported.

a_scale and b_scale are optional scaling matrices that reside in TMEM. When set the operation is defined as:

accumulator = (a * a_scale) * (b * b_scale) + accumulator

accumulate is a boolean that indicates whether to perform the accumulate step.

source
Reactant.MLIR.Dialects.mosaic_gpu.tmem_alloc Method

tmem_alloc

This op allocates a chunk of TMEM and stores the pointer to the memory in the provided SMEM memref.

The smem_ptr is a pointer in SMEM where a pointer to the allocated TMEM will be stored. The op returns a memref to the allocated TMEM. The result must have a shape with dimensions [rows, logical_columns]. If packing is 1, then the number of logical (unpacked) columns is equal to the number of allocated columns in TMEM. Otherwise, these equations must hold:

packing = 32 / bitwidth(element type of result)
unpacked_columns = allocated_columns * packing

The number of allocated columns in TMEM can be any power of two in the range [32, 512]. If exact is true, then the calculated number of allocated columns must match that restriction. If exact is false and the calculated number of allocated columns is less than 32 or not a power of two, then it will be rounded up to the nearest power of two larger or equal to 32.

If collective is true 2 CTAs will perform the allocation collectively, otherwise, only one CTA will perform the allocation.

source
Reactant.MLIR.Dialects.mosaic_gpu.wait Method

wait

All threads in the warpgroup will block, waiting on the provided barrier until:

  • all pending threads have arrived on the barrier

  • all expected byte transfers have been completed

  • the barrier's parity matches the provided parity

source
Reactant.MLIR.Dialects.mosaic_gpu.wgmma Method

wgmma

Schedules WGMMA operations that perform the following matrix multiply and accumulate:

accumulator = a * b + accumulator

This operation supports larger inputs than the PTX-level WGMMA operation and will schedule as many PTX-level WGMMA operations as needed to accomplish the calculation. The b matrix, and optionally a, need to be provided as a 2-dimensional memref.

The inputs should have the following shapes:

  • a: [groups_m * 64, groups_k * s]

  • b: [groups_k * s, groups_n * s]

  • accumulator: [groups_m * 64, groups_n * s]

where s == swizzle / element_bytewidth.

The output has an identical shape and type as the input accumulator.

The accumulator is always in registers and b is always in shared memory. a and b must have the same element type and when a is in registers only F16 or BF16 are supported.

The accumulator must be a vector with a FragmentedLayout. The WGMMA operation will be executed in the async proxy and any inputs in registers need to be synchronized with a memory fence.

Usually a is read from shared memory if it is used directly in the WGMMA operation. If a needs to be transformed before it is used in the WGMMA operation, it may be more convenient to read it directly form registers. This avoids the need to store the data and wait for a fence.

source
Reactant.MLIR.Dialects.mosaic_gpu.with_transforms Method

with_transforms

This op enforces the provided transforms on the parameter memref.

source