Add various SPU patterns

This commit is contained in:
RipleyTom 2023-05-11 20:36:39 +02:00 committed by Elad.Ash
parent a92b8acba7
commit 65d93c97ea
2 changed files with 293 additions and 31 deletions

View file

@ -43,6 +43,9 @@
#include <functional>
#include <unordered_map>
// Helper function
llvm::Value* peek_through_bitcasts(llvm::Value*);
enum class i2 : char
{
};
@ -147,7 +150,7 @@ struct llvm_value_t
std::tuple<> match(llvm::Value*& value, llvm::Module*) const
{
if (value != this->value)
if (peek_through_bitcasts(value) != peek_through_bitcasts(this->value))
{
value = nullptr;
}
@ -503,9 +506,6 @@ using llvm_common_t = std::enable_if_t<(is_llvm_expr_of<T, Types>::ok && ...), t
template <typename... Args>
using llvm_match_tuple = decltype(std::tuple_cat(std::declval<llvm_expr_t<Args>&>().match(std::declval<llvm::Value*&>(), nullptr)...));
// Helper function
llvm::Value* peek_through_bitcasts(llvm::Value*);
template <typename T, typename U = llvm_common_t<llvm_value_t<T>>>
struct llvm_match_t
{
@ -532,7 +532,7 @@ struct llvm_match_t
std::tuple<> match(llvm::Value*& value, llvm::Module*) const
{
if (value != this->value)
if (peek_through_bitcasts(value) != peek_through_bitcasts(this->value))
{
value = nullptr;
}

View file

@ -5615,14 +5615,35 @@ public:
}
});
const auto [a, b] = get_vrs<f32[4]>(op.ra, op.rb);
if (op.ra == op.rb && !m_interp_magn)
{
const auto a = get_vr<f32[4]>(op.ra);
set_vr(op.rt, fm(a, a));
return;
}
const auto [a, b] = get_vrs<f32[4]>(op.ra, op.rb);
// Resistance 2 doesn't like this
if (g_cfg.core.spu_xfloat_accuracy == xfloat_accuracy::relaxed)
{
// FM(a, re_accurate(div))
if (const auto [ok_re_acc, div] = match_expr(b, re_accurate(match<f32[4]>())); ok_re_acc)
{
erase_stores(b);
set_vr(op.rt, a / div);
return;
}
// FM(re_accurate(div), b)
if (const auto [ok_re_acc, div] = match_expr(a, re_accurate(match<f32[4]>())); ok_re_acc)
{
erase_stores(a);
set_vr(op.rt, b / div);
return;
}
}
set_vr(op.rt, fm(a, b));
}
@ -5950,6 +5971,12 @@ public:
return {"spu_fma", {std::forward<T>(a), std::forward<U>(b), std::forward<V>(c)}};
}
template <typename T>
static llvm_calli<f32[4], T> re_accurate(T&& a)
{
return {"spu_re_acc", {std::forward<T>(a)}};
}
void FMA(spu_opcode_t op)
{
// Hardware FMA produces the same result as multiple + add on the limited double range (xfloat).
@ -5980,39 +6007,208 @@ public:
}
});
const auto [a, b, c] = get_vrs<f32[4]>(op.ra, op.rb, op.rc);
register_intrinsic("spu_re_acc", [&](llvm::CallInst* ci)
{
const auto div = value<f32[4]>(ci->getOperand(0));
return fsplat<f32[4]>(1.0f) / div;
});
const auto [a, b, c] = get_vrs<f32[4]>(op.ra, op.rb, op.rc);
static const auto MT = match<f32[4]>();
// Match sqrt
if (auto [ok_fnma, a1, b1] = match_expr(a, fnms(MT, MT, fsplat<f32[4]>(1.00000011920928955078125))); ok_fnma)
auto check_sqrt_pattern_for_float = [&](f32 float_value) -> bool
{
if (auto [ok_fm2, a2] = match_expr(b, fm(MT, fsplat<f32[4]>(0.5))); ok_fm2 && a2.eq(b1))
auto match_fnms = [&](f32 float_value)
{
if (auto [ok_fm1, a3, b3] = match_expr(c, fm(MT, MT)); ok_fm1 && a3.eq(a1))
auto res = match_expr(a, fnms(MT, MT, fsplat<f32[4]>(float_value)));
if (std::get<0>(res))
return res;
return match_expr(b, fnms(MT, MT, fsplat<f32[4]>(float_value)));
};
auto match_fm_half = [&]()
{
auto res = match_expr(a, fm(MT, fsplat<f32[4]>(0.5)));
if (std::get<0>(res))
return res;
res = match_expr(a, fm(fsplat<f32[4]>(0.5), MT));
if (std::get<0>(res))
return res;
res = match_expr(b, fm(MT, fsplat<f32[4]>(0.5)));
if (std::get<0>(res))
return res;
return match_expr(b, fm(fsplat<f32[4]>(0.5), MT));
};
if (auto [ok_fnma, a1, b1] = match_fnms(float_value); ok_fnma)
{
if (auto [ok_fm2, fm_half_mul] = match_fm_half(); ok_fm2 && fm_half_mul.eq(b1))
{
if (auto [ok_sqrte, src] = match_expr(a3, spu_rsqrte(MT)); ok_sqrte && src.eq(b3))
if (fm_half_mul.eq(b1))
{
erase_stores(a, b, c, a3);
set_vr(op.rt4, fsqrt(fabs(src)));
if (auto [ok_fm1, a3, b3] = match_expr(c, fm(MT, MT)); ok_fm1 && a3.eq(a1))
{
if (auto [ok_sqrte, src] = match_expr(a3, spu_rsqrte(MT)); ok_sqrte && src.eq(b3))
{
erase_stores(a, b, c, a3);
set_vr(op.rt4, fsqrt(fabs(src)));
return true;
}
}
else if (auto [ok_fm1, a3, b3] = match_expr(c, fm(MT, MT)); ok_fm1 && b3.eq(a1))
{
if (auto [ok_sqrte, src] = match_expr(b3, spu_rsqrte(MT)); ok_sqrte && src.eq(a3))
{
erase_stores(a, b, c, b3);
set_vr(op.rt4, fsqrt(fabs(src)));
return true;
}
}
}
else if (fm_half_mul.eq(a1))
{
if (auto [ok_fm1, a3, b3] = match_expr(c, fm(MT, MT)); ok_fm1 && a3.eq(b1))
{
if (auto [ok_sqrte, src] = match_expr(a3, spu_rsqrte(MT)); ok_sqrte && src.eq(b3))
{
erase_stores(a, b, c, a3);
set_vr(op.rt4, fsqrt(fabs(src)));
return true;
}
}
else if (auto [ok_fm1, a3, b3] = match_expr(c, fm(MT, MT)); ok_fm1 && b3.eq(b1))
{
if (auto [ok_sqrte, src] = match_expr(b3, spu_rsqrte(MT)); ok_sqrte && src.eq(a3))
{
erase_stores(a, b, c, b3);
set_vr(op.rt4, fsqrt(fabs(src)));
return true;
}
}
}
}
}
return false;
};
if (check_sqrt_pattern_for_float(1.0f))
return;
if (check_sqrt_pattern_for_float(std::bit_cast<f32>(std::bit_cast<u32>(1.0f) + 1)))
return;
auto check_accurate_reciprocal_pattern_for_float = [&](f32 float_value) -> bool
{
// FMA(FNMS(div, spu_re(div), float_value), spu_re(div), spu_re(div))
if (auto [ok_fnms, div] = match_expr(a, fnms(MT, b, fsplat<f32[4]>(float_value))); ok_fnms && op.rb == op.rc)
{
if (auto [ok_re] = match_expr(b, spu_re(div)); ok_re)
{
erase_stores(b);
set_vr(op.rt4, re_accurate(div));
return true;
}
}
// FMA(FNMS(spu_re(div), div, float_value), spu_re(div), spu_re(div))
if (auto [ok_fnms, div] = match_expr(a, fnms(b, MT, fsplat<f32[4]>(float_value))); ok_fnms && op.rb == op.rc)
{
if (auto [ok_re] = match_expr(b, spu_re(div)); ok_re)
{
erase_stores(b);
set_vr(op.rt4, re_accurate(div));
return true;
}
}
// FMA(spu_re(div), FNMS(div, spu_re(div), float_value), spu_re(div))
if (auto [ok_fnms, div] = match_expr(a, fnms(MT, a, fsplat<f32[4]>(float_value))); ok_fnms && op.ra == op.rc)
{
if (auto [ok_re] = match_expr(a, spu_re(div)); ok_re)
{
erase_stores(a);
set_vr(op.rt4, re_accurate(div));
return true;
}
}
// FMA(spu_re(div), FNMS(spu_re(div), div, float_value), spu_re(div))
if (auto [ok_fnms, div] = match_expr(a, fnms(a, MT, fsplat<f32[4]>(float_value))); ok_fnms && op.ra == op.rc)
{
if (auto [ok_re] = match_expr(a, spu_re(div)); ok_re)
{
erase_stores(a);
set_vr(op.rt4, re_accurate(div));
return true;
}
}
return false;
};
if (check_accurate_reciprocal_pattern_for_float(1.0f))
return;
if (check_accurate_reciprocal_pattern_for_float(std::bit_cast<f32>(std::bit_cast<u32>(1.0f) + 1)))
return;
// NFS Most Wanted doesn't like this
if (g_cfg.core.spu_xfloat_accuracy == xfloat_accuracy::relaxed)
{
// Those patterns are not safe vs non optimization as inaccuracy from spu_re will spread with early fm before the accuracy is improved
// Match division (fast)
// FMA(FNMS(fm(diva<*> spu_re(divb)), divb, diva), spu_re(divb), fm(diva<*> spu_re(divb)))
if (auto [ok_fnma, divb, diva] = match_expr(a, fnms(c, MT, MT)); ok_fnma)
{
if (auto [ok_fm, fm1, fm2] = match_expr(c, fm(MT, MT)); ok_fm && ((fm1.eq(diva) && fm2.eq(b)) || (fm1.eq(b) && fm2.eq(diva))))
{
if (auto [ok_re] = match_expr(b, spu_re(divb)); ok_re)
{
erase_stores(b, c);
set_vr(op.rt4, diva / divb);
return;
}
}
}
// FMA(spu_re(divb), FNMS(fm(diva <*> spu_re(divb)), divb, diva), fm(diva <*> spu_re(divb)))
if (auto [ok_fnma, divb, diva] = match_expr(b, fnms(c, MT, MT)); ok_fnma)
{
if (auto [ok_fm, fm1, fm2] = match_expr(c, fm(MT, MT)); ok_fm && ((fm1.eq(diva) && fm2.eq(a)) || (fm1.eq(a) && fm2.eq(diva))))
{
if (auto [ok_re] = match_expr(a, spu_re(divb)); ok_re)
{
erase_stores(a, c);
set_vr(op.rt4, diva / divb);
return;
}
}
}
}
// Match division (fast)
if (auto [ok_fnma, divb, diva] = match_expr(a, fnms(c, MT, MT)); ok_fnma)
// Not all patterns can be simplified because of block scope
// Those todos don't necessarily imply a missing pattern
if (auto [ok_re, mystery] = match_expr(a, spu_re(MT)); ok_re)
{
if (auto [ok_fm] = match_expr(c, fm(diva, b)); ok_fm)
{
if (auto [ok_re] = match_expr(b, spu_re(divb)); ok_re)
{
erase_stores(b, c);
set_vr(op.rt4, diva / divb);
return;
}
}
spu_log.todo("[%s:0x%05x] Unmatched spu_re(a) found in FMA", m_hash, m_pos);
}
if (auto [ok_re, mystery] = match_expr(b, spu_re(MT)); ok_re)
{
spu_log.todo("[%s:0x%05x] Unmatched spu_re(b) found in FMA", m_hash, m_pos);
}
if (auto [ok_resq, mystery] = match_expr(c, spu_rsqrte(MT)); ok_resq)
{
spu_log.todo("[%s:0x%05x] Unmatched spu_rsqrte(c) found in FMA", m_hash, m_pos);
}
set_vr(op.rt4, fma(a, b, c));
@ -6090,7 +6286,67 @@ public:
const auto [a, b] = get_vrs<f32[4]>(op.ra, op.rb);
if (g_cfg.core.spu_xfloat_accuracy == xfloat_accuracy::relaxed)
switch (g_cfg.core.spu_xfloat_accuracy)
{
case xfloat_accuracy::approximate:
{
// For approximate, create a pattern but do not optimize yet
register_intrinsic("spu_re", [&](llvm::CallInst* ci)
{
const auto a = bitcast<u32[4]>(value<f32[4]>(ci->getOperand(0)));
const auto a_fraction = (a >> splat<u32[4]>(18)) & splat<u32[4]>(0x1F);
const auto a_exponent = (a >> splat<u32[4]>(23)) & splat<u32[4]>(0xFF);
const auto a_sign = (a & splat<u32[4]>(0x80000000));
value_t<u32[4]> b = eval(splat<u32[4]>(0));
for (u32 i = 0; i < 4; i++)
{
const auto eval_fraction = eval(extract(a_fraction, i));
const auto eval_exponent = eval(extract(a_exponent, i));
const auto eval_sign = eval(extract(a_sign, i));
value_t<u32> r_fraction = load_const<u32>(m_spu_frest_fraction_lut, eval_fraction);
value_t<u32> r_exponent = load_const<u32>(m_spu_frest_exponent_lut, eval_exponent);
b = eval(insert(b, i, eval(r_fraction | eval_sign | r_exponent)));
}
const auto base = (b & 0x007ffc00u) << 9; // Base fraction
const auto ymul = (b & 0x3ff) * (a & 0x7ffff); // Step fraction * Y fraction (fixed point at 2^-32)
const auto comparison = (ymul > base); // Should exponent be adjusted?
const auto bnew = (base - ymul) >> (zext<u32[4]>(comparison) ^ 9); // Shift one less bit if exponent is adjusted
const auto base_result = (b & 0xff800000u) | (bnew & ~0xff800000u); // Inject old sign and exponent
const auto adjustment = bitcast<u32[4]>(sext<s32[4]>(comparison)) & (1 << 23); // exponent adjustement for negative bnew
return bitcast<f32[4]>(base_result - adjustment);
});
register_intrinsic("spu_rsqrte", [&](llvm::CallInst* ci)
{
const auto a = bitcast<u32[4]>(value<f32[4]>(ci->getOperand(0)));
const auto a_fraction = (a >> splat<u32[4]>(18)) & splat<u32[4]>(0x3F);
const auto a_exponent = (a >> splat<u32[4]>(23)) & splat<u32[4]>(0xFF);
value_t<u32[4]> b = eval(splat<u32[4]>(0));
for (u32 i = 0; i < 4; i++)
{
const auto eval_fraction = eval(extract(a_fraction, i));
const auto eval_exponent = eval(extract(a_exponent, i));
value_t<u32> r_fraction = load_const<u32>(m_spu_frsqest_fraction_lut, eval_fraction);
value_t<u32> r_exponent = load_const<u32>(m_spu_frsqest_exponent_lut, eval_exponent);
b = eval(insert(b, i, eval(r_fraction | r_exponent)));
}
const auto base = (b & 0x007ffc00u) << 9; // Base fraction
const auto ymul = (b & 0x3ff) * (a & 0x7ffff); // Step fraction * Y fraction (fixed point at 2^-32)
const auto comparison = (ymul > base); // Should exponent be adjusted?
const auto bnew = (base - ymul) >> (zext<u32[4]>(comparison) ^ 9); // Shift one less bit if exponent is adjusted
const auto base_result = (b & 0xff800000u) | (bnew & ~0xff800000u); // Inject old sign and exponent
const auto adjustment = bitcast<u32[4]>(sext<s32[4]>(comparison)) & (1 << 23); // exponent adjustement for negative bnew
return bitcast<f32[4]>(base_result - adjustment);
});
break;
}
case xfloat_accuracy::relaxed:
{
// For relaxed, agressively optimize and use intrinsics, those make the results vary per cpu
register_intrinsic("spu_re", [&](llvm::CallInst* ci)
@ -6104,7 +6360,15 @@ public:
const auto a = value<f32[4]>(ci->getOperand(0));
return frsqe(a);
});
break;
}
default:
break;
}
// Do not pattern match for accurate
if(g_cfg.core.spu_xfloat_accuracy == xfloat_accuracy::approximate || g_cfg.core.spu_xfloat_accuracy == xfloat_accuracy::relaxed)
{
if (const auto [ok, mb] = match_expr(b, frest(match<f32[4]>())); ok && mb.eq(a))
{
erase_stores(b);
@ -6120,11 +6384,9 @@ public:
}
}
// Do not optimize yet for approximate until we have a full accuracy sequence
const auto r = eval(fi(a, b));
// if (!m_interp_magn)
// spu_log.todo("[%s:0x%05x] Unmatched spu_fi found", m_hash, m_pos);
if (!m_interp_magn && g_cfg.core.spu_xfloat_accuracy != xfloat_accuracy::accurate)
spu_log.todo("[%s:0x%05x] Unmatched spu_fi found", m_hash, m_pos);
set_vr(op.rt, r);
}