feat: Added support for adding custom providers through C#

This commit is contained in:
WerWolv 2024-03-10 22:05:26 +01:00
parent d817a813b0
commit 0186f2f456
8 changed files with 221 additions and 32 deletions

View file

@ -931,7 +931,7 @@ namespace hex {
void addProviderName(const UnlocalizedString &unlocalizedName); void addProviderName(const UnlocalizedString &unlocalizedName);
using ProviderCreationFunction = std::unique_ptr<prv::Provider>(*)(); using ProviderCreationFunction = std::function<std::unique_ptr<prv::Provider>()>;
void add(const std::string &typeName, ProviderCreationFunction creationFunction); void add(const std::string &typeName, ProviderCreationFunction creationFunction);
const std::vector<std::string>& getEntries(); const std::vector<std::string>& getEntries();

View file

@ -1,5 +1,6 @@
using System.Reflection; using System.Reflection;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Runtime.InteropServices.ComTypes;
using System.Runtime.Loader; using System.Runtime.Loader;
namespace ImHex namespace ImHex
@ -12,7 +13,7 @@ namespace ImHex
{ {
try try
{ {
return ExecuteScript(Marshal.PtrToStringUTF8(arg, argLength)) ? 0 : 1; return ExecuteScript(Marshal.PtrToStringUTF8(arg, argLength));
} }
catch (Exception e) catch (Exception e)
{ {
@ -21,61 +22,108 @@ namespace ImHex
} }
} }
private static List<string> loadedPlugins = new();
private static bool ExecuteScript(string path) private static int ExecuteScript(string args)
{ {
// Parse input in the form of "execType||path"
var splitArgs = args.Split("||");
var type = splitArgs[0];
var methodName = splitArgs[1];
var path = splitArgs[2];
// Get the parent folder of the passed path
string? basePath = Path.GetDirectoryName(path); string? basePath = Path.GetDirectoryName(path);
if (basePath == null) if (basePath == null)
{ {
Console.WriteLine("[.NET Script] Failed to get base path"); Console.WriteLine("[.NET Script] Failed to get base path");
return false; return 1;
} }
// Create a new assembly context
AssemblyLoadContext? context = new("ScriptDomain_" + basePath, true); AssemblyLoadContext? context = new("ScriptDomain_" + basePath, true);
int result = 0;
try try
{ {
if (type is "LOAD")
{
if (loadedPlugins.Contains(path))
{
return 0;
}
// Check if the plugin is already loaded
loadedPlugins.Add(path);
}
// Load all assemblies in the parent folder
foreach (var file in Directory.GetFiles(basePath, "*.dll")) foreach (var file in Directory.GetFiles(basePath, "*.dll"))
{ {
// Skip main Assembly
if (file.EndsWith("Main.dll"))
{
continue;
}
context.LoadFromStream(new MemoryStream(File.ReadAllBytes(file))); context.LoadFromStream(new MemoryStream(File.ReadAllBytes(file)));
} }
// Load the script assembly
var assembly = context.LoadFromStream(new MemoryStream(File.ReadAllBytes(path))); var assembly = context.LoadFromStream(new MemoryStream(File.ReadAllBytes(path)));
// Find a class named "Script"
var entryPointType = assembly.GetType("Script"); var entryPointType = assembly.GetType("Script");
if (entryPointType == null) if (entryPointType == null)
{ {
Console.WriteLine("[.NET Script] Failed to find Script type"); Console.WriteLine("[.NET Script] Failed to find Script type");
return false; return 1;
} }
var entryPointMethod = entryPointType.GetMethod("Main", BindingFlags.Static | BindingFlags.Public); if (type is "EXEC" or "LOAD")
if (entryPointMethod == null)
{ {
Console.WriteLine("[.NET Script] Failed to find ScriptMain method"); // Load the function
return false; var method = entryPointType.GetMethod(methodName, BindingFlags.Static | BindingFlags.Public);
} if (method == null)
{
return 2;
}
entryPointMethod.Invoke(null, null); // Execute it
method.Invoke(null, null);
}
else if (type == "CHECK")
{
return entryPointType.GetMethod(methodName, BindingFlags.Static | BindingFlags.Public) != null ? 0 : 1;
}
else
{
return 1;
}
} }
catch (Exception e) catch (Exception e)
{ {
Console.WriteLine("[.NET Script] Exception in AssemblyLoader: " + e.ToString()); Console.WriteLine("[.NET Script] Exception in AssemblyLoader: " + e.ToString());
return false; return 3;
} }
finally finally
{ {
context.Unload(); if (type != "LOAD")
context = null;
for (int i = 0; i < 10; i++)
{ {
GC.Collect(); // Unload all assemblies associated with this script
GC.WaitForPendingFinalizers(); context.Unload();
context = null;
// Run the garbage collector multiple times to make sure that the
// assemblies are unloaded for sure
for (int i = 0; i < 10; i++)
{
GC.Collect();
GC.WaitForPendingFinalizers();
}
} }
} }
return true; return result;
} }
} }

