feat: Added support for adding custom providers through C#

This commit is contained in:
WerWolv 2024-03-10 22:05:26 +01:00
parent d817a813b0
commit 0186f2f456
8 changed files with 221 additions and 32 deletions

View file

@ -931,7 +931,7 @@ namespace hex {
void addProviderName(const UnlocalizedString &unlocalizedName);
using ProviderCreationFunction = std::unique_ptr<prv::Provider>(*)();
using ProviderCreationFunction = std::function<std::unique_ptr<prv::Provider>()>;
void add(const std::string &typeName, ProviderCreationFunction creationFunction);
const std::vector<std::string>& getEntries();

View file

@ -1,5 +1,6 @@
using System.Reflection;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.ComTypes;
using System.Runtime.Loader;
namespace ImHex
@ -12,7 +13,7 @@ namespace ImHex
{
try
{
return ExecuteScript(Marshal.PtrToStringUTF8(arg, argLength)) ? 0 : 1;
return ExecuteScript(Marshal.PtrToStringUTF8(arg, argLength));
}
catch (Exception e)
{
@ -21,61 +22,108 @@ namespace ImHex
}
}
private static bool ExecuteScript(string path)
private static List<string> loadedPlugins = new();
private static int ExecuteScript(string args)
{
// Parse input in the form of "execType||path"
var splitArgs = args.Split("||");
var type = splitArgs[0];
var methodName = splitArgs[1];
var path = splitArgs[2];
// Get the parent folder of the passed path
string? basePath = Path.GetDirectoryName(path);
if (basePath == null)
{
Console.WriteLine("[.NET Script] Failed to get base path");
return false;
return 1;
}
// Create a new assembly context
AssemblyLoadContext? context = new("ScriptDomain_" + basePath, true);
int result = 0;
try
{
if (type is "LOAD")
{
if (loadedPlugins.Contains(path))
{
return 0;
}
// Check if the plugin is already loaded
loadedPlugins.Add(path);
}
// Load all assemblies in the parent folder
foreach (var file in Directory.GetFiles(basePath, "*.dll"))
{
// Skip main Assembly
if (file.EndsWith("Main.dll"))
{
continue;
}
context.LoadFromStream(new MemoryStream(File.ReadAllBytes(file)));
}
// Load the script assembly
var assembly = context.LoadFromStream(new MemoryStream(File.ReadAllBytes(path)));
// Find a class named "Script"
var entryPointType = assembly.GetType("Script");
if (entryPointType == null)
{
Console.WriteLine("[.NET Script] Failed to find Script type");
return false;
return 1;
}
var entryPointMethod = entryPointType.GetMethod("Main", BindingFlags.Static | BindingFlags.Public);
if (entryPointMethod == null)
if (type is "EXEC" or "LOAD")
{
Console.WriteLine("[.NET Script] Failed to find ScriptMain method");
return false;
// Load the function
var method = entryPointType.GetMethod(methodName, BindingFlags.Static | BindingFlags.Public);
if (method == null)
{
return 2;
}
entryPointMethod.Invoke(null, null);
// Execute it
method.Invoke(null, null);
}
else if (type == "CHECK")
{
return entryPointType.GetMethod(methodName, BindingFlags.Static | BindingFlags.Public) != null ? 0 : 1;
}
else
{
return 1;
}
}
catch (Exception e)
{
Console.WriteLine("[.NET Script] Exception in AssemblyLoader: " + e.ToString());
return false;
return 3;
}
finally
{
if (type != "LOAD")
{
// Unload all assemblies associated with this script
context.Unload();
context = null;
// Run the garbage collector multiple times to make sure that the
// assemblies are unloaded for sure
for (int i = 0; i < 10; i++)
{
GC.Collect();
GC.WaitForPendingFinalizers();
}
}
}
return true;
return result;
}
}

View file

@ -17,7 +17,8 @@ namespace hex::script::loader {
bool loadAll() override;
private:
std::function<bool(const std::fs::path&)> m_loadAssembly;
std::function<int(const std::string &, bool, const std::fs::path&)> m_runMethod;
std::function<bool(const std::string &, const std::fs::path&)> m_methodExists;
std::fs::path::string_type m_assemblyLoaderPathString;
};

