LibGfx/JPEGXL: Add support for the self-correcting predictor

This predictor is much more complicated than the others. Indeed, to be
computed, it needs its own value but for other pixels. As you can guess,
implementing it involved the introduction of a structure to hold that
data.

Fundamentally, this predictor uses the value of the error between the
predicted value and the true value (aka decoded value) of pixels around.
One of this computed error (namely max_error) is used as a property, so
this patch also solves a FIXME in `get_properties`.

To ease the access to value that are close in the channel and moving
their values around, this patch adds a `Neighborhood` struct which holds
this data. It has been used in `prediction()` and it allowed to simplify
the signature and to remove the explicit retrieval of the underlying
data.

All this work allows us to decode the default image that appears when
loading `https://jxl-art.surma.technology`. However, we still render it
incorrectly due to the lack of support for orientation values different
from 1.
This commit is contained in:
Lucas CHOLLET 2023-07-30 23:08:54 -04:00 committed by Andreas Kling
parent a01a2f712e
commit cd9bb985d4

View file

@ -1071,10 +1071,7 @@ struct WPHeader {
u8 wp_p3c { 7 };
u8 wp_p3d { 0 };
u8 wp_p3e { 0 };
u8 wp_w0 { 13 };
u8 wp_w1 { 12 };
u8 wp_w2 { 12 };
u8 wp_w3 { 12 };
Array<u8, 4> wp_w { 13, 12, 12, 12 };
};
static ErrorOr<WPHeader> read_self_correcting_predictor(LittleEndianInputBitStream& stream)
@ -1272,6 +1269,177 @@ private:
};
///
/// H.5 - Self-correcting predictor
struct Neighborhood {
i32 N {};
i32 NW {};
i32 NE {};
i32 W {};
i32 NN {};
i32 WW {};
i32 NEE {};
};
class SelfCorrectingData {
public:
struct Predictions {
i32 prediction {};
Array<i32, 4> subpred {};
i32 max_error {};
i32 true_err {};
Array<i32, 4> err {};
};
static ErrorOr<SelfCorrectingData> create(WPHeader const& wp_params, u32 width)
{
SelfCorrectingData self_correcting_data { wp_params };
self_correcting_data.m_width = width;
self_correcting_data.m_previous = TRY(FixedArray<Predictions>::create(width));
self_correcting_data.m_current_row = TRY(FixedArray<Predictions>::create(width));
self_correcting_data.m_next_row = TRY(FixedArray<Predictions>::create(width));
return self_correcting_data;
}
void register_next_row()
{
auto tmp = move(m_previous);
m_previous = move(m_current_row);
m_current_row = move(m_next_row);
// We reuse m_previous to avoid an allocation, no values are kept
// everything will be overridden.
m_next_row = move(tmp);
m_current_row_index++;
}
Predictions compute_predictions(Neighborhood const& neighborhood, u32 x)
{
auto& current_predictions = m_next_row[x];
auto const N3 = neighborhood.N << 3;
auto const NW3 = neighborhood.NW << 3;
auto const NE3 = neighborhood.NE << 3;
auto const W3 = neighborhood.W << 3;
auto const NN3 = neighborhood.NN << 3;
auto const predictions_W = predictions_for(x, Direction::West);
auto const predictions_N = predictions_for(x, Direction::North);
auto const predictions_NE = predictions_for(x, Direction::NorthEast);
auto const predictions_NW = predictions_for(x, Direction::NorthWest);
auto const predictions_WW = predictions_for(x, Direction::WestWest);
current_predictions.subpred[0] = W3 + NE3 - N3;
current_predictions.subpred[1] = N3 - (((predictions_W.true_err + predictions_N.true_err + predictions_NE.true_err) * wp_params.wp_p1) >> 5);
current_predictions.subpred[2] = W3 - (((predictions_W.true_err + predictions_N.true_err + predictions_NW.true_err) * wp_params.wp_p2) >> 5);
current_predictions.subpred[3] = N3 - ((predictions_NW.true_err * wp_params.wp_p3a + predictions_N.true_err * wp_params.wp_p3b + predictions_NE.true_err * wp_params.wp_p3c + (NN3 - N3) * wp_params.wp_p3d + (NW3 - W3) * wp_params.wp_p3e) >> 5);
auto const error2weight = [](i32 err_sum, u8 maxweight) -> i32 {
i32 shift = floor(log2(err_sum + 1)) - 5;
if (shift < 0)
shift = 0;
return 4 + ((static_cast<u64>(maxweight) * ((1 << 24) / ((err_sum >> shift) + 1))) >> shift);
};
Array<i32, 4> weight {};
for (u8 i = 0; i < weight.size(); ++i) {
auto err_sum = predictions_N.err[i] + predictions_W.err[i] + predictions_NW.err[i] + predictions_WW.err[i] + predictions_NE.err[i];
if (x == m_width - 1)
err_sum += predictions_W.err[i];
weight[i] = error2weight(err_sum, wp_params.wp_w[i]);
}
auto sum_weights = weight[0] + weight[1] + weight[2] + weight[3];
i32 const log_weight = floor(log2(sum_weights)) + 1;
for (u8 i = 0; i < 4; i++)
weight[i] = weight[i] >> (log_weight - 5);
sum_weights = weight[0] + weight[1] + weight[2] + weight[3];
auto s = (sum_weights >> 1) - 1;
for (u8 i = 0; i < 4; i++)
s += current_predictions.subpred[i] * weight[i];
current_predictions.prediction = static_cast<u64>(s) * ((1 << 24) / sum_weights) >> 24;
// if true_err_N, true_err_W and true_err_NW don't have the same sign
if (((predictions_N.true_err ^ predictions_W.true_err) | (predictions_N.true_err ^ predictions_NW.true_err)) <= 0) {
current_predictions.prediction = clamp(current_predictions.prediction, min(W3, min(N3, NE3)), max(W3, max(N3, NE3)));
}
auto& max_error = current_predictions.max_error;
max_error = predictions_W.true_err;
if (abs(predictions_N.true_err) > abs(max_error))
max_error = predictions_N.true_err;
if (abs(predictions_NW.true_err) > abs(max_error))
max_error = predictions_NW.true_err;
if (abs(predictions_NE.true_err) > abs(max_error))
max_error = predictions_NE.true_err;
return current_predictions;
}
// H.5.1 - General
void compute_errors(u32 x, i32 true_value)
{
auto& current_predictions = m_next_row[x];
current_predictions.true_err = current_predictions.prediction - (true_value << 3);
for (u8 i = 0; i < 4; ++i)
current_predictions.err[i] = (abs(current_predictions.subpred[i] - (true_value << 3)) + 3) >> 3;
}
private:
SelfCorrectingData(WPHeader const& wp)
: wp_params(wp)
{
}
enum class Direction {
North,
NorthWest,
NorthEast,
West,
NorthNorth,
WestWest
};
Predictions predictions_for(u32 x, Direction direction) const
{
// H.5.2 - Prediction
auto const north = [&]() {
return m_current_row_index < 1 ? Predictions {} : m_current_row[x];
};
switch (direction) {
case Direction::North:
return north();
case Direction::NorthWest:
return x < 1 ? north() : m_current_row[x - 1];
case Direction::NorthEast:
return x + 1 >= m_current_row.size() ? north() : m_current_row[x + 1];
case Direction::West:
return x < 1 ? Predictions {} : m_next_row[x - 1];
case Direction::NorthNorth:
return m_current_row_index < 2 ? Predictions {} : m_previous[x];
case Direction::WestWest:
return x < 2 ? Predictions {} : m_next_row[x - 2];
}
VERIFY_NOT_REACHED();
}
WPHeader const& wp_params {};
u32 m_width {};
u32 m_current_row_index {};
FixedArray<Predictions> m_previous {};
FixedArray<Predictions> m_current_row {};
FixedArray<Predictions> m_next_row {};
};
///
/// H.2 - Image decoding
struct ModularHeader {
bool use_global_tree {};
@ -1279,7 +1447,7 @@ struct ModularHeader {
Vector<TransformInfo> transform {};
};
static ErrorOr<Vector<i32>> get_properties(Vector<Channel> const& channels, u16 i, u32 x, u32 y)
static ErrorOr<Vector<i32>> get_properties(Vector<Channel> const& channels, u16 i, u32 x, u32 y, i32 max_error)
{
Vector<i32> properties;
@ -1320,8 +1488,7 @@ static ErrorOr<Vector<i32>> get_properties(Vector<Channel> const& channels, u16
TRY(properties.try_append(N - NN));
TRY(properties.try_append(W - WW));
// FIXME: Correctly compute max_error
TRY(properties.try_append(0));
TRY(properties.try_append(max_error));
for (i16 j = i - 1; j >= 0; j--) {
if (channels[j].width() != channels[i].width())
@ -1345,48 +1512,62 @@ static ErrorOr<Vector<i32>> get_properties(Vector<Channel> const& channels, u16
return properties;
}
static i32 prediction(Channel const& channel, u32 x, u32 y, u32 predictor)
static i32 prediction(Neighborhood const& neighborhood, i32 self_correcting, u32 predictor)
{
switch (predictor) {
case 0:
return 0;
case 1:
return neighborhood.W;
case 2:
return neighborhood.N;
case 3:
return (neighborhood.W + neighborhood.N) / 2;
case 4:
return abs(neighborhood.N - neighborhood.NW) < abs(neighborhood.W - neighborhood.NW) ? neighborhood.W : neighborhood.N;
case 5:
return clamp(neighborhood.W + neighborhood.N - neighborhood.NW, min(neighborhood.W, neighborhood.N), max(neighborhood.W, neighborhood.N));
case 6:
return (self_correcting + 3) >> 3;
case 7:
return neighborhood.NE;
case 8:
return neighborhood.NW;
case 9:
return neighborhood.WW;
case 10:
return (neighborhood.W + neighborhood.NW) / 2;
case 11:
return (neighborhood.N + neighborhood.NW) / 2;
case 12:
return (neighborhood.N + neighborhood.NE) / 2;
case 13:
return (6 * neighborhood.N - 2 * neighborhood.NN + 7 * neighborhood.W + neighborhood.WW + neighborhood.NEE + 3 * neighborhood.NE + 8) / 16;
}
VERIFY_NOT_REACHED();
}
static Neighborhood retrieve_neighborhood(Channel const& channel, u32 x, u32 y)
{
i32 const W = x > 0 ? channel.get(x - 1, y) : (y > 0 ? channel.get(x, y - 1) : 0);
i32 const N = y > 0 ? channel.get(x, y - 1) : W;
i32 const NW = x > 0 && y > 0 ? channel.get(x - 1, y - 1) : W;
i32 const NE = x + 1 < channel.width() && y > 0 ? channel.get(x + 1, y - 1) : N;
i32 const NN = y > 1 ? channel.get(x, y - 2) : N;
i32 const NEE = x + 2 < channel.width() and y > 0 ? channel.get(x + 2, y - 1) : NE;
i32 const WW = x > 1 ? channel.get(x - 2, y) : W;
i32 const NEE = x + 2 < channel.width() && y > 0 ? channel.get(x + 2, y - 1) : NE;
switch (predictor) {
case 0:
return 0;
case 1:
return W;
case 2:
return N;
case 3:
return (W + N) / 2;
case 4:
return abs(N - NW) < abs(W - NW) ? W : N;
case 5:
return clamp(W + N - NW, min(W, N), max(W, N));
case 6:
TODO();
return (0 + 3) >> 3;
case 7:
return NE;
case 8:
return NW;
case 9:
return WW;
case 10:
return (W + NW) / 2;
case 11:
return (N + NW) / 2;
case 12:
return (N + NE) / 2;
case 13:
return (6 * N - 2 * NN + 7 * W + WW + NEE + 3 * NE + 8) / 16;
}
VERIFY_NOT_REACHED();
Neighborhood const neighborhood {
.N = N,
.NW = NW,
.NE = NE,
.W = W,
.NN = NN,
.WW = WW,
.NEE = NEE,
};
return neighborhood;
}
static ErrorOr<ModularHeader> read_modular_header(LittleEndianInputBitStream& stream,
@ -1415,17 +1596,26 @@ static ErrorOr<ModularHeader> read_modular_header(LittleEndianInputBitStream& st
auto const& tree = local_tree.has_value() ? *local_tree : global_tree;
for (u16 i {}; i < num_channels; ++i) {
auto self_correcting_data = TRY(SelfCorrectingData::create(modular_header.wp_params, image.channels()[i].width()));
for (u32 y {}; y < image.channels()[i].height(); y++) {
for (u32 x {}; x < image.channels()[i].width(); x++) {
auto const neighborhood = retrieve_neighborhood(image.channels()[i], x, y);
auto const properties = TRY(get_properties(image.channels(), i, x, y));
auto const self_prediction = self_correcting_data.compute_predictions(neighborhood, x);
auto const properties = TRY(get_properties(image.channels(), i, x, y, self_prediction.max_error));
auto const leaf_node = tree.get_leaf(properties);
auto diff = unpack_signed(TRY(decoder->decode_hybrid_uint(stream, leaf_node.ctx)));
diff = (diff * leaf_node.multiplier) + leaf_node.offset;
auto const total = diff + prediction(image.channels()[i], x, y, leaf_node.predictor);
auto const total = diff + prediction(neighborhood, self_prediction.prediction, leaf_node.predictor);
self_correcting_data.compute_errors(x, total);
image.channels()[i].set(x, y, total);
}
self_correcting_data.register_next_row();
}
image.channels()[i].set_decoded(true);