View file

@ -17,7 +17,8 @@ namespace hex::script::loader {
bool loadAll() override; bool loadAll() override;
private: private:
std::function<bool(const std::fs::path&)> m_loadAssembly; std::function<int(const std::string &, bool, const std::fs::path&)> m_runMethod;
std::function<bool(const std::string &, const std::fs::path&)> m_methodExists;
std::fs::path::string_type m_assemblyLoaderPathString; std::fs::path::string_type m_assemblyLoaderPathString;
}; };

View file

@ -179,8 +179,19 @@ namespace hex::script::loader {
continue; continue;
} }
m_loadAssembly = [entryPoint](const std::fs::path &path) -> bool { m_runMethod = [entryPoint](const std::string &methodName, bool keepLoaded, const std::fs::path &path) -> int {
auto string = wolv::util::toUTF8String(path); auto pathString = wolv::util::toUTF8String(path);
auto string = hex::format("{}||{}||{}", keepLoaded ? "LOAD" : "EXEC", methodName, pathString);
auto result = entryPoint(string.data(), string.size());
return result;
};
m_methodExists = [entryPoint](const std::string &methodName, const std::fs::path &path) -> bool {
auto pathString = wolv::util::toUTF8String(path);
auto string = hex::format("CHECK||{}||{}", methodName, pathString);
auto result = entryPoint(string.data(), string.size()); auto result = entryPoint(string.data(), string.size());
return result == 0; return result == 0;
@ -211,9 +222,15 @@ namespace hex::script::loader {
if (!std::fs::exists(scriptPath)) if (!std::fs::exists(scriptPath))
continue; continue;
this->addScript(entry.path().stem().string(), [this, scriptPath] { if (m_methodExists("Main", scriptPath)) {
hex::unused(m_loadAssembly(scriptPath)); this->addScript(entry.path().stem().string(), [this, scriptPath] {
}); hex::unused(m_runMethod("Main", false, scriptPath));
});
}
if (m_methodExists("OnLoad", scriptPath)) {
hex::unused(m_runMethod("OnLoad", true, scriptPath));
}
} }
} }

View file

