LibRegex: Add a basic optimization pass

This currently tries to convert forking loops to atomic groups, and
unify the left side of alternations.
This commit is contained in:
Ali Mohammad Pur 2021-09-12 17:30:27 +04:30 committed by Ali Mohammad Pur
parent 913382734c
commit 246ab432ff
9 changed files with 677 additions and 24 deletions

View file

@ -8,6 +8,7 @@
#include <LibTest/TestCase.h> // import first, to prevent warning of VERIFY* redefinition
#include <AK/StringBuilder.h>
#include <AK/Tuple.h>
#include <LibRegex/Regex.h>
#include <LibRegex/RegexDebug.h>
#include <stdio.h>
@ -887,3 +888,21 @@ BENCHMARK_CASE(fork_performance)
auto result = re.match(g_lots_of_a_s);
EXPECT_EQ(result.success, true);
}
TEST_CASE(optimizer_atomic_groups)
{
Array tests {
// Fork -> ForkReplace
Tuple { "a*b"sv, "aaaaa"sv, false },
Tuple { "a+b"sv, "aaaaa"sv, false },
// Alternative fuse
Tuple { "(abcfoo|abcbar|abcbaz).*x"sv, "abcbarx"sv, true },
Tuple { "(a|a)"sv, "a"sv, true },
};
for (auto& test : tests) {
Regex<ECMA262> re(test.get<0>());
auto result = re.match(test.get<1>());
EXPECT_EQ(result.success, test.get<2>());
}
}

View file

@ -3,6 +3,7 @@ set(SOURCES
RegexByteCode.cpp
RegexLexer.cpp
RegexMatcher.cpp
RegexOptimizer.cpp
RegexParser.cpp
)

View file

