serenity/Userland/Libraries/LibSQL/SQLClient.cpp
Andrew Kaster afb3a4a030 LibSQL: Block signals while forking SQLServer in Lagom
When debugging in Xcode, the waitpid() for the initial forked process
would always return EINTR or ECHILD. Work around this by blocking all
signals until we're ready to wait for the initial child.
2023-03-28 09:18:50 +01:00

238 lines
7.7 KiB
C++

/*
* Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
* Copyright (c) 2022, the SerenityOS developers.
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/DeprecatedString.h>
#include <AK/String.h>
#include <LibSQL/SQLClient.h>
#if !defined(AK_OS_SERENITY)
# include <LibCore/Directory.h>
# include <LibCore/SocketAddress.h>
# include <LibCore/StandardPaths.h>
# include <LibCore/System.h>
# include <LibFileSystem/FileSystem.h>
# include <signal.h>
#endif
namespace SQL {
#if !defined(AK_OS_SERENITY)
// This is heavily based on how SystemServer's Service creates its socket.
static ErrorOr<int> create_database_socket(DeprecatedString const& socket_path)
{
if (FileSystem::exists(socket_path))
TRY(Core::System::unlink(socket_path));
# ifdef SOCK_NONBLOCK
auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
# else
auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM, 0));
int option = 1;
TRY(Core::System::ioctl(socket_fd, FIONBIO, &option));
TRY(Core::System::fcntl(socket_fd, F_SETFD, FD_CLOEXEC));
# endif
# if !defined(AK_OS_BSD_GENERIC)
TRY(Core::System::fchmod(socket_fd, 0600));
# endif
auto socket_address = Core::SocketAddress::local(socket_path);
auto socket_address_un = socket_address.to_sockaddr_un().release_value();
TRY(Core::System::bind(socket_fd, reinterpret_cast<sockaddr*>(&socket_address_un), sizeof(socket_address_un)));
TRY(Core::System::listen(socket_fd, 16));
return socket_fd;
}
static ErrorOr<void> launch_server(DeprecatedString const& socket_path, DeprecatedString const& pid_path, Vector<String> candidate_server_paths)
{
auto server_fd_or_error = create_database_socket(socket_path);
if (server_fd_or_error.is_error()) {
warnln("Failed to create a database socket at {}: {}", socket_path, server_fd_or_error.error());
return server_fd_or_error.release_error();
}
auto server_fd = server_fd_or_error.value();
sigset_t original_set;
sigset_t setting_set;
sigfillset(&setting_set);
(void)pthread_sigmask(SIG_BLOCK, &setting_set, &original_set);
auto server_pid = TRY(Core::System::fork());
if (server_pid == 0) {
(void)pthread_sigmask(SIG_SETMASK, &original_set, nullptr);
TRY(Core::System::setsid());
TRY(Core::System::signal(SIGCHLD, SIG_IGN));
server_pid = TRY(Core::System::fork());
if (server_pid != 0) {
auto server_pid_file = TRY(Core::File::open(pid_path, Core::File::OpenMode::Write));
TRY(server_pid_file->write_until_depleted(DeprecatedString::number(server_pid).bytes()));
TRY(Core::System::kill(getpid(), SIGTERM));
}
server_fd = TRY(Core::System::dup(server_fd));
auto takeover_string = DeprecatedString::formatted("SQLServer:{}", server_fd);
TRY(Core::System::setenv("SOCKET_TAKEOVER"sv, takeover_string, true));
ErrorOr<void> result;
for (auto const& server_path : candidate_server_paths) {
auto arguments = Array {
server_path.bytes_as_string_view(),
"--pid-file"sv,
pid_path,
};
result = Core::System::exec(arguments[0], arguments, Core::System::SearchInPath::Yes);
if (!result.is_error())
break;
}
if (result.is_error()) {
warnln("Could not launch any of {}: {}", candidate_server_paths, result.error());
TRY(Core::System::unlink(pid_path));
}
VERIFY_NOT_REACHED();
}
VERIFY(server_pid > 0);
auto wait_err = Core::System::waitpid(server_pid);
(void)pthread_sigmask(SIG_SETMASK, &original_set, nullptr);
if (wait_err.is_error())
return wait_err.release_error();
return {};
}
static ErrorOr<bool> should_launch_server(DeprecatedString const& pid_path)
{
if (!FileSystem::exists(pid_path))
return true;
Optional<pid_t> pid;
{
auto server_pid_file = Core::File::open(pid_path, Core::File::OpenMode::Read);
if (server_pid_file.is_error()) {
warnln("Could not open SQLServer PID file '{}': {}", pid_path, server_pid_file.error());
return server_pid_file.release_error();
}
auto contents = server_pid_file.value()->read_until_eof();
if (contents.is_error()) {
warnln("Could not read SQLServer PID file '{}': {}", pid_path, contents.error());
return contents.release_error();
}
pid = StringView { contents.value() }.to_int<pid_t>();
}
if (!pid.has_value()) {
warnln("SQLServer PID file '{}' exists, but with an invalid PID", pid_path);
TRY(Core::System::unlink(pid_path));
return true;
}
if (kill(*pid, 0) < 0) {
warnln("SQLServer PID file '{}' exists with PID {}, but process cannot be found", pid_path, *pid);
TRY(Core::System::unlink(pid_path));
return true;
}
return false;
}
ErrorOr<NonnullRefPtr<SQLClient>> SQLClient::launch_server_and_create_client(Vector<String> candidate_server_paths)
{
auto runtime_directory = TRY(Core::StandardPaths::runtime_directory());
auto socket_path = DeprecatedString::formatted("{}/SQLServer.socket", runtime_directory);
auto pid_path = DeprecatedString::formatted("{}/SQLServer.pid", runtime_directory);
if (TRY(should_launch_server(pid_path)))
TRY(launch_server(socket_path, pid_path, move(candidate_server_paths)));
auto socket = TRY(Core::LocalSocket::connect(move(socket_path)));
TRY(socket->set_blocking(true));
return adopt_nonnull_ref_or_enomem(new (nothrow) SQLClient(move(socket)));
}
#endif
void SQLClient::execution_success(u64 statement_id, u64 execution_id, Vector<DeprecatedString> const& column_names, bool has_results, size_t created, size_t updated, size_t deleted)
{
if (!on_execution_success) {
outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted);
return;
}
ExecutionSuccess success {
.statement_id = statement_id,
.execution_id = execution_id,
.column_names = move(const_cast<Vector<DeprecatedString>&>(column_names)),
.has_results = has_results,
.rows_created = created,
.rows_updated = updated,
.rows_deleted = deleted,
};
on_execution_success(move(success));
}
void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message)
{
if (!on_execution_error) {
warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code));
return;
}
ExecutionError error {
.statement_id = statement_id,
.execution_id = execution_id,
.error_code = code,
.error_message = move(const_cast<DeprecatedString&>(message)),
};
on_execution_error(move(error));
}
void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector<Value> const& row)
{
if (!on_next_result) {
StringBuilder builder;
builder.join(", "sv, row, "\"{}\""sv);
outln("{}", builder.string_view());
return;
}
ExecutionResult result {
.statement_id = statement_id,
.execution_id = execution_id,
.values = move(const_cast<Vector<Value>&>(row)),
};
on_next_result(move(result));
}
void SQLClient::results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows)
{
if (!on_results_exhausted) {
outln("{} total row(s)", total_rows);
return;
}
ExecutionComplete success {
.statement_id = statement_id,
.execution_id = execution_id,
.total_rows = total_rows,
};
on_results_exhausted(move(success));
}
}