# Grouped GEMM tests
#
# Format: --grouped=<dim_idx>:<num_groups>:<group_sizes> MxK:GxKxN
# Where:
#   dim_idx      - dimension index to group (0=M, 1=K)
#   num_groups   - number of groups/experts
#   group_sizes  - '+'-separated list of sizes for each group (sum must equal dimension size)
#   M            - total M dimension size (e.g., sum of group sizes if dim_idx=0)
#   G            - number of groups/experts (must match num_groups above)
#
# Configurations below cover:
# - Data types: f32, bf16, f16, u8/s8 with s4/u4 weights
# - Bias: --bia-dt=f32/bf16/f16, --bia_mask=2 (per-expert bias [num_experts, N])
# - Empty groups (M_g = 0)
# - Row-wise src scales: --attr-scales=src:1
# - Grouped src scales: 3:f16:1x32 (grouped along K, mask = (1<<0) | (1<<1))
# - Column-wise wei scales: 5 (scales vary by expert and N, mask = (1<<0) | (1<<2))
# - Grouped wei scales: 7:f32:32x1 (grouped along K, mask = (1<<0) | (1<<1) | (1<<2))
# - Column-wise weight zero points: --attr-zero-points=wei:5:s8 (mask = (1<<0) | (1<<2))
# - Grouped weight zero points: --attr-zero-points=wei:7:s8:32x1 (grouped along K)
# - WOQ (weight-only quantization) with f16 activations and s4/u4/8-bit weights
# - Transposed weights: --wtag=acb (weights layout [num_experts, N, K] vs abc [num_experts, K, N])
# - FP8 row-wise: f8_e5m2/f8_e4m3 dtypes with row-wise src + column-wise wei f32 scales
# - MXFP8: f8_e4m3 dtypes, e8m0 block scales, block size 32
# - MXFP4: f4_e2m1 dtypes, e8m0 block scales, block size 32
# - NVFP4: f4_e2m1 dtypes, f8_e4m3 block scales (block size 16) + global f32 scale via
#           binary mul post-op (--attr-post-ops=mul:f32:0)

--reset

# Basic correctness
--dt=f32:f32:f32,f16:f16:f16,bf16:bf16:bf16
--wtag=abc,acb
--batch=shapes_grouped

# Transposed weights with bias - f32
--reset
--dt=f32:f32:f32
--wtag=acb
--bia-dt=f32 --bia_mask=2
--batch=shapes_grouped

# Bias + row-wise src scales - f16
--reset
--dt=f16:f16:f16
--wtag=abc,acb
--bia-dt=f16 --bia_mask=2
--attr-scales=src:1
--batch=shapes_grouped

# Bias + row-wise src scales - bf16
--reset
--dt=bf16:bf16:bf16
--wtag=abc
--bia-dt=bf16 --bia_mask=2
--attr-scales=src:1
--batch=shapes_grouped

# Column-wise weight scales - f32
--reset
--dt=f32:f32:f32
--wtag=abc
--attr-scales=wei:5
--batch=shapes_grouped

# Combined src + weight scales - f16
--reset
--dt=f16:f16:f16
--wtag=abc,acb
--attr-scales=src:1+wei:5
--batch=shapes_grouped

# Blocked wei scales with different block sizes - u8/s8
--reset
--dt=u8:s4:bf16,s8:s4:bf16,u8:u8:bf16
--wtag=abc
--grouped=0:4:8+8+8+8
--attr-scales=wei:7:bf16:128x1,wei:7:bf16:32x1
32x512:4x512x256

# Combined row-wise src + blocked weight scales - u8/s8
--reset
--dt=u8:s4:bf16,s8:s4:bf16,u8:s4:bf16
--wtag=abc,acb
--grouped=0:4:8+8+8+8
--attr-scales=src:1+wei:7:bf16:128x1,src:1+wei:7:bf16:32x1
32x512:4x512x256

--reset
--dt=s8:s4:f32
--wtag=abc,acb
--grouped=0:4:8+8+8+8
--attr-scales=src:1+wei:7:f32:128x1
32x512:4x512x256

# int src + int wei with column-wise scales and ZPs
--reset
--dt=u8:s8:f32,s8:s8:f32
--wtag=abc,acb
--grouped=0:4:8+8+8+8
--attr-scales=wei:5:f32
--attr-zero-points=wei:5:s8
32x64:4x64x32

# int src + int wei with blocked scales and ZPs
--reset
--dt=u8:s8:bf16,s8:s8:bf16
--wtag=abc
--grouped=0:4:8+8+8+8
--attr-scales=wei:7:bf16:128x1
--attr-zero-points=wei:7:s8:128x1
32x512:4x512x256

# WOQ: f16 activation + s4/u4 weights with blocked scales and ZPs
--reset
--dt=f16:s4:f16,f16:u4:f16
--wtag=abc,acb
--grouped=0:4:8+8+8+8
--attr-fpmath=f16:true
--attr-scales=wei:7:f16:32x1
--attr-zero-points=,wei:7:u4:32x1
32x128:4x128x64

# WOQ: f16 activation + u8 weights with per-expert per-column scales and ZPs
--reset
--dt=f16:u8:f16
--wtag=abc,acb
--grouped=0:4:8+8+8+8
--attr-fpmath=f16:true
--attr-scales=wei:5:f16
--attr-zero-points=wei:5:u8
32x64:4x64x32

# int src + int wei with K-grouped f16 src scales and K-grouped wei scales and ZPs
--reset
--dt=u8:s8:f32,s8:s8:f32,u8:u8:f32
--wtag=abc,acb
--grouped=0:4:8+16+0+8
--attr-scales=src:3:f16:1x128+wei:7:f16:128x1
--attr-zero-points=,wei:7:s8:128x1
32x512:4x512x256

# f8_e5m2 with row-wise/block src scales + column-wise wei scales
--reset
--dt=f8_e5m2:f8_e5m2:bf16
--wtag=abc,acb
--attr-scales=,src:1+wei:5,src:3:e8m0:1x32+wei:7:e8m0:32x1
--grouped=0:4:16+8+0+8
32x128:4x128x64

# f8_e4m3 with e8m0 block scales
--reset
--dt=f8_e4m3:f8_e4m3:bf16
--wtag=abc,acb
--grouped=0:4:8+8+8+8
--attr-scales=src:3:e8m0:1x32+wei:7:e8m0:32x1,src:3:e8m0:1x32+wei:5,src:1+wei:7:e8m0:32x1
32x128:4x128x64

# f4_e2m1 with e8m0 block scales
--reset
--dt=f4_e2m1:f4_e2m1:bf16
--wtag=abc,acb
--grouped=0:4:8+8+8+8
--attr-scales=src:3:e8m0:1x32+wei:7:e8m0:32x1
32x128:4x128x64

# NVFP4: f4_e2m1 dt with f8_e4m3 block scales + global f32 scale via binary mul
--reset
--dt=f4_e2m1:f4_e2m1:bf16
--wtag=abc,acb
--grouped=0:4:8+8+8+8
--attr-scales=src:3:f8_e4m3:1x16+wei:7:f8_e4m3:16x1
--attr-post-ops=mul:f32:0
32x128:4x128x64

# DNNL_ARG_HINT_MAX_GROUP_SIZE usage
--reset
--dt=f32:f32:f32,bf16:bf16:bf16
--wtag=abc
--grouped=0:4:8+8+8+8:8
32x64:4x64x32

--reset
--dt=f16:f16:f16
--wtag=abc
--grouped=0:4:10+20+30+40:40
100x64:4x64x32
