aboutsummaryrefslogtreecommitdiff
path: root/target/arm/vec_helper.c
diff options
context:
space:
mode:
Diffstat (limited to 'target/arm/vec_helper.c')
-rw-r--r--target/arm/vec_helper.c805
1 files changed, 616 insertions, 189 deletions
diff --git a/target/arm/vec_helper.c b/target/arm/vec_helper.c
index 3fbeae87cb..e84b438340 100644
--- a/target/arm/vec_helper.c
+++ b/target/arm/vec_helper.c
@@ -22,33 +22,82 @@
#include "exec/helper-proto.h"
#include "tcg/tcg-gvec-desc.h"
#include "fpu/softfloat.h"
+#include "qemu/int128.h"
#include "vec_internal.h"
-/* Note that vector data is stored in host-endian 64-bit chunks,
- so addressing units smaller than that needs a host-endian fixup. */
-#ifdef HOST_WORDS_BIGENDIAN
-#define H1(x) ((x) ^ 7)
-#define H2(x) ((x) ^ 3)
-#define H4(x) ((x) ^ 1)
-#else
-#define H1(x) (x)
-#define H2(x) (x)
-#define H4(x) (x)
-#endif
-
-/* Signed saturating rounding doubling multiply-accumulate high half, 16-bit */
-static int16_t do_sqrdmlah_h(int16_t src1, int16_t src2, int16_t src3,
- bool neg, bool round, uint32_t *sat)
+/* Signed saturating rounding doubling multiply-accumulate high half, 8-bit */
+int8_t do_sqrdmlah_b(int8_t src1, int8_t src2, int8_t src3,
+ bool neg, bool round)
{
/*
* Simplify:
- * = ((a3 << 16) + ((e1 * e2) << 1) + (1 << 15)) >> 16
- * = ((a3 << 15) + (e1 * e2) + (1 << 14)) >> 15
+ * = ((a3 << 8) + ((e1 * e2) << 1) + (round << 7)) >> 8
+ * = ((a3 << 7) + (e1 * e2) + (round << 6)) >> 7
*/
int32_t ret = (int32_t)src1 * src2;
if (neg) {
ret = -ret;
}
+ ret += ((int32_t)src3 << 7) + (round << 6);
+ ret >>= 7;
+
+ if (ret != (int8_t)ret) {
+ ret = (ret < 0 ? INT8_MIN : INT8_MAX);
+ }
+ return ret;
+}
+
+void HELPER(sve2_sqrdmlah_b)(void *vd, void *vn, void *vm,
+ void *va, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int8_t *d = vd, *n = vn, *m = vm, *a = va;
+
+ for (i = 0; i < opr_sz; ++i) {
+ d[i] = do_sqrdmlah_b(n[i], m[i], a[i], false, true);
+ }
+}
+
+void HELPER(sve2_sqrdmlsh_b)(void *vd, void *vn, void *vm,
+ void *va, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int8_t *d = vd, *n = vn, *m = vm, *a = va;
+
+ for (i = 0; i < opr_sz; ++i) {
+ d[i] = do_sqrdmlah_b(n[i], m[i], a[i], true, true);
+ }
+}
+
+void HELPER(sve2_sqdmulh_b)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int8_t *d = vd, *n = vn, *m = vm;
+
+ for (i = 0; i < opr_sz; ++i) {
+ d[i] = do_sqrdmlah_b(n[i], m[i], 0, false, false);
+ }
+}
+
+void HELPER(sve2_sqrdmulh_b)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int8_t *d = vd, *n = vn, *m = vm;
+
+ for (i = 0; i < opr_sz; ++i) {
+ d[i] = do_sqrdmlah_b(n[i], m[i], 0, false, true);
+ }
+}
+
+/* Signed saturating rounding doubling multiply-accumulate high half, 16-bit */
+int16_t do_sqrdmlah_h(int16_t src1, int16_t src2, int16_t src3,
+ bool neg, bool round, uint32_t *sat)
+{
+ /* Simplify similarly to do_sqrdmlah_b above. */
+ int32_t ret = (int32_t)src1 * src2;
+ if (neg) {
+ ret = -ret;
+ }
ret += ((int32_t)src3 << 15) + (round << 14);
ret >>= 15;
@@ -133,11 +182,87 @@ void HELPER(neon_sqrdmulh_h)(void *vd, void *vn, void *vm,
clear_tail(d, opr_sz, simd_maxsz(desc));
}
+void HELPER(sve2_sqrdmlah_h)(void *vd, void *vn, void *vm,
+ void *va, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int16_t *d = vd, *n = vn, *m = vm, *a = va;
+ uint32_t discard;
+
+ for (i = 0; i < opr_sz / 2; ++i) {
+ d[i] = do_sqrdmlah_h(n[i], m[i], a[i], false, true, &discard);
+ }
+}
+
+void HELPER(sve2_sqrdmlsh_h)(void *vd, void *vn, void *vm,
+ void *va, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int16_t *d = vd, *n = vn, *m = vm, *a = va;
+ uint32_t discard;
+
+ for (i = 0; i < opr_sz / 2; ++i) {
+ d[i] = do_sqrdmlah_h(n[i], m[i], a[i], true, true, &discard);
+ }
+}
+
+void HELPER(sve2_sqdmulh_h)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int16_t *d = vd, *n = vn, *m = vm;
+ uint32_t discard;
+
+ for (i = 0; i < opr_sz / 2; ++i) {
+ d[i] = do_sqrdmlah_h(n[i], m[i], 0, false, false, &discard);
+ }
+}
+
+void HELPER(sve2_sqrdmulh_h)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int16_t *d = vd, *n = vn, *m = vm;
+ uint32_t discard;
+
+ for (i = 0; i < opr_sz / 2; ++i) {
+ d[i] = do_sqrdmlah_h(n[i], m[i], 0, false, true, &discard);
+ }
+}
+
+void HELPER(sve2_sqdmulh_idx_h)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, j, opr_sz = simd_oprsz(desc);
+ int idx = simd_data(desc);
+ int16_t *d = vd, *n = vn, *m = (int16_t *)vm + H2(idx);
+ uint32_t discard;
+
+ for (i = 0; i < opr_sz / 2; i += 16 / 2) {
+ int16_t mm = m[i];
+ for (j = 0; j < 16 / 2; ++j) {
+ d[i + j] = do_sqrdmlah_h(n[i + j], mm, 0, false, false, &discard);
+ }
+ }
+}
+
+void HELPER(sve2_sqrdmulh_idx_h)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, j, opr_sz = simd_oprsz(desc);
+ int idx = simd_data(desc);
+ int16_t *d = vd, *n = vn, *m = (int16_t *)vm + H2(idx);
+ uint32_t discard;
+
+ for (i = 0; i < opr_sz / 2; i += 16 / 2) {
+ int16_t mm = m[i];
+ for (j = 0; j < 16 / 2; ++j) {
+ d[i + j] = do_sqrdmlah_h(n[i + j], mm, 0, false, true, &discard);
+ }
+ }
+}
+
/* Signed saturating rounding doubling multiply-accumulate high half, 32-bit */
-static int32_t do_sqrdmlah_s(int32_t src1, int32_t src2, int32_t src3,
- bool neg, bool round, uint32_t *sat)
+int32_t do_sqrdmlah_s(int32_t src1, int32_t src2, int32_t src3,
+ bool neg, bool round, uint32_t *sat)
{
- /* Simplify similarly to int_qrdmlah_s16 above. */
+ /* Simplify similarly to do_sqrdmlah_b above. */
int64_t ret = (int64_t)src1 * src2;
if (neg) {
ret = -ret;
@@ -220,197 +345,253 @@ void HELPER(neon_sqrdmulh_s)(void *vd, void *vn, void *vm,
clear_tail(d, opr_sz, simd_maxsz(desc));
}
-/* Integer 8 and 16-bit dot-product.
- *
- * Note that for the loops herein, host endianness does not matter
- * with respect to the ordering of data within the 64-bit lanes.
- * All elements are treated equally, no matter where they are.
- */
-
-void HELPER(gvec_sdot_b)(void *vd, void *vn, void *vm, uint32_t desc)
+void HELPER(sve2_sqrdmlah_s)(void *vd, void *vn, void *vm,
+ void *va, uint32_t desc)
{
intptr_t i, opr_sz = simd_oprsz(desc);
- uint32_t *d = vd;
- int8_t *n = vn, *m = vm;
+ int32_t *d = vd, *n = vn, *m = vm, *a = va;
+ uint32_t discard;
for (i = 0; i < opr_sz / 4; ++i) {
- d[i] += n[i * 4 + 0] * m[i * 4 + 0]
- + n[i * 4 + 1] * m[i * 4 + 1]
- + n[i * 4 + 2] * m[i * 4 + 2]
- + n[i * 4 + 3] * m[i * 4 + 3];
+ d[i] = do_sqrdmlah_s(n[i], m[i], a[i], false, true, &discard);
}
- clear_tail(d, opr_sz, simd_maxsz(desc));
}
-void HELPER(gvec_udot_b)(void *vd, void *vn, void *vm, uint32_t desc)
+void HELPER(sve2_sqrdmlsh_s)(void *vd, void *vn, void *vm,
+ void *va, uint32_t desc)
{
intptr_t i, opr_sz = simd_oprsz(desc);
- uint32_t *d = vd;
- uint8_t *n = vn, *m = vm;
+ int32_t *d = vd, *n = vn, *m = vm, *a = va;
+ uint32_t discard;
for (i = 0; i < opr_sz / 4; ++i) {
- d[i] += n[i * 4 + 0] * m[i * 4 + 0]
- + n[i * 4 + 1] * m[i * 4 + 1]
- + n[i * 4 + 2] * m[i * 4 + 2]
- + n[i * 4 + 3] * m[i * 4 + 3];
+ d[i] = do_sqrdmlah_s(n[i], m[i], a[i], true, true, &discard);
}
- clear_tail(d, opr_sz, simd_maxsz(desc));
}
-void HELPER(gvec_sdot_h)(void *vd, void *vn, void *vm, uint32_t desc)
+void HELPER(sve2_sqdmulh_s)(void *vd, void *vn, void *vm, uint32_t desc)
{
intptr_t i, opr_sz = simd_oprsz(desc);
- uint64_t *d = vd;
- int16_t *n = vn, *m = vm;
+ int32_t *d = vd, *n = vn, *m = vm;
+ uint32_t discard;
- for (i = 0; i < opr_sz / 8; ++i) {
- d[i] += (int64_t)n[i * 4 + 0] * m[i * 4 + 0]
- + (int64_t)n[i * 4 + 1] * m[i * 4 + 1]
- + (int64_t)n[i * 4 + 2] * m[i * 4 + 2]
- + (int64_t)n[i * 4 + 3] * m[i * 4 + 3];
+ for (i = 0; i < opr_sz / 4; ++i) {
+ d[i] = do_sqrdmlah_s(n[i], m[i], 0, false, false, &discard);
}
- clear_tail(d, opr_sz, simd_maxsz(desc));
}
-void HELPER(gvec_udot_h)(void *vd, void *vn, void *vm, uint32_t desc)
+void HELPER(sve2_sqrdmulh_s)(void *vd, void *vn, void *vm, uint32_t desc)
{
intptr_t i, opr_sz = simd_oprsz(desc);
- uint64_t *d = vd;
- uint16_t *n = vn, *m = vm;
+ int32_t *d = vd, *n = vn, *m = vm;
+ uint32_t discard;
- for (i = 0; i < opr_sz / 8; ++i) {
- d[i] += (uint64_t)n[i * 4 + 0] * m[i * 4 + 0]
- + (uint64_t)n[i * 4 + 1] * m[i * 4 + 1]
- + (uint64_t)n[i * 4 + 2] * m[i * 4 + 2]
- + (uint64_t)n[i * 4 + 3] * m[i * 4 + 3];
+ for (i = 0; i < opr_sz / 4; ++i) {
+ d[i] = do_sqrdmlah_s(n[i], m[i], 0, false, true, &discard);
}
- clear_tail(d, opr_sz, simd_maxsz(desc));
}
-void HELPER(gvec_sdot_idx_b)(void *vd, void *vn, void *vm, uint32_t desc)
+void HELPER(sve2_sqdmulh_idx_s)(void *vd, void *vn, void *vm, uint32_t desc)
{
- intptr_t i, segend, opr_sz = simd_oprsz(desc), opr_sz_4 = opr_sz / 4;
- intptr_t index = simd_data(desc);
- uint32_t *d = vd;
- int8_t *n = vn;
- int8_t *m_indexed = (int8_t *)vm + H4(index) * 4;
+ intptr_t i, j, opr_sz = simd_oprsz(desc);
+ int idx = simd_data(desc);
+ int32_t *d = vd, *n = vn, *m = (int32_t *)vm + H4(idx);
+ uint32_t discard;
+
+ for (i = 0; i < opr_sz / 4; i += 16 / 4) {
+ int32_t mm = m[i];
+ for (j = 0; j < 16 / 4; ++j) {
+ d[i + j] = do_sqrdmlah_s(n[i + j], mm, 0, false, false, &discard);
+ }
+ }
+}
- /* Notice the special case of opr_sz == 8, from aa64/aa32 advsimd.
- * Otherwise opr_sz is a multiple of 16.
- */
- segend = MIN(4, opr_sz_4);
- i = 0;
- do {
- int8_t m0 = m_indexed[i * 4 + 0];
- int8_t m1 = m_indexed[i * 4 + 1];
- int8_t m2 = m_indexed[i * 4 + 2];
- int8_t m3 = m_indexed[i * 4 + 3];
-
- do {
- d[i] += n[i * 4 + 0] * m0
- + n[i * 4 + 1] * m1
- + n[i * 4 + 2] * m2
- + n[i * 4 + 3] * m3;
- } while (++i < segend);
- segend = i + 4;
- } while (i < opr_sz_4);
+void HELPER(sve2_sqrdmulh_idx_s)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, j, opr_sz = simd_oprsz(desc);
+ int idx = simd_data(desc);
+ int32_t *d = vd, *n = vn, *m = (int32_t *)vm + H4(idx);
+ uint32_t discard;
+
+ for (i = 0; i < opr_sz / 4; i += 16 / 4) {
+ int32_t mm = m[i];
+ for (j = 0; j < 16 / 4; ++j) {
+ d[i + j] = do_sqrdmlah_s(n[i + j], mm, 0, false, true, &discard);
+ }
+ }
+}
- clear_tail(d, opr_sz, simd_maxsz(desc));
+/* Signed saturating rounding doubling multiply-accumulate high half, 64-bit */
+static int64_t do_sat128_d(Int128 r)
+{
+ int64_t ls = int128_getlo(r);
+ int64_t hs = int128_gethi(r);
+
+ if (unlikely(hs != (ls >> 63))) {
+ return hs < 0 ? INT64_MIN : INT64_MAX;
+ }
+ return ls;
}
-void HELPER(gvec_udot_idx_b)(void *vd, void *vn, void *vm, uint32_t desc)
+int64_t do_sqrdmlah_d(int64_t n, int64_t m, int64_t a, bool neg, bool round)
{
- intptr_t i, segend, opr_sz = simd_oprsz(desc), opr_sz_4 = opr_sz / 4;
- intptr_t index = simd_data(desc);
- uint32_t *d = vd;
- uint8_t *n = vn;
- uint8_t *m_indexed = (uint8_t *)vm + H4(index) * 4;
+ uint64_t l, h;
+ Int128 r, t;
- /* Notice the special case of opr_sz == 8, from aa64/aa32 advsimd.
- * Otherwise opr_sz is a multiple of 16.
- */
- segend = MIN(4, opr_sz_4);
- i = 0;
- do {
- uint8_t m0 = m_indexed[i * 4 + 0];
- uint8_t m1 = m_indexed[i * 4 + 1];
- uint8_t m2 = m_indexed[i * 4 + 2];
- uint8_t m3 = m_indexed[i * 4 + 3];
-
- do {
- d[i] += n[i * 4 + 0] * m0
- + n[i * 4 + 1] * m1
- + n[i * 4 + 2] * m2
- + n[i * 4 + 3] * m3;
- } while (++i < segend);
- segend = i + 4;
- } while (i < opr_sz_4);
+ /* As in do_sqrdmlah_b, but with 128-bit arithmetic. */
+ muls64(&l, &h, m, n);
+ r = int128_make128(l, h);
+ if (neg) {
+ r = int128_neg(r);
+ }
+ if (a) {
+ t = int128_exts64(a);
+ t = int128_lshift(t, 63);
+ r = int128_add(r, t);
+ }
+ if (round) {
+ t = int128_exts64(1ll << 62);
+ r = int128_add(r, t);
+ }
+ r = int128_rshift(r, 63);
- clear_tail(d, opr_sz, simd_maxsz(desc));
+ return do_sat128_d(r);
}
-void HELPER(gvec_sdot_idx_h)(void *vd, void *vn, void *vm, uint32_t desc)
+void HELPER(sve2_sqrdmlah_d)(void *vd, void *vn, void *vm,
+ void *va, uint32_t desc)
{
- intptr_t i, opr_sz = simd_oprsz(desc), opr_sz_8 = opr_sz / 8;
- intptr_t index = simd_data(desc);
- uint64_t *d = vd;
- int16_t *n = vn;
- int16_t *m_indexed = (int16_t *)vm + index * 4;
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int64_t *d = vd, *n = vn, *m = vm, *a = va;
- /* This is supported by SVE only, so opr_sz is always a multiple of 16.
- * Process the entire segment all at once, writing back the results
- * only after we've consumed all of the inputs.
- */
- for (i = 0; i < opr_sz_8 ; i += 2) {
- uint64_t d0, d1;
+ for (i = 0; i < opr_sz / 8; ++i) {
+ d[i] = do_sqrdmlah_d(n[i], m[i], a[i], false, true);
+ }
+}
- d0 = n[i * 4 + 0] * (int64_t)m_indexed[i * 4 + 0];
- d0 += n[i * 4 + 1] * (int64_t)m_indexed[i * 4 + 1];
- d0 += n[i * 4 + 2] * (int64_t)m_indexed[i * 4 + 2];
- d0 += n[i * 4 + 3] * (int64_t)m_indexed[i * 4 + 3];
- d1 = n[i * 4 + 4] * (int64_t)m_indexed[i * 4 + 0];
- d1 += n[i * 4 + 5] * (int64_t)m_indexed[i * 4 + 1];
- d1 += n[i * 4 + 6] * (int64_t)m_indexed[i * 4 + 2];
- d1 += n[i * 4 + 7] * (int64_t)m_indexed[i * 4 + 3];
+void HELPER(sve2_sqrdmlsh_d)(void *vd, void *vn, void *vm,
+ void *va, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int64_t *d = vd, *n = vn, *m = vm, *a = va;
- d[i + 0] += d0;
- d[i + 1] += d1;
+ for (i = 0; i < opr_sz / 8; ++i) {
+ d[i] = do_sqrdmlah_d(n[i], m[i], a[i], true, true);
}
+}
- clear_tail(d, opr_sz, simd_maxsz(desc));
+void HELPER(sve2_sqdmulh_d)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int64_t *d = vd, *n = vn, *m = vm;
+
+ for (i = 0; i < opr_sz / 8; ++i) {
+ d[i] = do_sqrdmlah_d(n[i], m[i], 0, false, false);
+ }
}
-void HELPER(gvec_udot_idx_h)(void *vd, void *vn, void *vm, uint32_t desc)
+void HELPER(sve2_sqrdmulh_d)(void *vd, void *vn, void *vm, uint32_t desc)
{
- intptr_t i, opr_sz = simd_oprsz(desc), opr_sz_8 = opr_sz / 8;
- intptr_t index = simd_data(desc);
- uint64_t *d = vd;
- uint16_t *n = vn;
- uint16_t *m_indexed = (uint16_t *)vm + index * 4;
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int64_t *d = vd, *n = vn, *m = vm;
- /* This is supported by SVE only, so opr_sz is always a multiple of 16.
- * Process the entire segment all at once, writing back the results
- * only after we've consumed all of the inputs.
- */
- for (i = 0; i < opr_sz_8 ; i += 2) {
- uint64_t d0, d1;
+ for (i = 0; i < opr_sz / 8; ++i) {
+ d[i] = do_sqrdmlah_d(n[i], m[i], 0, false, true);
+ }
+}
- d0 = n[i * 4 + 0] * (uint64_t)m_indexed[i * 4 + 0];
- d0 += n[i * 4 + 1] * (uint64_t)m_indexed[i * 4 + 1];
- d0 += n[i * 4 + 2] * (uint64_t)m_indexed[i * 4 + 2];
- d0 += n[i * 4 + 3] * (uint64_t)m_indexed[i * 4 + 3];
- d1 = n[i * 4 + 4] * (uint64_t)m_indexed[i * 4 + 0];
- d1 += n[i * 4 + 5] * (uint64_t)m_indexed[i * 4 + 1];
- d1 += n[i * 4 + 6] * (uint64_t)m_indexed[i * 4 + 2];
- d1 += n[i * 4 + 7] * (uint64_t)m_indexed[i * 4 + 3];
+void HELPER(sve2_sqdmulh_idx_d)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, j, opr_sz = simd_oprsz(desc);
+ int idx = simd_data(desc);
+ int64_t *d = vd, *n = vn, *m = (int64_t *)vm + idx;
- d[i + 0] += d0;
- d[i + 1] += d1;
+ for (i = 0; i < opr_sz / 8; i += 16 / 8) {
+ int64_t mm = m[i];
+ for (j = 0; j < 16 / 8; ++j) {
+ d[i + j] = do_sqrdmlah_d(n[i + j], mm, 0, false, false);
+ }
}
+}
- clear_tail(d, opr_sz, simd_maxsz(desc));
+void HELPER(sve2_sqrdmulh_idx_d)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, j, opr_sz = simd_oprsz(desc);
+ int idx = simd_data(desc);
+ int64_t *d = vd, *n = vn, *m = (int64_t *)vm + idx;
+
+ for (i = 0; i < opr_sz / 8; i += 16 / 8) {
+ int64_t mm = m[i];
+ for (j = 0; j < 16 / 8; ++j) {
+ d[i + j] = do_sqrdmlah_d(n[i + j], mm, 0, false, true);
+ }
+ }
}
+/* Integer 8 and 16-bit dot-product.
+ *
+ * Note that for the loops herein, host endianness does not matter
+ * with respect to the ordering of data within the quad-width lanes.
+ * All elements are treated equally, no matter where they are.
+ */
+
+#define DO_DOT(NAME, TYPED, TYPEN, TYPEM) \
+void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \
+{ \
+ intptr_t i, opr_sz = simd_oprsz(desc); \
+ TYPED *d = vd, *a = va; \
+ TYPEN *n = vn; \
+ TYPEM *m = vm; \
+ for (i = 0; i < opr_sz / sizeof(TYPED); ++i) { \
+ d[i] = (a[i] + \
+ (TYPED)n[i * 4 + 0] * m[i * 4 + 0] + \
+ (TYPED)n[i * 4 + 1] * m[i * 4 + 1] + \
+ (TYPED)n[i * 4 + 2] * m[i * 4 + 2] + \
+ (TYPED)n[i * 4 + 3] * m[i * 4 + 3]); \
+ } \
+ clear_tail(d, opr_sz, simd_maxsz(desc)); \
+}
+
+DO_DOT(gvec_sdot_b, int32_t, int8_t, int8_t)
+DO_DOT(gvec_udot_b, uint32_t, uint8_t, uint8_t)
+DO_DOT(gvec_usdot_b, uint32_t, uint8_t, int8_t)
+DO_DOT(gvec_sdot_h, int64_t, int16_t, int16_t)
+DO_DOT(gvec_udot_h, uint64_t, uint16_t, uint16_t)
+
+#define DO_DOT_IDX(NAME, TYPED, TYPEN, TYPEM, HD) \
+void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \
+{ \
+ intptr_t i = 0, opr_sz = simd_oprsz(desc); \
+ intptr_t opr_sz_n = opr_sz / sizeof(TYPED); \
+ intptr_t segend = MIN(16 / sizeof(TYPED), opr_sz_n); \
+ intptr_t index = simd_data(desc); \
+ TYPED *d = vd, *a = va; \
+ TYPEN *n = vn; \
+ TYPEM *m_indexed = (TYPEM *)vm + HD(index) * 4; \
+ do { \
+ TYPED m0 = m_indexed[i * 4 + 0]; \
+ TYPED m1 = m_indexed[i * 4 + 1]; \
+ TYPED m2 = m_indexed[i * 4 + 2]; \
+ TYPED m3 = m_indexed[i * 4 + 3]; \
+ do { \
+ d[i] = (a[i] + \
+ n[i * 4 + 0] * m0 + \
+ n[i * 4 + 1] * m1 + \
+ n[i * 4 + 2] * m2 + \
+ n[i * 4 + 3] * m3); \
+ } while (++i < segend); \
+ segend = i + 4; \
+ } while (i < opr_sz_n); \
+ clear_tail(d, opr_sz, simd_maxsz(desc)); \
+}
+
+DO_DOT_IDX(gvec_sdot_idx_b, int32_t, int8_t, int8_t, H4)
+DO_DOT_IDX(gvec_udot_idx_b, uint32_t, uint8_t, uint8_t, H4)
+DO_DOT_IDX(gvec_sudot_idx_b, int32_t, int8_t, uint8_t, H4)
+DO_DOT_IDX(gvec_usdot_idx_b, int32_t, uint8_t, int8_t, H4)
+DO_DOT_IDX(gvec_sdot_idx_h, int64_t, int16_t, int16_t, )
+DO_DOT_IDX(gvec_udot_idx_h, uint64_t, uint16_t, uint16_t, )
+
void HELPER(gvec_fcaddh)(void *vd, void *vn, void *vm,
void *vfpst, uint32_t desc)
{
@@ -495,13 +676,11 @@ void HELPER(gvec_fcaddd)(void *vd, void *vn, void *vm,
clear_tail(d, opr_sz, simd_maxsz(desc));
}
-void HELPER(gvec_fcmlah)(void *vd, void *vn, void *vm,
+void HELPER(gvec_fcmlah)(void *vd, void *vn, void *vm, void *va,
void *vfpst, uint32_t desc)
{
uintptr_t opr_sz = simd_oprsz(desc);
- float16 *d = vd;
- float16 *n = vn;
- float16 *m = vm;
+ float16 *d = vd, *n = vn, *m = vm, *a = va;
float_status *fpst = vfpst;
intptr_t flip = extract32(desc, SIMD_DATA_SHIFT, 1);
uint32_t neg_imag = extract32(desc, SIMD_DATA_SHIFT + 1, 1);
@@ -518,19 +697,17 @@ void HELPER(gvec_fcmlah)(void *vd, void *vn, void *vm,
float16 e4 = e2;
float16 e3 = m[H2(i + 1 - flip)] ^ neg_imag;
- d[H2(i)] = float16_muladd(e2, e1, d[H2(i)], 0, fpst);
- d[H2(i + 1)] = float16_muladd(e4, e3, d[H2(i + 1)], 0, fpst);
+ d[H2(i)] = float16_muladd(e2, e1, a[H2(i)], 0, fpst);
+ d[H2(i + 1)] = float16_muladd(e4, e3, a[H2(i + 1)], 0, fpst);
}
clear_tail(d, opr_sz, simd_maxsz(desc));
}
-void HELPER(gvec_fcmlah_idx)(void *vd, void *vn, void *vm,
+void HELPER(gvec_fcmlah_idx)(void *vd, void *vn, void *vm, void *va,
void *vfpst, uint32_t desc)
{
uintptr_t opr_sz = simd_oprsz(desc);
- float16 *d = vd;
- float16 *n = vn;
- float16 *m = vm;
+ float16 *d = vd, *n = vn, *m = vm, *a = va;
float_status *fpst = vfpst;
intptr_t flip = extract32(desc, SIMD_DATA_SHIFT, 1);
uint32_t neg_imag = extract32(desc, SIMD_DATA_SHIFT + 1, 1);
@@ -554,20 +731,18 @@ void HELPER(gvec_fcmlah_idx)(void *vd, void *vn, void *vm,
float16 e2 = n[H2(j + flip)];
float16 e4 = e2;
- d[H2(j)] = float16_muladd(e2, e1, d[H2(j)], 0, fpst);
- d[H2(j + 1)] = float16_muladd(e4, e3, d[H2(j + 1)], 0, fpst);
+ d[H2(j)] = float16_muladd(e2, e1, a[H2(j)], 0, fpst);
+ d[H2(j + 1)] = float16_muladd(e4, e3, a[H2(j + 1)], 0, fpst);
}
}
clear_tail(d, opr_sz, simd_maxsz(desc));
}
-void HELPER(gvec_fcmlas)(void *vd, void *vn, void *vm,
+void HELPER(gvec_fcmlas)(void *vd, void *vn, void *vm, void *va,
void *vfpst, uint32_t desc)
{
uintptr_t opr_sz = simd_oprsz(desc);
- float32 *d = vd;
- float32 *n = vn;
- float32 *m = vm;
+ float32 *d = vd, *n = vn, *m = vm, *a = va;
float_status *fpst = vfpst;
intptr_t flip = extract32(desc, SIMD_DATA_SHIFT, 1);
uint32_t neg_imag = extract32(desc, SIMD_DATA_SHIFT + 1, 1);
@@ -584,19 +759,17 @@ void HELPER(gvec_fcmlas)(void *vd, void *vn, void *vm,
float32 e4 = e2;
float32 e3 = m[H4(i + 1 - flip)] ^ neg_imag;
- d[H4(i)] = float32_muladd(e2, e1, d[H4(i)], 0, fpst);
- d[H4(i + 1)] = float32_muladd(e4, e3, d[H4(i + 1)], 0, fpst);
+ d[H4(i)] = float32_muladd(e2, e1, a[H4(i)], 0, fpst);
+ d[H4(i + 1)] = float32_muladd(e4, e3, a[H4(i + 1)], 0, fpst);
}
clear_tail(d, opr_sz, simd_maxsz(desc));
}
-void HELPER(gvec_fcmlas_idx)(void *vd, void *vn, void *vm,
+void HELPER(gvec_fcmlas_idx)(void *vd, void *vn, void *vm, void *va,
void *vfpst, uint32_t desc)
{
uintptr_t opr_sz = simd_oprsz(desc);
- float32 *d = vd;
- float32 *n = vn;
- float32 *m = vm;
+ float32 *d = vd, *n = vn, *m = vm, *a = va;
float_status *fpst = vfpst;
intptr_t flip = extract32(desc, SIMD_DATA_SHIFT, 1);
uint32_t neg_imag = extract32(desc, SIMD_DATA_SHIFT + 1, 1);
@@ -620,20 +793,18 @@ void HELPER(gvec_fcmlas_idx)(void *vd, void *vn, void *vm,
float32 e2 = n[H4(j + flip)];
float32 e4 = e2;
- d[H4(j)] = float32_muladd(e2, e1, d[H4(j)], 0, fpst);
- d[H4(j + 1)] = float32_muladd(e4, e3, d[H4(j + 1)], 0, fpst);
+ d[H4(j)] = float32_muladd(e2, e1, a[H4(j)], 0, fpst);
+ d[H4(j + 1)] = float32_muladd(e4, e3, a[H4(j + 1)], 0, fpst);
}
}
clear_tail(d, opr_sz, simd_maxsz(desc));
}
-void HELPER(gvec_fcmlad)(void *vd, void *vn, void *vm,
+void HELPER(gvec_fcmlad)(void *vd, void *vn, void *vm, void *va,
void *vfpst, uint32_t desc)
{
uintptr_t opr_sz = simd_oprsz(desc);
- float64 *d = vd;
- float64 *n = vn;
- float64 *m = vm;
+ float64 *d = vd, *n = vn, *m = vm, *a = va;
float_status *fpst = vfpst;
intptr_t flip = extract32(desc, SIMD_DATA_SHIFT, 1);
uint64_t neg_imag = extract32(desc, SIMD_DATA_SHIFT + 1, 1);
@@ -650,8 +821,8 @@ void HELPER(gvec_fcmlad)(void *vd, void *vn, void *vm,
float64 e4 = e2;
float64 e3 = m[i + 1 - flip] ^ neg_imag;
- d[i] = float64_muladd(e2, e1, d[i], 0, fpst);
- d[i + 1] = float64_muladd(e4, e3, d[i + 1], 0, fpst);
+ d[i] = float64_muladd(e2, e1, a[i], 0, fpst);
+ d[i + 1] = float64_muladd(e4, e3, a[i + 1], 0, fpst);
}
clear_tail(d, opr_sz, simd_maxsz(desc));
}
@@ -1497,6 +1668,27 @@ void HELPER(gvec_fmlal_a64)(void *vd, void *vn, void *vm,
get_flush_inputs_to_zero(&env->vfp.fp_status_f16));
}
+void HELPER(sve2_fmlal_zzzw_s)(void *vd, void *vn, void *vm, void *va,
+ void *venv, uint32_t desc)
+{
+ intptr_t i, oprsz = simd_oprsz(desc);
+ uint16_t negn = extract32(desc, SIMD_DATA_SHIFT, 1) << 15;
+ intptr_t sel = extract32(desc, SIMD_DATA_SHIFT + 1, 1) * sizeof(float16);
+ CPUARMState *env = venv;
+ float_status *status = &env->vfp.fp_status;
+ bool fz16 = get_flush_inputs_to_zero(&env->vfp.fp_status_f16);
+
+ for (i = 0; i < oprsz; i += sizeof(float32)) {
+ float16 nn_16 = *(float16 *)(vn + H1_2(i + sel)) ^ negn;
+ float16 mm_16 = *(float16 *)(vm + H1_2(i + sel));
+ float32 nn = float16_to_float32_by_bits(nn_16, fz16);
+ float32 mm = float16_to_float32_by_bits(mm_16, fz16);
+ float32 aa = *(float32 *)(va + H1_4(i));
+
+ *(float32 *)(vd + H1_4(i)) = float32_muladd(nn, mm, aa, 0, status);
+ }
+}
+
static void do_fmlal_idx(float32 *d, void *vn, void *vm, float_status *fpst,
uint32_t desc, bool fz16)
{
@@ -1541,6 +1733,32 @@ void HELPER(gvec_fmlal_idx_a64)(void *vd, void *vn, void *vm,
get_flush_inputs_to_zero(&env->vfp.fp_status_f16));
}
+void HELPER(sve2_fmlal_zzxw_s)(void *vd, void *vn, void *vm, void *va,
+ void *venv, uint32_t desc)
+{
+ intptr_t i, j, oprsz = simd_oprsz(desc);
+ uint16_t negn = extract32(desc, SIMD_DATA_SHIFT, 1) << 15;
+ intptr_t sel = extract32(desc, SIMD_DATA_SHIFT + 1, 1) * sizeof(float16);
+ intptr_t idx = extract32(desc, SIMD_DATA_SHIFT + 2, 3) * sizeof(float16);
+ CPUARMState *env = venv;
+ float_status *status = &env->vfp.fp_status;
+ bool fz16 = get_flush_inputs_to_zero(&env->vfp.fp_status_f16);
+
+ for (i = 0; i < oprsz; i += 16) {
+ float16 mm_16 = *(float16 *)(vm + i + idx);
+ float32 mm = float16_to_float32_by_bits(mm_16, fz16);
+
+ for (j = 0; j < 16; j += sizeof(float32)) {
+ float16 nn_16 = *(float16 *)(vn + H1_2(i + j + sel)) ^ negn;
+ float32 nn = float16_to_float32_by_bits(nn_16, fz16);
+ float32 aa = *(float32 *)(va + H1_4(i + j));
+
+ *(float32 *)(vd + H1_4(i + j)) =
+ float32_muladd(nn, mm, aa, 0, status);
+ }
+ }
+}
+
void HELPER(gvec_sshl_b)(void *vd, void *vn, void *vm, uint32_t desc)
{
intptr_t i, opr_sz = simd_oprsz(desc);
@@ -1750,6 +1968,30 @@ void HELPER(sve2_pmull_h)(void *vd, void *vn, void *vm, uint32_t desc)
d[i] = pmull_h(nn, mm);
}
}
+
+static uint64_t pmull_d(uint64_t op1, uint64_t op2)
+{
+ uint64_t result = 0;
+ int i;
+
+ for (i = 0; i < 32; ++i) {
+ uint64_t mask = -((op1 >> i) & 1);
+ result ^= (op2 << i) & mask;
+ }
+ return result;
+}
+
+void HELPER(sve2_pmull_d)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t sel = H4(simd_data(desc));
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ uint32_t *n = vn, *m = vm;
+ uint64_t *d = vd;
+
+ for (i = 0; i < opr_sz / 8; ++i) {
+ d[i] = pmull_d(n[2 * i + sel], m[2 * i + sel]);
+ }
+}
#endif
#define DO_CMP0(NAME, TYPE, OP) \
@@ -1985,3 +2227,188 @@ void HELPER(simd_tblx)(void *vd, void *vm, void *venv, uint32_t desc)
clear_tail(vd, oprsz, simd_maxsz(desc));
}
#endif
+
+/*
+ * NxN -> N highpart multiply
+ *
+ * TODO: expose this as a generic vector operation.
+ */
+
+void HELPER(gvec_smulh_b)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int8_t *d = vd, *n = vn, *m = vm;
+
+ for (i = 0; i < opr_sz; ++i) {
+ d[i] = ((int32_t)n[i] * m[i]) >> 8;
+ }
+ clear_tail(d, opr_sz, simd_maxsz(desc));
+}
+
+void HELPER(gvec_smulh_h)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int16_t *d = vd, *n = vn, *m = vm;
+
+ for (i = 0; i < opr_sz / 2; ++i) {
+ d[i] = ((int32_t)n[i] * m[i]) >> 16;
+ }
+ clear_tail(d, opr_sz, simd_maxsz(desc));
+}
+
+void HELPER(gvec_smulh_s)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ int32_t *d = vd, *n = vn, *m = vm;
+
+ for (i = 0; i < opr_sz / 4; ++i) {
+ d[i] = ((int64_t)n[i] * m[i]) >> 32;
+ }
+ clear_tail(d, opr_sz, simd_maxsz(desc));
+}
+
+void HELPER(gvec_smulh_d)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ uint64_t *d = vd, *n = vn, *m = vm;
+ uint64_t discard;
+
+ for (i = 0; i < opr_sz / 8; ++i) {
+ muls64(&discard, &d[i], n[i], m[i]);
+ }
+ clear_tail(d, opr_sz, simd_maxsz(desc));
+}
+
+void HELPER(gvec_umulh_b)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ uint8_t *d = vd, *n = vn, *m = vm;
+
+ for (i = 0; i < opr_sz; ++i) {
+ d[i] = ((uint32_t)n[i] * m[i]) >> 8;
+ }
+ clear_tail(d, opr_sz, simd_maxsz(desc));
+}
+
+void HELPER(gvec_umulh_h)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ uint16_t *d = vd, *n = vn, *m = vm;
+
+ for (i = 0; i < opr_sz / 2; ++i) {
+ d[i] = ((uint32_t)n[i] * m[i]) >> 16;
+ }
+ clear_tail(d, opr_sz, simd_maxsz(desc));
+}
+
+void HELPER(gvec_umulh_s)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ uint32_t *d = vd, *n = vn, *m = vm;
+
+ for (i = 0; i < opr_sz / 4; ++i) {
+ d[i] = ((uint64_t)n[i] * m[i]) >> 32;
+ }
+ clear_tail(d, opr_sz, simd_maxsz(desc));
+}
+
+void HELPER(gvec_umulh_d)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc);
+ uint64_t *d = vd, *n = vn, *m = vm;
+ uint64_t discard;
+
+ for (i = 0; i < opr_sz / 8; ++i) {
+ mulu64(&discard, &d[i], n[i], m[i]);
+ }
+ clear_tail(d, opr_sz, simd_maxsz(desc));
+}
+
+void HELPER(gvec_xar_d)(void *vd, void *vn, void *vm, uint32_t desc)
+{
+ intptr_t i, opr_sz = simd_oprsz(desc) / 8;
+ int shr = simd_data(desc);
+ uint64_t *d = vd, *n = vn, *m = vm;
+
+ for (i = 0; i < opr_sz; ++i) {
+ d[i] = ror64(n[i] ^ m[i], shr);
+ }
+ clear_tail(d, opr_sz * 8, simd_maxsz(desc));
+}
+
+/*
+ * Integer matrix-multiply accumulate
+ */
+
+static uint32_t do_smmla_b(uint32_t sum, void *vn, void *vm)
+{
+ int8_t *n = vn, *m = vm;
+
+ for (intptr_t k = 0; k < 8; ++k) {
+ sum += n[H1(k)] * m[H1(k)];
+ }
+ return sum;
+}
+
+static uint32_t do_ummla_b(uint32_t sum, void *vn, void *vm)
+{
+ uint8_t *n = vn, *m = vm;
+
+ for (intptr_t k = 0; k < 8; ++k) {
+ sum += n[H1(k)] * m[H1(k)];
+ }
+ return sum;
+}
+
+static uint32_t do_usmmla_b(uint32_t sum, void *vn, void *vm)
+{
+ uint8_t *n = vn;
+ int8_t *m = vm;
+
+ for (intptr_t k = 0; k < 8; ++k) {
+ sum += n[H1(k)] * m[H1(k)];
+ }
+ return sum;
+}
+
+static void do_mmla_b(void *vd, void *vn, void *vm, void *va, uint32_t desc,
+ uint32_t (*inner_loop)(uint32_t, void *, void *))
+{
+ intptr_t seg, opr_sz = simd_oprsz(desc);
+
+ for (seg = 0; seg < opr_sz; seg += 16) {
+ uint32_t *d = vd + seg;
+ uint32_t *a = va + seg;
+ uint32_t sum0, sum1, sum2, sum3;
+
+ /*
+ * Process the entire segment at once, writing back the
+ * results only after we've consumed all of the inputs.
+ *
+ * Key to indicies by column:
+ * i j i j
+ */
+ sum0 = a[H4(0 + 0)];
+ sum0 = inner_loop(sum0, vn + seg + 0, vm + seg + 0);
+ sum1 = a[H4(0 + 1)];
+ sum1 = inner_loop(sum1, vn + seg + 0, vm + seg + 8);
+ sum2 = a[H4(2 + 0)];
+ sum2 = inner_loop(sum2, vn + seg + 8, vm + seg + 0);
+ sum3 = a[H4(2 + 1)];
+ sum3 = inner_loop(sum3, vn + seg + 8, vm + seg + 8);
+
+ d[H4(0)] = sum0;
+ d[H4(1)] = sum1;
+ d[H4(2)] = sum2;
+ d[H4(3)] = sum3;
+ }
+ clear_tail(vd, opr_sz, simd_maxsz(desc));
+}
+
+#define DO_MMLA_B(NAME, INNER) \
+ void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \
+ { do_mmla_b(vd, vn, vm, va, desc, INNER); }
+
+DO_MMLA_B(gvec_smmla_b, do_smmla_b)
+DO_MMLA_B(gvec_ummla_b, do_ummla_b)
+DO_MMLA_B(gvec_usmmla_b, do_usmmla_b)