@ -245,12 +245,26 @@ ALWAYS_INLINE ExecutionResult OpCode_ForkJump::execute(MatchInput const&, MatchS
return ExecutionResult::Fork_PrioHigh;
}
ALWAYS_INLINE ExecutionResult OpCode_ForkReplaceJump::execute(MatchInput const& input, MatchState& state) const
{
state.fork_at_position = state.instruction_position + size() + offset();
input.fork_to_replace = state.instruction_position;
return ExecutionResult::Fork_PrioHigh;
}
ALWAYS_INLINE ExecutionResult OpCode_ForkStay::execute(MatchInput const&, MatchState& state) const
{
state.fork_at_position = state.instruction_position + size() + offset();
return ExecutionResult::Fork_PrioLow;
}
ALWAYS_INLINE ExecutionResult OpCode_ForkReplaceStay::execute(MatchInput const& input, MatchState& state) const
{
state.fork_at_position = state.instruction_position + size() + offset();
input.fork_to_replace = state.instruction_position;
return ExecutionResult::Fork_PrioLow;
}
ALWAYS_INLINE ExecutionResult OpCode_CheckBegin::execute(MatchInput const& input, MatchState& state) const
{
if (0 == state.string_position && (input.regex_options & AllFlags::MatchNotBeginOfLine))
@ -778,6 +792,40 @@ String const OpCode_Compare::arguments_string() const
return String::formatted("argc={}, args={} ", arguments_count(), arguments_size());
}
Vector<CompareTypeAndValuePair> OpCode_Compare::flat_compares() const
{
Vector<CompareTypeAndValuePair> result;
size_t offset { state().instruction_position + 3 };
for (size_t i = 0; i < arguments_count(); ++i) {
auto compare_type = (CharacterCompareType)m_bytecode->at(offset++);
if (compare_type == CharacterCompareType::Char) {
auto ch = m_bytecode->at(offset++);
result.append({ compare_type, ch });
} else if (compare_type == CharacterCompareType::Reference) {
auto ref = m_bytecode->at(offset++);
result.append({ compare_type, ref });
} else if (compare_type == CharacterCompareType::String) {
auto& length = m_bytecode->at(offset++);
if (length > 0)
result.append({ compare_type, m_bytecode->at(offset) });
StringBuilder str_builder;
offset += length;
} else if (compare_type == CharacterCompareType::CharClass) {
auto character_class = m_bytecode->at(offset++);
result.append({ compare_type, character_class });
} else if (compare_type == CharacterCompareType::CharRange) {
auto value = m_bytecode->at(offset++);
result.append({ compare_type, value });
} else {
result.append({ compare_type, 0 });
}
}
return result;
}
Vector<String> const OpCode_Compare::variable_arguments_to_string(Optional<MatchInput> input) const
{
Vector<String> result;
@ -834,7 +882,7 @@ Vector<String> const OpCode_Compare::variable_arguments_to_string(Optional<Match
input.value().view.substring_view(string_start_offset, state().string_position > view.length() ? 0 : 1).to_string()));
} else if (compare_type == CharacterCompareType::CharRange) {
auto value = (CharRange)m_bytecode->at(offset++);
result.empend(String::formatted("ch_range='{:c}'-'{:c}'", value.from, value.to));
result.empend(String::formatted("ch_range={:x}-{:x}", value.from, value.to));
if (!view.is_null() && view.length() > state().string_position)
result.empend(String::formatted(
"compare against: '{}'",
@ -896,6 +944,16 @@ ALWAYS_INLINE ExecutionResult OpCode_JumpNonEmpty::execute(MatchInput const& inp
if (form == OpCodeId::ForkStay)
return ExecutionResult::Fork_PrioLow;
if (form == OpCodeId::ForkReplaceStay) {
input.fork_to_replace = state.instruction_position;
return ExecutionResult::Fork_PrioLow;
}
if (form == OpCodeId::ForkReplaceJump) {
input.fork_to_replace = state.instruction_position;
return ExecutionResult::Fork_PrioHigh;
}
}
return ExecutionResult::Continue;

View file

@ -6,6 +6,7 @@
#pragma once
#include "RegexBytecodeStreamOptimizer.h"
#include "RegexMatch.h"
#include "RegexOptions.h"
@ -30,6 +31,8 @@ using ByteCodeValueType = u64;
__ENUMERATE_OPCODE(JumpNonEmpty) \
__ENUMERATE_OPCODE(ForkJump) \
__ENUMERATE_OPCODE(ForkStay) \
__ENUMERATE_OPCODE(ForkReplaceJump) \
__ENUMERATE_OPCODE(ForkReplaceStay) \
__ENUMERATE_OPCODE(FailForks) \
__ENUMERATE_OPCODE(SaveLeftCaptureGroup) \
__ENUMERATE_OPCODE(SaveRightCaptureGroup) \
@ -306,7 +309,7 @@ public:
VERIFY_NOT_REACHED();
}
void insert_bytecode_alternation(ByteCode&& left, ByteCode&& right)
void insert_bytecode_alternation(ByteCode left, ByteCode right)
{
// FORKJUMP _ALT
@ -316,21 +319,8 @@ public:
// REGEXP ALT1
// LABEL _END
ByteCode byte_code;
empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
empend(right.size() + 2); // Jump to the _ALT label
extend(right);
empend(static_cast<ByteCodeValueType>(OpCodeId::Jump));
empend(left.size()); // Jump to the _END label
// LABEL _ALT = bytecode.size() + 2
extend(left);
// LABEL _END = alterantive_bytecode.size
// Optimisation: Eliminate extra work by unifying common pre-and-postfix exprs.
Optimizer::append_alternation(*this, left, right);
}
template<typename T>
@ -625,7 +615,7 @@ public:
}
};
class OpCode_ForkJump final : public OpCode {
class OpCode_ForkJump : public OpCode {
public:
ExecutionResult execute(MatchInput const& input, MatchState& state) const override;
ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::ForkJump; }
@ -637,7 +627,13 @@ public:
}
};
class OpCode_ForkStay final : public OpCode {
class OpCode_ForkReplaceJump final : public OpCode_ForkJump {
public:
ExecutionResult execute(MatchInput const& input, MatchState& state) const override;
ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::ForkReplaceJump; }
};
class OpCode_ForkStay : public OpCode {
public:
ExecutionResult execute(MatchInput const& input, MatchState& state) const override;
ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::ForkStay; }
@ -649,6 +645,12 @@ public:
}
};
class OpCode_ForkReplaceStay final : public OpCode_ForkStay {
public:
ExecutionResult execute(MatchInput const& input, MatchState& state) const override;
ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::ForkReplaceStay; }
};
class OpCode_CheckBegin final : public OpCode {
public:
ExecutionResult execute(MatchInput const& input, MatchState& state) const override;
@ -725,6 +727,7 @@ public:
ALWAYS_INLINE size_t arguments_size() const { return argument(1); }
String const arguments_string() const override;
Vector<String> const variable_arguments_to_string(Optional<MatchInput> input = {}) const;
Vector<CompareTypeAndValuePair> flat_compares() const;
private:
ALWAYS_INLINE static void compare_char(MatchInput const& input, MatchState& state, u32 ch1, bool inverse, bool& inverse_matched);

View file

@ -0,0 +1,18 @@
/*
* Copyright (c) 2021, Ali Mohammad Pur <mpfard@serenityos.org>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#pragma once
#include "Forward.h"
namespace regex {
class Optimizer {
public:
static void append_alternation(ByteCode& target, ByteCode& left, ByteCode& right);
};
}

View file

@ -6,6 +6,7 @@
#pragma once
#include "Forward.h"
#include "RegexOptions.h"
#include <AK/FlyString.h>
@ -514,6 +515,7 @@ struct MatchInput {
mutable Vector<size_t> saved_positions;
mutable Vector<size_t> saved_code_unit_positions;
mutable HashMap<u64, u64> checkpoints;
mutable Optional<size_t> fork_to_replace;
};
struct MatchState {
@ -522,6 +524,7 @@ struct MatchState {
size_t string_position_in_code_units { 0 };
size_t instruction_position { 0 };
size_t fork_at_position { 0 };
Optional<size_t> initiating_fork;
Vector<Match> matches;
Vector<Vector<Match>> capture_group_matches;
Vector<u64> repetition_marks;

View file

@ -39,6 +39,7 @@ Regex<Parser>::Regex(String pattern, typename ParserTraits<Parser>::OptionsType
Parser parser(lexer, regex_options);
parser_result = parser.parse();
run_optimization_passes();
if (parser_result.error == regex::Error::NoError)
matcher = make<Matcher<Parser>>(this, regex_options);
}
@ -48,6 +49,7 @@ Regex<Parser>::Regex(regex::Parser::Result parse_result, String pattern, typenam
: pattern_value(move(pattern))
, parser_result(move(parse_result))
{
run_optimization_passes();
if (parser_result.error == regex::Error::NoError)
matcher = make<Matcher<Parser>>(this, regex_options);
}
@ -370,6 +372,9 @@ public:
return m_first == nullptr;
}
auto reverse_begin() { return ReverseIterator(m_last); }
auto reverse_end() { return ReverseIterator(); }
private:
struct Node {
T value;
@ -377,6 +382,27 @@ private:
Node* previous { nullptr };
};
struct ReverseIterator {
ReverseIterator() = default;
explicit ReverseIterator(Node* node)
: m_node(node)
{
}
T* operator->() { return &m_node->value; }
T& operator*() { return m_node->value; }
bool operator==(ReverseIterator const& it) const { return m_node == it.m_node; }
ReverseIterator& operator++()
{
if (m_node)
m_node = m_node->previous;
return *this;
}
private:
Node* m_node;
};
UniformBumpAllocator<Node, true, 2 * MiB> m_allocator;
Node* m_first { nullptr };
Node* m_last { nullptr };
@ -413,15 +439,48 @@ Optional<bool> Matcher<Parser>::execute(MatchInput const& input, MatchState& sta
state.instruction_position += opcode.size();
switch (result) {
case ExecutionResult::Fork_PrioLow:
states_to_try_next.append(state);
states_to_try_next.last().instruction_position = state.fork_at_position;
case ExecutionResult::Fork_PrioLow: {
bool found = false;
if (input.fork_to_replace.has_value()) {
for (auto it = states_to_try_next.reverse_begin(); it != states_to_try_next.reverse_end(); ++it) {
if (it->initiating_fork == input.fork_to_replace.value()) {
(*it) = state;
it->instruction_position = state.fork_at_position;
it->initiating_fork = *input.fork_to_replace;
found = true;
break;
}
}
input.fork_to_replace.clear();
}
if (!found) {
states_to_try_next.append(state);
states_to_try_next.last().initiating_fork = state.instruction_position - opcode.size();
states_to_try_next.last().instruction_position = state.fork_at_position;
}
continue;
case ExecutionResult::Fork_PrioHigh:
states_to_try_next.append(state);
}
case ExecutionResult::Fork_PrioHigh: {
bool found = false;
if (input.fork_to_replace.has_value()) {
for (auto it = states_to_try_next.reverse_begin(); it != states_to_try_next.reverse_end(); ++it) {
if (it->initiating_fork == input.fork_to_replace.value()) {
(*it) = state;
it->initiating_fork = *input.fork_to_replace;
found = true;
break;
}
}
input.fork_to_replace.clear();
}
if (!found) {
states_to_try_next.append(state);
states_to_try_next.last().initiating_fork = state.instruction_position - opcode.size();
}
state.instruction_position = state.fork_at_position;
++recursion_level;
continue;
}
case ExecutionResult::Continue:
continue;
case ExecutionResult::Succeeded:

View file

@ -24,6 +24,15 @@
namespace regex {
namespace Detail {
struct Block {
size_t start;
size_t end;
};
}
static constexpr const size_t c_max_recursion = 5000;
static constexpr const size_t c_match_preallocation_count = 0;
@ -217,6 +226,12 @@ public:
RegexResult result = matcher->match(views, AllOptions { regex_options.value_or({}) } | AllFlags::SkipSubExprResults);
return result.success;
}
private:
void run_optimization_passes();
using BasicBlockList = Vector<Detail::Block>;
BasicBlockList split_basic_blocks();
void attempt_rewrite_loops_as_atomic_groups(BasicBlockList const&);
};
// free standing functions for match, search and has_match

View file

@ -0,0 +1,477 @@
/*
* Copyright (c) 2021, Ali Mohammad Pur <mpfard@serenityos.org>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/QuickSort.h>
#include <AK/RedBlackTree.h>
#include <AK/Stack.h>
#include <LibRegex/Regex.h>
#include <LibRegex/RegexBytecodeStreamOptimizer.h>
namespace regex {
using Detail::Block;
template<typename Parser>
void Regex<Parser>::run_optimization_passes()
{
// Rewrite fork loops as atomic groups
// e.g. a*b -> (ATOMIC a*)b
attempt_rewrite_loops_as_atomic_groups(split_basic_blocks());
}
template<typename Parser>
typename Regex<Parser>::BasicBlockList Regex<Parser>::split_basic_blocks()
{
BasicBlockList block_boundaries;
auto& bytecode = parser_result.bytecode;
size_t end_of_last_block = 0;
MatchState state;
state.instruction_position = 0;
auto check_jump = [&]<typename T>(OpCode const& opcode) {
auto& op = static_cast<T const&>(opcode);
ssize_t jump_offset = op.size() + op.offset();
if (jump_offset >= 0) {
block_boundaries.append({ end_of_last_block, state.instruction_position });
end_of_last_block = state.instruction_position + opcode.size();
} else {
// This op jumps back, see if that's within this "block".
if (jump_offset + state.instruction_position > end_of_last_block) {
// Split the block!
block_boundaries.append({ end_of_last_block, jump_offset + state.instruction_position });
block_boundaries.append({ jump_offset + state.instruction_position, state.instruction_position });
end_of_last_block = state.instruction_position + opcode.size();
} else {
// Nope, it's just a jump to another block
block_boundaries.append({ end_of_last_block, state.instruction_position });
end_of_last_block = state.instruction_position + opcode.size();
}
}
};
for (;;) {
auto& opcode = bytecode.get_opcode(state);
switch (opcode.opcode_id()) {
case OpCodeId::Jump:
check_jump.template operator()<OpCode_Jump>(opcode);
break;
case OpCodeId::JumpNonEmpty:
check_jump.template operator()<OpCode_JumpNonEmpty>(opcode);
break;
case OpCodeId::ForkJump:
check_jump.template operator()<OpCode_ForkJump>(opcode);
break;
case OpCodeId::ForkStay:
check_jump.template operator()<OpCode_ForkStay>(opcode);
break;
case OpCodeId::FailForks:
block_boundaries.append({ end_of_last_block, state.instruction_position });
end_of_last_block = state.instruction_position + opcode.size();
break;
case OpCodeId::Repeat: {
// Repeat produces two blocks, one containing its repeated expr, and one after that.
auto repeat_start = state.instruction_position - static_cast<OpCode_Repeat const&>(opcode).offset();
if (repeat_start > end_of_last_block)
block_boundaries.append({ end_of_last_block, repeat_start });
block_boundaries.append({ repeat_start, state.instruction_position });
end_of_last_block = state.instruction_position + opcode.size();
break;
}
default:
break;
}
auto next_ip = state.instruction_position + opcode.size();
if (next_ip < bytecode.size())
state.instruction_position = next_ip;
else
break;
}
if (end_of_last_block < bytecode.size())
block_boundaries.append({ end_of_last_block, bytecode.size() });
quick_sort(block_boundaries, [](auto& a, auto& b) { return a.start < b.start; });
return block_boundaries;
}
static bool block_satisfies_atomic_rewrite_precondition(ByteCode const& bytecode, Block const& repeated_block, Block const& following_block)
{
Vector<Vector<CompareTypeAndValuePair>> repeated_values;
MatchState state;
for (state.instruction_position = repeated_block.start; state.instruction_position < repeated_block.end;) {
auto& opcode = bytecode.get_opcode(state);
switch (opcode.opcode_id()) {
case OpCodeId::Compare: {
auto compares = static_cast<OpCode_Compare const&>(opcode).flat_compares();
if (repeated_values.is_empty() && any_of(compares, [](auto& compare) { return compare.type == CharacterCompareType::AnyChar; }))
return false;
repeated_values.append(move(compares));
break;
}
case OpCodeId::CheckBegin:
case OpCodeId::CheckEnd:
if (repeated_values.is_empty())
return true;
break;
case OpCodeId::CheckBoundary:
// FIXME: What should we do with these? for now, let's fail.
return false;
case OpCodeId::Restore:
case OpCodeId::GoBack:
return false;
default:
break;
}
state.instruction_position += opcode.size();
}
dbgln_if(REGEX_DEBUG, "Found {} entries in reference", repeated_values.size());
// Find the first compare in the following block, it must NOT match any of the values in `repeated_values'.
for (state.instruction_position = following_block.start; state.instruction_position < following_block.end;) {
auto& opcode = bytecode.get_opcode(state);
switch (opcode.opcode_id()) {
case OpCodeId::Compare: {
// We found a compare, let's see what it has.
auto compares = static_cast<OpCode_Compare const&>(opcode).flat_compares();
if (compares.is_empty())
break;
// If either side can match _anything_, fail.
if (any_of(compares, [](auto& compare) { return compare.type == CharacterCompareType::AnyChar; }))
return false;
for (auto& repeated_value : repeated_values) {
// FIXME: This is too naive!
if (any_of(repeated_value, [](auto& compare) { return compare.type == CharacterCompareType::AnyChar; }))
return false;
for (auto& repeated_compare : repeated_value) {
// FIXME: This is too naive! it will miss _tons_ of cases since it doesn't check ranges!
if (any_of(compares, [&](auto& compare) { return compare.type == repeated_compare.type && compare.value == repeated_compare.value; }))
return false;
}
}
return true;
}
case OpCodeId::CheckBegin:
case OpCodeId::CheckEnd:
return true; // Nothing can match the end!
case OpCodeId::CheckBoundary:
// FIXME: What should we do with these? For now, consider them a failure.
return false;
default:
break;
}
state.instruction_position += opcode.size();
}
return true;
}
template<typename Parser>
void Regex<Parser>::attempt_rewrite_loops_as_atomic_groups(BasicBlockList const& basic_blocks)
{
auto& bytecode = parser_result.bytecode;
if constexpr (REGEX_DEBUG) {
RegexDebug dbg;
dbg.print_bytecode(*this);
for (auto& block : basic_blocks)
dbgln("block from {} to {}", block.start, block.end);
}
// A pattern such as:
// bb0 | RE0
// | ForkX bb0
// -------------------------
// bb1 | RE1
// can be rewritten as:
// loop.hdr | ForkStay bb1
// -------------------------
// bb0 | RE0
// | ForkReplaceX bb0
// -------------------------
// bb1 | RE1
// provided that first(RE1) not-in end(RE0), which is to say
// that RE1 cannot start with whatever RE0 has matched (ever).
//
// Alternatively, a second form of this pattern can also occur:
// bb0 | *
// | ForkX bb2
// ------------------------
// bb1 | RE0
// | Jump bb0
// ------------------------
// bb2 | RE1
// which can be transformed (with the same preconditions) to:
// bb0 | *
// | ForkReplaceX bb2
// ------------------------
// bb1 | RE0
// | Jump bb0
// ------------------------
// bb2 | RE1
enum class AlternateForm {
DirectLoopWithoutHeader, // loop without proper header, a block forking to itself. i.e. the first form.
DirectLoopWithHeader, // loop with proper header, i.e. the second form.
};
struct CandidateBlock {
Block forking_block;
Optional<Block> new_target_block;
AlternateForm form;
};
Vector<CandidateBlock> candidate_blocks;
auto is_an_eligible_jump = [](OpCode const& opcode, size_t ip, size_t block_start, AlternateForm alternate_form) {
switch (opcode.opcode_id()) {
case OpCodeId::JumpNonEmpty: {
auto& op = static_cast<OpCode_JumpNonEmpty const&>(opcode);
auto form = op.form();
if (form != OpCodeId::Jump && alternate_form == AlternateForm::DirectLoopWithHeader)
return false;
if (form != OpCodeId::ForkJump && form != OpCodeId::ForkStay && alternate_form == AlternateForm::DirectLoopWithoutHeader)
return false;
return op.offset() + ip + opcode.size() == block_start;
}
case OpCodeId::ForkJump:
if (alternate_form == AlternateForm::DirectLoopWithHeader)
return false;
return static_cast<OpCode_ForkJump const&>(opcode).offset() + ip + opcode.size() == block_start;
case OpCodeId::ForkStay:
if (alternate_form == AlternateForm::DirectLoopWithHeader)
return false;
return static_cast<OpCode_ForkStay const&>(opcode).offset() + ip + opcode.size() == block_start;
case OpCodeId::Jump:
// Infinite loop does *not* produce forks.
if (alternate_form == AlternateForm::DirectLoopWithoutHeader)
return false;
if (alternate_form == AlternateForm::DirectLoopWithHeader)
return static_cast<OpCode_Jump const&>(opcode).offset() + ip + opcode.size() == block_start;
VERIFY_NOT_REACHED();
default:
return false;
}
};
for (size_t i = 0; i < basic_blocks.size(); ++i) {
auto forking_block = basic_blocks[i];
Optional<Block> fork_fallback_block;
if (i + 1 < basic_blocks.size())
fork_fallback_block = basic_blocks[i + 1];
MatchState state;
// Check if the last instruction in this block is a jump to the block itself:
{
state.instruction_position = forking_block.end;
auto& opcode = bytecode.get_opcode(state);
if (is_an_eligible_jump(opcode, state.instruction_position, forking_block.start, AlternateForm::DirectLoopWithoutHeader)) {
// We've found RE0 (and RE1 is just the following block, if any), let's see if the precondition applies.
// if RE1 is empty, there's no first(RE1), so this is an automatic pass.
if (!fork_fallback_block.has_value() || fork_fallback_block->end == fork_fallback_block->start) {
candidate_blocks.append({ forking_block, fork_fallback_block, AlternateForm::DirectLoopWithoutHeader });
break;
}
if (block_satisfies_atomic_rewrite_precondition(bytecode, forking_block, *fork_fallback_block)) {
candidate_blocks.append({ forking_block, fork_fallback_block, AlternateForm::DirectLoopWithoutHeader });
break;
}
}
}
// Check if the last instruction in the last block is a direct jump to this block
if (fork_fallback_block.has_value()) {
state.instruction_position = fork_fallback_block->end;
auto& opcode = bytecode.get_opcode(state);
if (is_an_eligible_jump(opcode, state.instruction_position, forking_block.start, AlternateForm::DirectLoopWithHeader)) {
// We've found bb1 and bb0, let's just make sure that bb0 forks to bb2.
state.instruction_position = forking_block.end;
auto& opcode = bytecode.get_opcode(state);
if (opcode.opcode_id() == OpCodeId::ForkJump || opcode.opcode_id() == OpCodeId::ForkStay) {
Optional<Block> block_following_fork_fallback;
if (i + 2 < basic_blocks.size())
block_following_fork_fallback = basic_blocks[i + 2];
if (!block_following_fork_fallback.has_value() || block_satisfies_atomic_rewrite_precondition(bytecode, *fork_fallback_block, *block_following_fork_fallback)) {
candidate_blocks.append({ forking_block, {}, AlternateForm::DirectLoopWithHeader });
break;
}
}
}
}
}
dbgln_if(REGEX_DEBUG, "Found {} candidate blocks", candidate_blocks.size());
if (candidate_blocks.is_empty()) {
dbgln_if(REGEX_DEBUG, "Failed to find anything for {}", pattern_value);
return;
}
RedBlackTree<size_t, size_t> needed_patches;
// Reverse the blocks, so we can patch the bytecode without messing with the latter patches.
quick_sort(candidate_blocks, [](auto& a, auto& b) { return b.forking_block.start > a.forking_block.start; });
for (auto& candidate : candidate_blocks) {
// Note that both forms share a ForkReplace patch in forking_block.
// Patch the ForkX in forking_block to be a ForkReplaceX instead.
auto& opcode_id = bytecode[candidate.forking_block.end];
if (opcode_id == (ByteCodeValueType)OpCodeId::ForkStay) {
opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceStay;
} else if (opcode_id == (ByteCodeValueType)OpCodeId::ForkJump) {
opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceJump;
} else if (opcode_id == (ByteCodeValueType)OpCodeId::JumpNonEmpty) {
auto& jump_opcode_id = bytecode[candidate.forking_block.end + 3];
if (jump_opcode_id == (ByteCodeValueType)OpCodeId::ForkStay)
jump_opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceStay;
else if (jump_opcode_id == (ByteCodeValueType)OpCodeId::ForkJump)
jump_opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceJump;
else
VERIFY_NOT_REACHED();
} else {
VERIFY_NOT_REACHED();
}
if (candidate.form == AlternateForm::DirectLoopWithoutHeader) {
if (candidate.new_target_block.has_value()) {
// Insert a fork-stay targeted at the second block.
bytecode.insert(candidate.forking_block.start, (ByteCodeValueType)OpCodeId::ForkStay);
bytecode.insert(candidate.forking_block.start + 1, candidate.new_target_block->start - candidate.forking_block.start);
needed_patches.insert(candidate.forking_block.start, 2u);
}
}
}
if (!needed_patches.is_empty()) {
MatchState state;
state.instruction_position = 0;
struct Patch {
ssize_t value;
size_t offset;
bool should_negate { false };
};
for (;;) {
if (state.instruction_position >= bytecode.size())
break;
auto& opcode = bytecode.get_opcode(state);
Stack<Patch, 2> patch_points;
switch (opcode.opcode_id()) {
case OpCodeId::Jump:
patch_points.push({ static_cast<OpCode_Jump const&>(opcode).offset(), state.instruction_position + 1 });
break;
case OpCodeId::JumpNonEmpty:
patch_points.push({ static_cast<OpCode_JumpNonEmpty const&>(opcode).offset(), state.instruction_position + 1 });
patch_points.push({ static_cast<OpCode_JumpNonEmpty const&>(opcode).checkpoint(), state.instruction_position + 2 });
break;
case OpCodeId::ForkJump:
patch_points.push({ static_cast<OpCode_ForkJump const&>(opcode).offset(), state.instruction_position + 1 });
break;
case OpCodeId::ForkStay:
patch_points.push({ static_cast<OpCode_ForkStay const&>(opcode).offset(), state.instruction_position + 1 });
break;
case OpCodeId::Repeat:
patch_points.push({ -(ssize_t) static_cast<OpCode_Repeat const&>(opcode).offset(), state.instruction_position + 1, true });
break;
default:
break;
}
while (!patch_points.is_empty()) {
auto& patch_point = patch_points.top();
auto target_offset = patch_point.value + state.instruction_position + opcode.size();
constexpr auto do_patch = [](auto& patch_it, auto& patch_point, auto& target_offset, auto& bytecode, auto ip) {
if (patch_it.key() == ip)
return;
if (patch_point.value < 0 && target_offset < patch_it.key() && ip > patch_it.key())
bytecode[patch_point.offset] += (patch_point.should_negate ? 1 : -1) * (*patch_it);
else if (patch_point.value > 0 && target_offset > patch_it.key() && ip < patch_it.key())
bytecode[patch_point.offset] += (patch_point.should_negate ? -1 : 1) * (*patch_it);
};
if (auto patch_it = needed_patches.find_largest_not_above_iterator(target_offset); !patch_it.is_end())
do_patch(patch_it, patch_point, target_offset, bytecode, state.instruction_position);
else if (auto patch_it = needed_patches.find_largest_not_above_iterator(state.instruction_position); !patch_it.is_end())
do_patch(patch_it, patch_point, target_offset, bytecode, state.instruction_position);
patch_points.pop();
}
state.instruction_position += opcode.size();
}
}
if constexpr (REGEX_DEBUG) {
warnln("Transformed to:");
RegexDebug dbg;
dbg.print_bytecode(*this);
}
}
void Optimizer::append_alternation(ByteCode& target, ByteCode& left, ByteCode& right)
{
if (left.is_empty()) {
target.extend(right);
return;
}
if (right.is_empty()) {
target.extend(left);
return;
}
size_t left_skip = 0;
MatchState state;
for (state.instruction_position = 0; state.instruction_position < left.size() && state.instruction_position < right.size();) {
auto left_size = left.get_opcode(state).size();
auto right_size = right.get_opcode(state).size();
if (left_size != right_size)
break;
if (left.span().slice(state.instruction_position, left_size) == right.span().slice(state.instruction_position, right_size))
left_skip = state.instruction_position + left_size;
else
break;
state.instruction_position += left_size;
}
// FIXME: Implement postfix unification too.
size_t right_skip = 0;
if (left_skip)
target.append(left.data(), left_skip);
dbgln_if(REGEX_DEBUG, "Skipping {}/{} bytecode entries from {}/{}", left_skip, right_skip, left.size(), right.size());
auto left_slice = left.span().slice(left_skip, left.size() - left_skip - right_skip);
auto right_slice = right.span().slice(left_skip, right.size() - left_skip - right_skip);
target.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
target.empend(right_slice.size() + 2); // Jump to the _ALT label
target.append(right_slice.data(), right_slice.size());
if (!left_slice.is_empty()) {
target.empend(static_cast<ByteCodeValueType>(OpCodeId::Jump));
target.empend(left_slice.size()); // Jump to the _END label
}
// LABEL _ALT = bytecode.size() + 2
target.append(left_slice.data(), left_slice.size());
// LABEL _END = alterantive_bytecode.size
if (right_skip)
target.append(left.span().slice_from_end(right_skip).data(), right_skip);
}
template void Regex<PosixBasicParser>::run_optimization_passes();
template void Regex<PosixExtendedParser>::run_optimization_passes();
template void Regex<ECMA262Parser>::run_optimization_passes();
}