diff options
author | Richard Henderson <richard.henderson@linaro.org> | 2021-05-25 15:58:13 -0700 |
---|---|---|
committer | Peter Maydell <peter.maydell@linaro.org> | 2021-06-03 16:43:26 +0100 |
commit | 81266a1f58bf557280c6f7ce3cad1ba8ed8a56f1 (patch) | |
tree | e5e56230ae43c544d8d05ba36c910289ff9d1c40 /target/arm/vec_helper.c | |
parent | 839144784b613998edf7a7277ed2ed2015b0b4d7 (diff) |
target/arm: Implement bfloat16 matrix multiply accumulate
This is BFMMLA for both AArch64 AdvSIMD and SVE,
and VMMLA.BF16 for AArch32 NEON.
Reviewed-by: Peter Maydell <peter.maydell@linaro.org>
Signed-off-by: Richard Henderson <richard.henderson@linaro.org>
Message-id: 20210525225817.400336-9-richard.henderson@linaro.org
Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
Diffstat (limited to 'target/arm/vec_helper.c')
-rw-r--r-- | target/arm/vec_helper.c | 42 |
1 files changed, 41 insertions, 1 deletions
diff --git a/target/arm/vec_helper.c b/target/arm/vec_helper.c index 74a497f38c..27e9bdd329 100644 --- a/target/arm/vec_helper.c +++ b/target/arm/vec_helper.c @@ -2385,7 +2385,7 @@ static void do_mmla_b(void *vd, void *vn, void *vm, void *va, uint32_t desc, * Process the entire segment at once, writing back the * results only after we've consumed all of the inputs. * - * Key to indicies by column: + * Key to indices by column: * i j i j */ sum0 = a[H4(0 + 0)]; @@ -2472,3 +2472,43 @@ void HELPER(gvec_bfdot_idx)(void *vd, void *vn, void *vm, } clear_tail(d, opr_sz, simd_maxsz(desc)); } + +void HELPER(gvec_bfmmla)(void *vd, void *vn, void *vm, void *va, uint32_t desc) +{ + intptr_t s, opr_sz = simd_oprsz(desc); + float32 *d = vd, *a = va; + uint32_t *n = vn, *m = vm; + + for (s = 0; s < opr_sz / 4; s += 4) { + float32 sum00, sum01, sum10, sum11; + + /* + * Process the entire segment at once, writing back the + * results only after we've consumed all of the inputs. + * + * Key to indicies by column: + * i j i k j k + */ + sum00 = a[s + H4(0 + 0)]; + sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)]); + sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)]); + + sum01 = a[s + H4(0 + 1)]; + sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)]); + sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)]); + + sum10 = a[s + H4(2 + 0)]; + sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)]); + sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)]); + + sum11 = a[s + H4(2 + 1)]; + sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)]); + sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)]); + + d[s + H4(0 + 0)] = sum00; + d[s + H4(0 + 1)] = sum01; + d[s + H4(2 + 0)] = sum10; + d[s + H4(2 + 1)] = sum11; + } + clear_tail(d, opr_sz, simd_maxsz(desc)); +} |