@ -69,10 +69,10 @@ namespace {
} }
void addScriptsMenu() { void addScriptsMenu() {
static std::vector<const Script*> scripts;
static TaskHolder runnerTask, updaterTask; static TaskHolder runnerTask, updaterTask;
hex::ContentRegistry::Interface::addMenuItemSubMenu({ "hex.builtin.menu.extras" }, 5000, [] { hex::ContentRegistry::Interface::addMenuItemSubMenu({ "hex.builtin.menu.extras" }, 5000, [] {
static bool menuJustOpened = true; static bool menuJustOpened = true;
static std::vector<const Script*> scripts;
if (ImGui::BeginMenu("hex.script_loader.menu.run_script"_lang)) { if (ImGui::BeginMenu("hex.script_loader.menu.run_script"_lang)) {
if (menuJustOpened) { if (menuJustOpened) {
@ -107,6 +107,10 @@ namespace {
}, [] { }, [] {
return !runnerTask.isRunning(); return !runnerTask.isRunning();
}); });
updaterTask = TaskManager::createBackgroundTask("Updating Scripts...", [] (auto&) {
scripts = loadAllScripts();
});
} }
} }
@ -119,5 +123,4 @@ IMHEX_PLUGIN_SETUP("Script Loader", "WerWolv", "Script Loader plugin") {
if (initializeAllLoaders()) { if (initializeAllLoaders()) {
addScriptsMenu(); addScriptsMenu();
} }
} }

View file

@ -1,4 +1,5 @@
#include <script_api.hpp> #include <script_api.hpp>
#include <hex/api/content_registry.hpp>
#include <hex/api/imhex_api.hpp> #include <hex/api/imhex_api.hpp>
#include <hex/providers/provider.hpp> #include <hex/providers/provider.hpp>
@ -41,4 +42,64 @@ SCRIPT_API(bool getSelection, u64 *start, u64 *end) {
*end = selection->getEndAddress(); *end = selection->getEndAddress();
return true; return true;
}
class ScriptDataProvider : public hex::prv::Provider {
public:
using ReadFunction = void(*)(u64, void*, u64);
using WriteFunction = void(*)(u64, const void*, u64);
using GetSizeFunction = u64(*)();
using GetNameFunction = std::string(*)();
bool open() override { return true; }
void close() override { }
[[nodiscard]] bool isAvailable() const override { return true; }
[[nodiscard]] bool isReadable() const override { return true; }
[[nodiscard]] bool isWritable() const override { return true; }
[[nodiscard]] bool isResizable() const override { return true; }
[[nodiscard]] bool isSavable() const override { return true; }
[[nodiscard]] bool isDumpable() const override { return true; }
void readRaw(u64 offset, void *buffer, size_t size) override {
m_readFunction(offset, buffer, size);
}
void writeRaw(u64 offset, const void *buffer, size_t size) override {
m_writeFunction(offset, const_cast<void*>(buffer), size);
}
void setFunctions(ReadFunction readFunc, WriteFunction writeFunc, GetSizeFunction getSizeFunc) {
m_readFunction = readFunc;
m_writeFunction = writeFunc;
m_getSizeFunction = getSizeFunc;
}
[[nodiscard]] u64 getActualSize() const override { return m_getSizeFunction(); }
void setTypeName(std::string typeName) { m_typeName = std::move(typeName);}
void setName(std::string name) { m_name = std::move(name);}
[[nodiscard]] std::string getTypeName() const override { return m_typeName; }
[[nodiscard]] std::string getName() const override { return m_name; }
private:
ReadFunction m_readFunction = nullptr;
WriteFunction m_writeFunction = nullptr;
GetSizeFunction m_getSizeFunction = nullptr;
GetNameFunction m_getNameFunction = nullptr;
std::string m_typeName, m_name;
};
SCRIPT_API(void registerProvider, const char *typeName, const char *name, ScriptDataProvider::ReadFunction readFunc, ScriptDataProvider::WriteFunction writeFunc, ScriptDataProvider::GetSizeFunction getSizeFunc) {
auto typeNameString = std::string(typeName);
auto nameString = std::string(name);
hex::ContentRegistry::Provider::impl::add(typeNameString, [typeNameString, nameString, readFunc, writeFunc, getSizeFunc] -> std::unique_ptr<hex::prv::Provider> {
auto provider = std::make_unique<ScriptDataProvider>();
provider->setTypeName(typeNameString);
provider->setName(nameString);
provider->setFunctions(readFunc, writeFunc, getSizeFunc);
return provider;
});
hex::ContentRegistry::Provider::impl::addProviderName(typeNameString);
} }

View file

@ -5,8 +5,42 @@ using System.Runtime.InteropServices;
namespace ImHex namespace ImHex
{ {
public interface IProvider
{
void readRaw(UInt64 address, IntPtr buffer, UInt64 size)
{
unsafe
{
Span<byte> data = new(buffer.ToPointer(), (int)size);
read(address, data);
}
}
void writeRaw(UInt64 address, IntPtr buffer, UInt64 size)
{
unsafe
{
ReadOnlySpan<byte> data = new(buffer.ToPointer(), (int)size);
write(address, data);
}
}
void read(UInt64 address, Span<byte> data);
void write(UInt64 address, ReadOnlySpan<byte> data);
UInt64 getSize();
string getTypeName();
string getName();
}
public class Memory public class Memory
{ {
private static List<IProvider> _registeredProviders = new();
private static List<Delegate> _registeredProviderDelegates = new();
private delegate void DataAccessDelegate(UInt64 address, IntPtr buffer, UInt64 size);
private delegate UInt64 GetSizeDelegate();
[DllImport(Library.Name)] [DllImport(Library.Name)]
private static extern void readMemoryV1(UInt64 address, UInt64 size, IntPtr buffer); private static extern void readMemoryV1(UInt64 address, UInt64 size, IntPtr buffer);
@ -15,6 +49,9 @@ namespace ImHex
[DllImport(Library.Name)] [DllImport(Library.Name)]
private static extern bool getSelectionV1(IntPtr start, IntPtr end); private static extern bool getSelectionV1(IntPtr start, IntPtr end);
[DllImport(Library.Name)]
private static extern int registerProviderV1([MarshalAs(UnmanagedType.LPStr)] string typeName, [MarshalAs(UnmanagedType.LPStr)] string name, IntPtr readFunction, IntPtr writeFunction, IntPtr getSizeFunction);
public static byte[] Read(ulong address, ulong size) public static byte[] Read(ulong address, ulong size)
@ -57,6 +94,25 @@ namespace ImHex
return (start, end); return (start, end);
} }
} }
public static int RegisterProvider<T>() where T : IProvider, new()
{
_registeredProviders.Add(new T());
ref var provider = ref CollectionsMarshal.AsSpan(_registeredProviders)[^1];
_registeredProviderDelegates.Add(new DataAccessDelegate(provider.readRaw));
_registeredProviderDelegates.Add(new DataAccessDelegate(provider.writeRaw));
_registeredProviderDelegates.Add(new GetSizeDelegate(provider.getSize));
return registerProviderV1(
_registeredProviders[^1].getTypeName(),
_registeredProviders[^1].getName(),
Marshal.GetFunctionPointerForDelegate(_registeredProviderDelegates[^3]),
Marshal.GetFunctionPointerForDelegate(_registeredProviderDelegates[^2]),
Marshal.GetFunctionPointerForDelegate(_registeredProviderDelegates[^1])
);
}
} }
} }

View file

@ -1,10 +1,13 @@
using ImHex; using ImHex;
using System.Drawing; class Script {
public static void OnLoad() {
// This function is executed the first time the Plugin is loaded
}
class Script
{
public static void Main() public static void Main()
{ {
UI.ShowMessageBox("Hello World!"); // This function is executed when the plugin is selected in the "Run Script..." menu
} }
} }