Add INT8 GEMM support to the GEMM operator#94
Conversation
Wire up existing INT8 matmul kernels (i8→i8, i8→i16, i8→i32) through the Python GEMM operator layer. Fix get_arg_spec() to pass correct dtype to AIERuntimeArgSpec (was defaulting to bfloat16 for all types). Closes issue amd#93
dtype_in and dtype_out have repr=False, so they are excluded from the auto-generated operator name. When a bf16 and an int8 GEMM share the same dimensions (M, K, N, tiles, columns), they produce identical xclbin filenames. The first to compile wins; the second silently reuses the wrong binary, producing garbage output. Override the name property to append the dtype suffix (e.g. _i8_i32) when dtype_in is not the default bf16. bf16 names are unchanged for backward compatibility.
| identical dimensions.""" | ||
| base = super().name | ||
| if self.dtype_in != "bf16": | ||
| base += f"_{self.dtype_in}_{self.dtype_out}" |
There was a problem hiding this comment.
I think it might make sense to always change it to include dtype in/out... thoughts @andrej ?
from review feedback: replace GEMM's private _np_dtype_map with a shared np_dtype_map in test_utils.py, derived from the existing torch_dtype_map to stay in sync
design.py specified MAC dims (8,8,8) for NPU1 i8, but aie_kernels/aie2/mm.cc only provides matmul_vectorized_4x8x8_i8_* wrappers. The mismatch caused DMA stride/tile patterns to be shaped for an 8x8x8 MAC while the kernel consumed them as 4x8x8, link succeeded (symbol is matmul_i8_i32) but produced wrong results, surfacing as AssertionError in CI on Phoenix hardware. Also drop the duplicated "npu1" key in microkernel_mac_dim_map (the second entry silently overrode the first with identical values).
|
Found the i8 GEMM failure: Long-term suggestion: encode MAC dims in the kernel symbol |
Wire up existing INT8 matmul kernels (i8→i8, i8→i16, i8→i32) through the
Python GEMM operator layer. The C++ kernels already had the templates and
compile flags, this connects them to the Python API.
Also fixes a pre-existing bug in
get_arg_spec()whereAIERuntimeArgSpecdefaulted all buffers to bfloat16, causing silent data corruption for any
non-bf16 output type.
Closes #93
Added
dtype_in="i8") with i8, i16, i32 output typesmicrokernel_mac_dim_map-Di8_i32_ONLY, etc.)reference.pyChanged
get_arg_spec()now passes correctdtypetoAIERuntimeArgSpecdtype_in/dtype_out(existing bf16 tests unchanged)Removed