View file

@ -179,8 +179,19 @@ namespace hex::script::loader {
continue;
}
m_loadAssembly = [entryPoint](const std::fs::path &path) -> bool {
auto string = wolv::util::toUTF8String(path);
m_runMethod = [entryPoint](const std::string &methodName, bool keepLoaded, const std::fs::path &path) -> int {
auto pathString = wolv::util::toUTF8String(path);
auto string = hex::format("{}||{}||{}", keepLoaded ? "LOAD" : "EXEC", methodName, pathString);
auto result = entryPoint(string.data(), string.size());
return result;
};
m_methodExists = [entryPoint](const std::string &methodName, const std::fs::path &path) -> bool {
auto pathString = wolv::util::toUTF8String(path);
auto string = hex::format("CHECK||{}||{}", methodName, pathString);
auto result = entryPoint(string.data(), string.size());
return result == 0;
@ -211,10 +222,16 @@ namespace hex::script::loader {
if (!std::fs::exists(scriptPath))
continue;
if (m_methodExists("Main", scriptPath)) {
this->addScript(entry.path().stem().string(), [this, scriptPath] {
hex::unused(m_loadAssembly(scriptPath));
hex::unused(m_runMethod("Main", false, scriptPath));
});
}
if (m_methodExists("OnLoad", scriptPath)) {
hex::unused(m_runMethod("OnLoad", true, scriptPath));
}
}
}
return true;

View file

@ -69,10 +69,10 @@ namespace {
}
void addScriptsMenu() {
static std::vector<const Script*> scripts;
static TaskHolder runnerTask, updaterTask;
hex::ContentRegistry::Interface::addMenuItemSubMenu({ "hex.builtin.menu.extras" }, 5000, [] {
static bool menuJustOpened = true;
static std::vector<const Script*> scripts;
if (ImGui::BeginMenu("hex.script_loader.menu.run_script"_lang)) {
if (menuJustOpened) {
@ -107,6 +107,10 @@ namespace {
}, [] {
return !runnerTask.isRunning();
});
updaterTask = TaskManager::createBackgroundTask("Updating Scripts...", [] (auto&) {
scripts = loadAllScripts();
});
}
}
@ -119,5 +123,4 @@ IMHEX_PLUGIN_SETUP("Script Loader", "WerWolv", "Script Loader plugin") {
if (initializeAllLoaders()) {
addScriptsMenu();
}
}

View file

@ -1,4 +1,5 @@
#include <script_api.hpp>
#include <hex/api/content_registry.hpp>
#include <hex/api/imhex_api.hpp>
#include <hex/providers/provider.hpp>
@ -42,3 +43,63 @@ SCRIPT_API(bool getSelection, u64 *start, u64 *end) {
return true;
}
class ScriptDataProvider : public hex::prv::Provider {
public:
using ReadFunction = void(*)(u64, void*, u64);
using WriteFunction = void(*)(u64, const void*, u64);
using GetSizeFunction = u64(*)();
using GetNameFunction = std::string(*)();
bool open() override { return true; }
void close() override { }
[[nodiscard]] bool isAvailable() const override { return true; }
[[nodiscard]] bool isReadable() const override { return true; }
[[nodiscard]] bool isWritable() const override { return true; }
[[nodiscard]] bool isResizable() const override { return true; }
[[nodiscard]] bool isSavable() const override { return true; }
[[nodiscard]] bool isDumpable() const override { return true; }
void readRaw(u64 offset, void *buffer, size_t size) override {
m_readFunction(offset, buffer, size);
}
void writeRaw(u64 offset, const void *buffer, size_t size) override {
m_writeFunction(offset, const_cast<void*>(buffer), size);
}
void setFunctions(ReadFunction readFunc, WriteFunction writeFunc, GetSizeFunction getSizeFunc) {
m_readFunction = readFunc;
m_writeFunction = writeFunc;
m_getSizeFunction = getSizeFunc;
}
[[nodiscard]] u64 getActualSize() const override { return m_getSizeFunction(); }
void setTypeName(std::string typeName) { m_typeName = std::move(typeName);}
void setName(std::string name) { m_name = std::move(name);}
[[nodiscard]] std::string getTypeName() const override { return m_typeName; }
[[nodiscard]] std::string getName() const override { return m_name; }
private:
ReadFunction m_readFunction = nullptr;
WriteFunction m_writeFunction = nullptr;
GetSizeFunction m_getSizeFunction = nullptr;
GetNameFunction m_getNameFunction = nullptr;
std::string m_typeName, m_name;
};
SCRIPT_API(void registerProvider, const char *typeName, const char *name, ScriptDataProvider::ReadFunction readFunc, ScriptDataProvider::WriteFunction writeFunc, ScriptDataProvider::GetSizeFunction getSizeFunc) {
auto typeNameString = std::string(typeName);
auto nameString = std::string(name);
hex::ContentRegistry::Provider::impl::add(typeNameString, [typeNameString, nameString, readFunc, writeFunc, getSizeFunc] -> std::unique_ptr<hex::prv::Provider> {
auto provider = std::make_unique<ScriptDataProvider>();
provider->setTypeName(typeNameString);
provider->setName(nameString);
provider->setFunctions(readFunc, writeFunc, getSizeFunc);
return provider;
});
hex::ContentRegistry::Provider::impl::addProviderName(typeNameString);
}

View file

@ -5,8 +5,42 @@ using System.Runtime.InteropServices;
namespace ImHex
{
public interface IProvider
{
void readRaw(UInt64 address, IntPtr buffer, UInt64 size)
{
unsafe
{
Span<byte> data = new(buffer.ToPointer(), (int)size);
read(address, data);
}
}
void writeRaw(UInt64 address, IntPtr buffer, UInt64 size)
{
unsafe
{
ReadOnlySpan<byte> data = new(buffer.ToPointer(), (int)size);
write(address, data);
}
}
void read(UInt64 address, Span<byte> data);
void write(UInt64 address, ReadOnlySpan<byte> data);
UInt64 getSize();
string getTypeName();
string getName();
}
public class Memory
{
private static List<IProvider> _registeredProviders = new();
private static List<Delegate> _registeredProviderDelegates = new();
private delegate void DataAccessDelegate(UInt64 address, IntPtr buffer, UInt64 size);
private delegate UInt64 GetSizeDelegate();
[DllImport(Library.Name)]
private static extern void readMemoryV1(UInt64 address, UInt64 size, IntPtr buffer);
@ -16,6 +50,9 @@ namespace ImHex
[DllImport(Library.Name)]
private static extern bool getSelectionV1(IntPtr start, IntPtr end);
[DllImport(Library.Name)]
private static extern int registerProviderV1([MarshalAs(UnmanagedType.LPStr)] string typeName, [MarshalAs(UnmanagedType.LPStr)] string name, IntPtr readFunction, IntPtr writeFunction, IntPtr getSizeFunction);
public static byte[] Read(ulong address, ulong size)
{
@ -58,5 +95,24 @@ namespace ImHex
}
}
public static int RegisterProvider<T>() where T : IProvider, new()
{
_registeredProviders.Add(new T());
ref var provider = ref CollectionsMarshal.AsSpan(_registeredProviders)[^1];
_registeredProviderDelegates.Add(new DataAccessDelegate(provider.readRaw));
_registeredProviderDelegates.Add(new DataAccessDelegate(provider.writeRaw));
_registeredProviderDelegates.Add(new GetSizeDelegate(provider.getSize));
return registerProviderV1(
_registeredProviders[^1].getTypeName(),
_registeredProviders[^1].getName(),
Marshal.GetFunctionPointerForDelegate(_registeredProviderDelegates[^3]),
Marshal.GetFunctionPointerForDelegate(_registeredProviderDelegates[^2]),
Marshal.GetFunctionPointerForDelegate(_registeredProviderDelegates[^1])
);
}
}
}

View file

@ -1,10 +1,13 @@
using ImHex;
using System.Drawing;
class Script {
public static void OnLoad() {
// This function is executed the first time the Plugin is loaded
}
class Script
{
public static void Main()
{
UI.ShowMessageBox("Hello World!");
// This function is executed when the plugin is selected in the "Run Script..." menu
}
}