GP-2204 addressing code review comments

GP-2204_added_machinelearning_extension
This commit is contained in:
James 2022-09-29 09:11:51 -04:00
parent 1cdb68b03e
commit 956a276387
52 changed files with 6696 additions and 0 deletions

View file

@ -0,0 +1,10 @@
MODULE FILE LICENSE: lib/olcut-config-protobuf-5.2.0.jar BSD-2-ORACLE
MODULE FILE LICENSE: lib/olcut-core-5.2.0.jar BSD-2-ORACLE
MODULE FILE LICENSE: lib/protobuf-java-3.17.3.jar BSD-3-GOOGLE
MODULE FILE LICENSE: lib/tribuo-classification-core-4.2.0.jar Apache License 2.0
MODULE FILE LICENSE: lib/tribuo-classification-tree-4.2.0.jar Apache License 2.0
MODULE FILE LICENSE: lib/tribuo-common-tree-4.2.0.jar Apache License 2.0
MODULE FILE LICENSE: lib/tribuo-core-4.2.0.jar Apache License 2.0
MODULE FILE LICENSE: lib/tribuo-data-4.2.0.jar Apache License 2.0
MODULE FILE LICENSE: lib/tribuo-math-4.2.0.jar Apache License 2.0
MODULE FILE LICENSE: lib/tribuo-util-onnx-4.2.0.jar Apache License 2.0

View file

@ -0,0 +1,46 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
apply from: "$rootProject.projectDir/gradle/distributableGhidraExtension.gradle"
apply from: "$rootProject.projectDir/gradle/javaProject.gradle"
apply from: "$rootProject.projectDir/gradle/javaTestProject.gradle"
apply from: "$rootProject.projectDir/gradle/helpProject.gradle"
apply plugin: 'eclipse'
eclipse.project.name = 'Xtra MachineLearning'
dependencies {
api project(':Base')
helpPath project(path: ":Base", configuration: 'helpPath')
api "com.oracle.labs.olcut:olcut-config-protobuf:5.2.0" //{exclude group: "com.google.protobuf", module: "protobuf-java"}
api ("com.oracle.labs.olcut:olcut-core:5.2.0") {exclude group: "org.jline"}
api "com.google.protobuf:protobuf-java:3.17.3" //only needed for running junits
api "org.tribuo:tribuo-classification-core:4.2.0"
api "org.tribuo:tribuo-classification-tree:4.2.0"
api "org.tribuo:tribuo-common-tree:4.2.0"
api 'org.tribuo:tribuo-core:4.2.0'
api ("org.tribuo:tribuo-data:4.2.0") {exclude group: "com.opencsv"}
api "org.tribuo:tribuo-math:4.2.0"
api ("org.tribuo:tribuo-util-onnx:4.2.0") //{exclude group: "com.google.protobuf", module: "protobuf-java"}
testImplementation project(path: ':SoftwareModeling', configuration: 'testArtifacts')
}

View file

@ -0,0 +1,10 @@
##VERSION: 2.0
##MODULE IP: Apache License 2.0
##MODULE IP: BSD-2-ORACLE
##MODULE IP: BSD-3-GOOGLE
Module.manifest||GHIDRA||||END|
extension.properties||GHIDRA||||END|
lib/README.txt||GHIDRA||||END|
src/main/help/help/TOC_Source.xml||GHIDRA||||END|
src/main/help/help/topics/RandomForestFunctionFinderPlugin/RandomForestFunctionFinderPlugin.htm||GHIDRA||||END|
src/main/resources/images/README.txt||GHIDRA||||END|

View file

@ -0,0 +1,58 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//Writes a list of the addresses of all call sites to a file.
//@category machineLearning
import java.io.*;
import ghidra.app.script.GhidraScript;
import ghidra.program.model.listing.*;
import ghidra.program.model.pcode.PcodeOp;
public class DumpCalls extends GhidraScript {
private static final String DATA_DIR = "/local/calls";
@Override
protected void run() throws Exception {
File outFile = new File(DATA_DIR + File.separator + currentProgram.getName() + "_calls");
FileWriter fWriter = new FileWriter(outFile);
BufferedWriter bWriter = new BufferedWriter(fWriter);
InstructionIterator fIter = currentProgram.getListing().getInstructions(true);
int numCalls = 0;
int numInstructions = 0;
while (fIter.hasNext()) {
Instruction inst = fIter.next();
if (inst.getPcode() == null || inst.getPcode().length == 0) {
continue;
}
numInstructions++;
for (int i = 0; i < inst.getPcode().length; i++) {
PcodeOp pCode = inst.getPcode()[i];
int opCode = pCode.getOpcode();
if (opCode == PcodeOp.CALL || opCode == PcodeOp.CALLIND) {
//printf("Inst: %s at %s\n", inst.toString(), inst.getAddress());
numCalls++;
bWriter.write(inst.getAddress().toString() + "\n");
}
}
}
printf("total num calls: %d\n", numCalls);
printf("total num instructions: %d\n", numInstructions);
bWriter.close();
}
}

View file

@ -0,0 +1,47 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//Writes a list of the addresses of all function starts and their sizes to a file
//@category machineLearning
import java.io.*;
import ghidra.app.script.GhidraScript;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.FunctionIterator;
public class DumpFunctionStarts extends GhidraScript {
private static final String DATA_DIR = "/local/funcstarts/stripped";
@Override
protected void run() throws Exception {
File outFile =
new File(DATA_DIR + File.separator + currentProgram.getName() + "_stripped_funcs");
FileWriter fWriter = new FileWriter(outFile);
BufferedWriter bWriter = new BufferedWriter(fWriter);
FunctionIterator fIter = currentProgram.getFunctionManager().getFunctions(true);
while (fIter.hasNext()) {
Function func = fIter.next();
if (func.isExternal()) {
continue;
}
long size = func.getBody().getNumAddresses();
bWriter.write(func.getEntryPoint().toString() + "," + size + "\n");
}
bWriter.close();
}
}

View file

@ -0,0 +1,74 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.nio.file.Paths;
import org.tribuo.*;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.dtree.CARTClassificationTrainer;
import org.tribuo.classification.ensemble.VotingCombiner;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.common.tree.RandomForestTrainer;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.evaluation.TrainTestSplitter;
public class ExampleTribuoRunner {
public static void main(String args[]) throws IOException {
var irisHeaders =
new String[] { "sepalLength", "sepalWidth", "petalLength", "petalWidth", "species" };
DataSource<Label> irisData = new CSVLoader<>(new LabelFactory()).loadDataSource(
Paths.get("/home/jmworth/ml/bezdekIris.data"), /* Output column */ irisHeaders[4],
/* Column headers */ irisHeaders);
// Split iris data into training set (70%) and test set (30%)
var splitIrisData =
new TrainTestSplitter<>(irisData, /* Train fraction */ 0.7, /* RNG seed */ 1L);
var trainData = new MutableDataset<>(splitIrisData.getTrain());
var testData = new MutableDataset<>(splitIrisData.getTest());
// We can train a decision tree
var cartTrainer = new CARTClassificationTrainer(100, (float) 0.2, 0);
var decisionTree = cartTrainer.train(trainData);
//Model<Label> tree = cartTrainer.train(trainData);
var trainer = new RandomForestTrainer<>(cartTrainer, // trainer - the tree trainer
new VotingCombiner(), // combiner - the combining function for the ensemble
10 // numMembers - the number of ensemble members to train
);
EnsembleModel<Label> tree = trainer.train(trainData);
// Finally we make predictions on unseen data
// Each prediction is a map from the output names (i.e. the labels) to the scores/probabilities
Prediction<Label> prediction = tree.predict(testData.getExample(0));
// Or we can evaluate the full test dataset, calculating the accuracy, F1 etc.
LabelEvaluation evaluation = new LabelEvaluator().evaluate(tree, testData);
// we can inspect the evaluation manually
double acc = evaluation.accuracy();
// which returns 0.978
// or print a formatted evaluation string
System.out.println(evaluation.toString());
}
}

View file

@ -0,0 +1,5 @@
name=MachineLearning
description=Finds functions using ML
author=Ghidra Team
createdOn=9/25/2022
version=@extversion@

View file

@ -0,0 +1,200 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//Example script for training random forests to find function starts
//@category machineLearning
import java.util.*;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import ghidra.app.cmd.disassemble.DisassembleCommand;
import ghidra.app.cmd.function.CreateFunctionCmd;
import ghidra.app.script.GhidraScript;
import ghidra.machinelearning.functionfinding.*;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.block.BasicBlockModel;
//NOTE: This script is referenced by name in the help for the
//RandomForestFunctionFinderPlugin. If you change the name be
//sure to update the help.
public class FindFunctionsRFExampleScript extends GhidraScript {
@Override
protected void run() throws Exception {
//get the parameters controlling how many models are trained and
//what data is used to train/test them
FunctionStartRFParams params = new FunctionStartRFParams(currentProgram);
//maximum number of function starts to use in the training set
//a warning will be issued if there are no functions left over for the test set
params.setMaxStarts(1000);
//minimum size of a function to be included in the training/test sets
params.setMinFuncSize(16);
//number of bytes before a function start
params.setPreBytes(Arrays.asList(new Integer[] { 2, 8 }));
//number of bytes after (and including) a function start
params.setInitialBytes(Arrays.asList(new Integer[] { 8, 16 }));
//number of non-starts to sample for each start in the training set
params.setFactors(Arrays.asList(new Integer[] { 10, 50 }));
//for every function start in the (test,training) set, add the code
//units immediately before and immediately after to the (test,training) set
//as non-starts.
params.setIncludePrecedingAndFollowing(true);
//uncomment to include features for each bit rather than just byte-level features
//params.setIncludeBitFeatures(true);
//bound for reducing the size of the test sets
long testSetMax = 1000000l;
//this is where the trained models will go
List<RandomForestRowObject> trainedModels = new ArrayList<>();
RandomForestTrainingTask trainingTask = new RandomForestTrainingTask(currentProgram, params,
r -> trainedModels.add(r), testSetMax);
//launch the task to train the models in parallel
trainingTask.run(monitor);
//sort the models by the number of false positives (ascending)
//if you actually need *unsigned* comparison it's likely that something has gone
//horribly wrong
Collections.sort(trainedModels,
(x, y) -> Integer.compareUnsigned(x.getNumFalsePositives(), y.getNumFalsePositives()));
//grab the model with the fewest false positives
//note: there could be ties; could sort the winners by recall
RandomForestRowObject best = trainedModels.get(0);
printf(
"Best model: pre-bytes: %d, initialBytes: %d, sampling factor: %d, false positives: %d," +
" precision: %g, recall: %g\n",
best.getNumPreBytes(), best.getNumInitialBytes(), best.getSamplingFactor(),
best.getNumFalsePositives(), best.getPrecision(), best.getRecall());
//to get more information about the test set errors, apply the model to the error set
FunctionStartClassifier classifier = new FunctionStartClassifier(currentProgram, best,
RandomForestFunctionFinderPlugin.FUNC_START);
//uncomment to see false negatives as well
//classifier.setProbabilityThreshold(0.0);
Map<Address, Double> errors = classifier.classify(best.getTestErrors(), monitor);
List<Entry<Address, Double>> falsePositives = errors.entrySet()
.stream()
.filter(x -> currentProgram.getFunctionManager().getFunctionAt(x.getKey()) == null)
.sorted((x, y) -> Double.compare(y.getValue(), x.getValue()))
.toList();
//print out addresses of false positives
printf("False positives:\n");
falsePositives.forEach(x -> printf(" %s %g\n", x.getKey().toString(), x.getValue()));
//show the true function starts most similar to one of the false positives
if (!falsePositives.isEmpty()) {
SimilarStartsFinder finder = new SimilarStartsFinder(currentProgram, best);
List<SimilarStartRowObject> neighbors =
finder.getSimilarFunctionStarts(falsePositives.get(0).getKey(), 10);
printf("\nClosest function starts to false positive at %s :\n",
falsePositives.get(0).getKey());
neighbors.forEach(n -> printf(" %s %d\n", n.funcStart(), n.numAgreements()));
}
//grab the set of bytes in executable memory which are undefined (i.e., initialized but
//not yet assigned to be code or data) or instructions which are not assigned to a function
//body
//don't bother looking in small undefined ranges
long minUndefinedRange = 16;
GetAddressesToClassifyTask getAddressTask =
new GetAddressesToClassifyTask(currentProgram, minUndefinedRange);
getAddressTask.run(monitor);
AddressSet toClassify = getAddressTask.getAddressesToClassify();
Map<Address, Double> potentialStarts = classifier.classify(toClassify, monitor);
//grab all of the addresses with probability of being a function start >= .7
//and disassemble any that are currently undefined
//block model is needed to get the interpretation of an address
//e.g., undefined, block start, within block,...
BasicBlockModel blockModel = new BasicBlockModel(currentProgram);
//@formatter:off
List<Address> addresses = potentialStarts.entrySet()
.stream()
.filter(x -> {return x.getValue() >= 0.7d;})
.map(x -> x.getKey())
.collect(Collectors.toList());
//@formatter:on
AddressSet toDisassemble = new AddressSet();
for (Address addr : addresses) {
Interpretation inter =
Interpretation.getInterpretation(currentProgram, addr, blockModel, monitor);
if (inter.equals(Interpretation.UNDEFINED)) {
toDisassemble.add(addr);
}
}
//see DisassemblyAndApplyContextAction#actionPerfomed if you need to
//apply context register values
printf("Found %d addresses to disassemble\n", toDisassemble.getNumAddresses());
DisassembleCommand cmd = new DisassembleCommand(toDisassemble, null, true);
cmd.applyTo(currentProgram);
//create functions at any address with probability of being a function start > .8,
//where an instruction exists,
//which is defined as a BLOCK_START (so not already a function start)
//and with no conditional flow references to it
//FunctionStartRowObject represents a row in a table in the gui, but you don't
//actually need the gui and it's a convenient container for information about an address
//@formatter:off
List<FunctionStartRowObject> funcRows = potentialStarts.entrySet()
.stream()
.filter(x -> {return x.getValue() >= 0.8d;})
.map(x -> new FunctionStartRowObject(x.getKey(),x.getValue()))
.collect(Collectors.toList());
//@formatter:on
for (FunctionStartRowObject funcRow : funcRows) {
funcRow.setCurrentInterpretation(Interpretation.getInterpretation(currentProgram,
funcRow.getAddress(), blockModel, monitor));
FunctionStartRowObject.setReferenceData(funcRow, currentProgram);
}
AddressSet entries = new AddressSet();
funcRows.stream()
.filter(x -> x.getCurrentInterpretation().equals(Interpretation.BLOCK_START))
.filter(x -> x.getNumConditionalFlowRefs() == 0)
.forEach(x -> entries.add(x.getAddress()));
printf("Found %d addresses to create functions\n", entries.getNumAddresses());
CreateFunctionCmd createCmd = new CreateFunctionCmd(entries);
createCmd.applyTo(currentProgram);
}
}

View file

@ -0,0 +1,33 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Turns off function start searching (intended for use with the
// headless analyzer as a prescript)
//@category machineLearning
import ghidra.app.script.GhidraScript;
public class TurnOffFuncStartSearch extends GhidraScript {
@Override
protected void run() throws Exception {
setAnalysisOption(currentProgram, "Function Start Search", "false");
setAnalysisOption(currentProgram, "Function Start Search After Code", "false");
setAnalysisOption(currentProgram, "Function Start Search After Data", "false");
setAnalysisOption(currentProgram, "Function ID", "false");
}
}

View file

@ -0,0 +1,3 @@
The "lib" directory is intended to hold Jar files which this contrib
is dependent upon. This directory may be eliminated from a specific
contrib if no other Jar files are needed.

View file

@ -0,0 +1,57 @@
<?xml version="1.0" encoding="ISO-8859-1"?>
<!--
This is an XML file intended to be parsed by the Ghidra help system. It is loosely based
upon the JavaHelp table of contents document format. The Ghidra help system uses a
TOC_Source.xml file to allow a module with help to define how its contents appear in the
Ghidra help viewer's table of contents. The main document (in the Base module)
defines a basic structure for the
Ghidra table of contents system. Other TOC_Source.xml files may use this structure to insert
their files directly into this structure (and optionally define a substructure).
In this document, a tag can be either a <tocdef> or a <tocref>. The former is a definition
of an XML item that may have a link and may contain other <tocdef> and <tocref> children.
<tocdef> items may be referred to in other documents by using a <tocref> tag with the
appropriate id attribute value. Using these two tags allows any module to define a place
in the table of contents system (<tocdef>), which also provides a place for
other TOC_Source.xml files to insert content (<tocref>).
During the help build time, all TOC_Source.xml files will be parsed and validated to ensure
that all <tocref> tags point to valid <tocdef> tags. From these files will be generated
<module name>_TOC.xml files, which are table of contents files written in the format
desired by the JavaHelp system. Additionally, the genated files will be merged together
as they are loaded by the JavaHelp system. In the end, when displaying help in the Ghidra
help GUI, there will be on table of contents that has been created from the definitions in
all of the modules' TOC_Source.xml files.
Tags and Attributes
<tocdef>
-id - the name of the definition (this must be unique across all TOC_Source.xml files)
-text - the display text of the node, as seen in the help GUI
-target** - the file to display when the node is clicked in the GUI
-sortgroup - this is a string that defines where a given node should appear under a given
parent. The string values will be sorted by the JavaHelp system using
a javax.text.RulesBasedCollator. If this attribute is not specified, then
the text of attribute will be used.
<tocref>
-id - The id of the <tocdef> that this reference points to
**The URL for the target is relative and should start with 'help/topics'. This text is
used by the Ghidra help system to provide a universal starting point for all links so that
they can be resolved at runtime, across modules.
-->
<tocroot>
<tocref id="Program Search">
<tocdef id="Random Forest Function Finder"
text="Search for Code and Functions"
target="help/topics/RandomForestFunctionFinderPlugin/RandomForestFunctionFinderPlugin.htm"
sortgroup="zzz">
</tocdef>
</tocref>
</tocroot>

View file

@ -0,0 +1,64 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
WARNING!
This file is copied to all help directories. If you change this file, you must copy it
to each src/main/help/help/shared directory.
Java Help Note: JavaHelp does not accept sizes (like in 'margin-top') in anything but
px (pixel) or with no type marking.
*/
body { margin-bottom: 50px; margin-left: 10px; margin-right: 10px; margin-top: 10px; } /* some padding to improve readability */
li { font-family:times new roman; font-size:14pt; }
h1 { color:#000080; font-family:times new roman; font-size:36pt; font-style:italic; font-weight:bold; text-align:center; }
h2 { margin: 10px; margin-top: 20px; color:#984c4c; font-family:times new roman; font-size:18pt; font-weight:bold; }
h3 { margin-left: 10px; margin-top: 20px; color:#0000ff; font-family:times new roman; `font-size:14pt; font-weight:bold; }
h4 { margin-left: 10px; margin-top: 20px; font-family:times new roman; font-size:14pt; font-style:italic; }
/*
P tag code. Most of the help files nest P tags inside of blockquote tags (the was the
way it had been done in the beginning). The net effect is that the text is indented. In
modern HTML we would use CSS to do this. We need to support the Ghidra P tags, nested in
blockquote tags, as well as naked P tags. The following two lines accomplish this. Note
that the 'blockquote p' definition will inherit from the first 'p' definition.
*/
p { margin-left: 40px; font-family:times new roman; font-size:14pt; }
blockquote p { margin-left: 10px; }
p.providedbyplugin { color:#7f7f7f; margin-left: 10px; font-size:14pt; margin-top:100px }
p.ProvidedByPlugin { color:#7f7f7f; margin-left: 10px; font-size:14pt; margin-top:100px }
p.relatedtopic { color:#800080; margin-left: 10px; font-size:14pt; }
p.RelatedTopic { color:#800080; margin-left: 10px; font-size:14pt; }
/*
We wish for a tables to have space between it and the preceding element, so that text
is not too close to the top of the table. Also, nest the table a bit so that it is clear
the table relates to the preceding text.
*/
table { margin-left: 20px; margin-top: 10px; width: 80%;}
td { font-family:times new roman; font-size:14pt; vertical-align: top; }
th { font-family:times new roman; font-size:14pt; font-weight:bold; background-color: #EDF3FE; }
/*
Code-like formatting for things such as file system paths and proper names of classes,
methods, etc. To apply this to a file path, use this syntax:
<CODE CLASS="path">...</CODE>
*/
code { color: black; font-weight: bold; font-family: courier new, monospace; font-size: 14pt; white-space: nowrap; }
code.path { color: #4682B4; font-weight: bold; font-family: courier new, monospace; font-size: 14pt; white-space: nowrap; }

View file

@ -0,0 +1,216 @@
<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN">
<HTML>
<HEAD>
<META http-equiv="Content-Language" content="en-us">
<META http-equiv="Content-Type" content="text/html; charset=windows-1252">
<TITLE>Random Forest Function Finder Plugin</TITLE>
<LINK rel="stylesheet" type="text/css" href="../../shared/Frontpage.css">
</HEAD>
<BODY>
<H1><A name="RandomForestFunctionFinderPlugin"></A> Random Forest Function Finder Plugin</H1>
<P> This plugin trains models used to find function starts within a program. Essentially,
the training set consists of addresses in a program where Ghidra's analysis was able to
find functions. The models are then applied to the rest of the program.
Models can also be applied to other programs. </P>
<P> In the motivating use case, you either don't know the toolchain which produced a program
or do not have a large number of sample programs to train other types of models. </P>
<P> Note: in general, this plugin ensures that addresses used for training, testing, or
searching for function starts are aligned relative to the processor's instruction alignment.
Defined data within an executable block is an exception - all such bytes are added to
the test set as examples of non-starts.</P>
<H2><A name="SuggestedWorkflow"></A> Basic Suggested Workflow</H2>
<ol>
<li> To begin, select <em>Search-&gt;For Code And Functions...</em> from the Code Browser.</li>
<li> Click the <em>Train</em> button to train models using the default parameters.</li>
<li> Choose the model with the fewest false positives (which will be apparent from
the <em>Model Statistics</em> table).</li>
<li> Right-click on that model's row and select <em>DEBUG - Show test set errors</em>.</li>
<li> Examine the resulting table to determine if there is a good cutoff for
the probabilities. Note that some of the "errors" might not actually be
errors of the model: see the discussion in
<A href="#DebugModelTable">Debug Model Table</A>.</li>
<li> If you're satisified with the performance of the model, right-click on the
row and select <em>Apply Model</em>. If you aren't, you can try changing the parameters
and training again. You can also try the <em>Include Bit Features</em> training option.</li>
<li> In the resulting table, select all addresses with an <em>Undefined</em> interpretation whose
probability is above your threshold, right-click, and select <em>Disassemble</em>. This will
start disassembly (and follow-on analysis) at each selected address.</li>
<li> Now, select all addresses whose interpretation is <em>Block Start</em> and whose probability
of being a function is above your threshold, right-click, and select <em>Create Function(s)</em>.
It's also probably worth filtering out any addresses which are the targets of
conditional references (which can be seen in the <em>Conditional Flow Refs</em> column). </li>
</ol>
<P> The script <em>FindFunctionsRFExampleScript.java</em> shows how to access the functionality of
this plugin programmatically. </P>
<H2><A name="ModelTrainingTable"></A> Model Training Table</H2>
<P> This table is the main interface for training and applying models. </P>
<H3><A name="DataGatheringParameters"></A> Data Gathering Parameters</H3>
<P> The values in this panel control the number of models trained and the data used to train them.
The first three fields: <A href="#NumberOfPreBytes">Number of Pre-Bytes (CSV)</A>,
<A href="#NumberOfInitialBytes">Number of Initial Bytes (CSV)</A>, and
<A href="#StartToNonStartFactors"> Start to Non-start Sampling Factors (CSV)</A> accept
CSVs of positive integers as input (a single integer with no comma is allowed). Models
corresponding to all possible choices of the three values will be trained and evaluated.
That is, if you enter two values for the <em>Pre-bytes</em> field, three values for the
<em>Initial Bytes</em> field, and four values for the <em>Sampling Factors</em> field, a total
of 2*3*4 = 24 models will be trained and evaluated. </P>
<H4><A name="NumberOfPreBytes"></A> Number of Pre-bytes (CSV) </H4>
<P> Values in this list control how many bytes before an address are used to construct its
feature vector. </P>
<H4><A name="NumberOfInitialBytes"></A> Number of Initial Bytes (CSV) </H4>
<P> Values in this list control how many bytes are used to construct the feature vector
of an address, starting at the address. </P>
<H4><A name="StartToNonStartFactors"></A> Start to Non-start Sampling Factors (CSV) </H4>
<P> Values in this list control how many non-starts (i.e., addresses in the interiors
of functions) are added to the training set for each function start in the training set.</P>
<H4><A name="MaximumNumberOfStarts"></A> Maximum Number Of Starts </H4>
<P> This field controls the maximum number of function starts that are added to the training
set. </P>
<H4><A name="ContextRegsAndValues"></A> Context Registers and Values (CSV) </H4>
<P> This field allows you to specify values of context registers. Addresses will only
be added to the training/test sets if they agree with these values, and the disassembly
action on the <a href="#FunctionStartTable"> Potential Functions Table</a> will apply the
context register values first. This field accepts CSVs of the form "creg1=x,creg2=y,...".
For example, to restrict Thumb mode in an ARM program, you would enter "TMode=1" in this field.
</P>
<H4><A name="IncludePrecedingAndFollowing"></A> Include Preceding and Following </H4>
<P> If this is selected, for every function entry in the training set, the code units immediately
before it and after it are added to the training set as negative examples (and similarly for the
test set). </P>
<H4><A name="IncludeBitFeatures"></A> Include Bit Features </H4>
<P> If this is selected, a binary feature is added to the feature vector for each bit in the
recorded bytes. </P>
<H4><A name="MinimumFunctionSize"></A> Minimum Function Size </H4>
<P> This value is the minimum size a function must be for its entry and interior to be included
in the training and test sets. </P>
<H3><A name="FunctionInformation"></A> Function Information</H3>
<P> This panel displays information about the functions in the program. </P>
<H4><A name="FunctionsMeetingSizeBound"></A> Functions Meeting Size Bound</H4>
<P> This field displays the number of functions meeting the size bound in the
<A href="#MinimumFunctionSize"> Minimum Function Size</A> field. You can use
this to ensure that the value in <A href="#MaximumNumberOfStarts">Maximum Number
of Starts</A> field doesn't cause all starts to be used for training (leaving
none for testing).</P>
<H4><A name="RestrictSearchToAlignedAddresses"></A> Restrict Search to Aligned Addresses </H4>
<P> If this is checked, only addresses which are zero modulo the value in the
<A href="#AlignmentModulus">Alignment Modulus</A> combo box are searched for function starts.
This does not affect training or testing, but can be a useful optimization when applying
models, for instance when the <A href="#FunctionAlignmentTable">Function Alignment Table</A>
shows that all (known) functions in the program are aligned on 16-byte boundaries. </P>
<H4><A name="AlignmentModulus"></A> Alignment Modulus </H4>
<P> The value in this combo box determines the modulus used when computing the values in
the <A href="#FunctionAlignmentTable">Function Alignment Table</A>. </P>
<H4><A name="FunctionAlignmentTable"></A> Function Alignment Table </H4>
<P> The rows in this table display the number of (known) functions in the program
whose address has the given remainder modulo the alignment modulus.</P>
<H3><A name="ModelStatistics"></A> Model Statistics</H3>
<P> This panel displays the statistics about the trained models as rows in a table.
Actions on these rows allow you to apply the models or see the test set failures.</P>
<H4><A name="ApplyModel"></A> Apply Model Action </H4>
<P> This action will apply the model to the program used to train it. The addresses
searched consist of all addresses which are loaded, initialized, marked as executable,
and not already in a function body (this set can be modified by the user via the
<A href="#RestrictSearchToAlignedAddresses"> Restrict Search to Aligned Addresses</A>
and <A href="#MinLengthUndefinedRange"> Minimum Length of Undefined Ranges to Search</A>
options). The results are displayed in a
<A href="#FunctionStartTable"> Function Start Table</A>. </P>
<H4><A name="ApplyModelTo"></A> Apply Model To... Action </H4>
<P> This action will open a dialog to select another program in the current project and
then apply the model to it. Note that the only check that the model is compatible with
the selected program is that any context registers specified when training must be
present in the selected program. </P>
<H4><A name="DebugModel"></A> Debug Model Action </H4>
<P> This action will display a <A href="#DebugModelTable"> Debug Model Table</A>, which shows
all of the errors encountered when applying the model to its test set. </P>
<H2><A name="FunctionStartTable"></A> Potential Functions Table</H2>
<P> This table displays all addresses in the search set which the model thinks are function starts
with probability at least .5. The table also shows the current "Interpretation" (e.g., undefined,
instruction at start of basic block, etc) of the address along with the numbers of certain types
of references to the address. </P>
<P> The following actions are defined on this table:</P>
<H3><A name="DisassembleAction"></A> Disassemble Action </H3>
<P> This action is enabled when at least one of the selected rows corresponds to an address
with an interpretation of "Undefined". It begins disassembly at each "Undefined" address
corresponding to a row in the selection. </P>
<H3><A name="DisassembleAndApplyContextAction"></A> Disassemble and Apply Context Action </H3>
<P> This action is similar to the <A href="#DisassembleAction">Disassemble Action</A>, except
it sets the context register values specified before training the model at the addresses
and then disassembles. </P>
<H3><A name="CreateFunctionsAction"></A> Create Functions Action </H3>
<P> This action is enabled whenever the selection contains at least one row whose corresponding
address is the start of a basic block. This action creates functions at all such addresses.</P>
<H3><A name="ShowSimilarStartsAction"></A> Show Similar Function Starts Action </H3>
<P> This action is enabled when the selection contains exactly one row. It displays
a <A href="#SimilarStartsTable"> table</A> of the function starts in the training set
which are most similar to the bytes at the address of the row. </P>
<H2><A name="SimilarStartsTable"></A> Similar Function Starts Table </H2>
<P> This table displays the function starts in the training set which are most similar
to a potential function start "from the model's point of view". Formally, similarity
is measured using <b>random forest proximity</b>. Given a potential start <i>p</i> and
a known start <i>s</i>, the similarity of <i>p</i> and <i>s</i> is the proportion of trees
which end up in the same leaf node when processing <i>p</i> and <i>s</i>. </P>
<P> For convenience, the potential start is also displayed as a row in the table. In
the Address column, its address is surrounded by asterisks.</P>
<H2><A name="DebugModelTable"></A> Debug Model Table </H2>
<P> This table has the same format as the <A href="#FunctionStartTable">Potential Functions Table</A>
but does not have the disassembly or function-creating actions (it does have the action to
display similar function starts). It displays all addresses in the test set where the classifier
made an error. Note that some in some cases, it might be the classifier which is correct and the
original analysis which was wrong. A common example is a tail call which
was optimized to a jump during compilation. If there is only one jump to this address, then analysis
may (reasonably) think that the function is just part of the function containing the jump even though
the classifier thinks the jump target is a function start.</P>
<H2><A name="Options"></A> Options </H2>
<P> This plugin has the following options. They can be set in the Tool Options menu. </P>
<H3><A name="MaxTestSetSize"></A> Maximum Test Set Size </H3>
<P> This option controls the maximum size of the test sets (the test set of function
starts and the test set of known non-starts which together form the model's "test set").
Each set that is larger than the maximum will be replaced with a random subset of the maximum size. </P>
<H3><A name="MinLengthUndefinedRange"></A> Minimum Length of Undefined Ranges to Search </H3>
<P> This option controls the minimum length a run of undefined bytes must be in order to
be searched for function starts. This is an optimization which allows you to skip the
(often quite numerous) small runs of undefined bytes between adjacent functions. Note
that this option has no effect on model training or evaluation. </P>
<P class="providedbyplugin">Provided By: <I>RandomForestFunctionFinderPlugin</I></P>
</BODY>
</HTML>

View file

@ -0,0 +1,98 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import docking.ActionContext;
import docking.action.DockingAction;
import docking.action.MenuData;
import ghidra.app.cmd.function.CreateFunctionCmd;
import ghidra.framework.plugintool.Plugin;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.listing.Program;
import ghidra.util.HelpLocation;
import ghidra.util.table.GhidraTable;
/**
* A {@link DockingAction} for creating functions from rows in a {@link FunctionStartTableModel}.
* When performed on a selection, functions are created at all rows in the selection whose
* {@link Interpretation} is {@link Interpretation#BLOCK_START}.
*/
public class CreateFunctionsAction extends DockingAction {
private static final String MENU_TEXT = "Create Function(s)";
private static final String ACTION_NAME = "CreateFunctionsAction";
private Program program;
private FunctionStartTableModel model;
private GhidraTable table;
private Plugin plugin;
/**
* Constructs an action for creating functions.
* @param plugin plugin
* @param program source program
* @param table table
* @param model table model
*/
public CreateFunctionsAction(Plugin plugin, Program program, GhidraTable table,
FunctionStartTableModel model) {
super(ACTION_NAME, plugin.getName());
this.program = program;
this.model = model;
this.plugin = plugin;
this.table = table;
init();
}
@Override
public boolean isAddToPopup(ActionContext context) {
return true;
}
@Override
public boolean isEnabledForContext(ActionContext context) {
for (FunctionStartRowObject row : model.getRowObjects(table.getSelectedRows())) {
switch (row.getCurrentInterpretation()) {
case BLOCK_START:
return true;
default:
break;
}
}
return false;
}
@Override
public void actionPerformed(ActionContext context) {
AddressSet entries = new AddressSet();
for (FunctionStartRowObject row : model.getRowObjects(table.getSelectedRows())) {
switch (row.getCurrentInterpretation()) {
case BLOCK_START:
entries.add(row.getAddress());
default:
break;
}
}
CreateFunctionCmd cmd = new CreateFunctionCmd(entries);
plugin.getTool().executeBackgroundCommand(cmd, program);
}
private void init() {
setPopupMenuData(new MenuData(new String[] { MENU_TEXT }));
setDescription(
String.format("Creates functions at all %s rows", Interpretation.BLOCK_START.name()));
setHelpLocation(new HelpLocation(plugin.getName(), ACTION_NAME));
}
}

View file

@ -0,0 +1,100 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.math.BigInteger;
import java.util.List;
import docking.ActionContext;
import docking.action.DockingAction;
import docking.action.MenuData;
import ghidra.app.cmd.disassemble.DisassembleCommand;
import ghidra.framework.plugintool.Plugin;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.lang.Register;
import ghidra.program.model.lang.RegisterValue;
import ghidra.program.model.listing.Program;
import ghidra.program.model.listing.ProgramContext;
import ghidra.util.HelpLocation;
import ghidra.util.table.GhidraTable;
/**
* A {@link DockingAction} for disassembling at addresses corresponding to rows in a
* {@link FunctionStartTableModel}. Context register values specified before training
* will be applied before disassembly.
*/
public class DisassembleAndApplyContextAction extends DisassembleFunctionStartsAction {
private static final String ACTION_NAME = "DisassembleAndApplyContextAction";
private static final String MENU_TEXT = "Disassemble and Apply Context";
private Program program;
private FunctionStartTableModel model;
private Plugin plugin;
private GhidraTable table;
/**
* Creates an action for disassembling at rows in a {@link FunctionStartTableModel} if the
* {@link Interpretation} of the row is {@link Interpretation#UNDEFINED}. Specified
* context register values are set before disassembly.
* @param plugin owning plugin
* @param program source program
* @param table table
* @param model table model
*/
public DisassembleAndApplyContextAction(Plugin plugin, Program program, GhidraTable table,
FunctionStartTableModel model) {
super(plugin, program, table, model);
this.program = program;
this.model = model;
this.plugin = plugin;
this.table = table;
init();
}
@Override
public void actionPerformed(ActionContext context) {
AddressSet entries = new AddressSet();
for (FunctionStartRowObject row : model.getRowObjects(table.getSelectedRows())) {
switch (row.getCurrentInterpretation()) {
case UNDEFINED:
entries.add(row.getAddress());
default:
break;
}
}
DisassembleCommand cmd = new DisassembleCommand(entries, null, true);
RandomForestRowObject row = model.getRandomForestRowObject();
if (row.isContextRestricted()) {
ProgramContext programContext = program.getProgramContext();
List<String> regNames = row.getContextRegisterList();
List<BigInteger> regValues = row.getContextRegisterValues();
RegisterValue newValue = new RegisterValue(programContext.getBaseContextRegister());
for (int i = 0; i < regNames.size(); ++i) {
Register reg = program.getRegister(regNames.get(i));
newValue = newValue.combineValues(new RegisterValue(reg, regValues.get(i)));
}
cmd.setInitialContext(newValue);
}
plugin.getTool().executeBackgroundCommand(cmd, program);
}
private void init() {
setPopupMenuData(new MenuData(new String[] { MENU_TEXT }));
setDescription("Apply context and disassemble at the selected rows");
setHelpLocation(new HelpLocation(plugin.getName(), ACTION_NAME));
}
}

View file

@ -0,0 +1,96 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import docking.ActionContext;
import docking.action.DockingAction;
import docking.action.MenuData;
import ghidra.app.cmd.disassemble.DisassembleCommand;
import ghidra.framework.plugintool.Plugin;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.listing.Program;
import ghidra.util.HelpLocation;
import ghidra.util.table.GhidraTable;
/**
* A {@link DockingAction} for disassembling at addresses corresponding to rows in a
* {@link FunctionStartTableModel}.
*/
public class DisassembleFunctionStartsAction extends DockingAction {
private static final String ACTION_NAME = "DisassembleAction";
private static final String MENU_TEXT = "Disassemble";
private Program program;
private FunctionStartTableModel model;
private GhidraTable table;
private Plugin plugin;
/**
* Creates and action for disassembling at rows in a {@link FunctionStartTableModel} if the
* {@link Interpretation} of the row is {@link Interpretation#UNDEFINED}.
* @param plugin owning plugin
* @param program source program
* @param table table
* @param model table model
*/
public DisassembleFunctionStartsAction(Plugin plugin, Program program, GhidraTable table,
FunctionStartTableModel model) {
super(ACTION_NAME, plugin.getName());
this.program = program;
this.model = model;
this.plugin = plugin;
this.table = table;
init();
}
@Override
public boolean isAddToPopup(ActionContext context) {
return true;
}
@Override
public boolean isEnabledForContext(ActionContext context) {
for (FunctionStartRowObject row : model.getRowObjects(table.getSelectedRows())) {
switch (row.getCurrentInterpretation()) {
case UNDEFINED:
return true;
default:
break;
}
}
return false;
}
@Override
public void actionPerformed(ActionContext context) {
AddressSet entries = new AddressSet();
for (FunctionStartRowObject row : model.getRowObjects(table.getSelectedRows())) {
switch (row.getCurrentInterpretation()) {
case UNDEFINED:
entries.add(row.getAddress());
default:
break;
}
}
DisassembleCommand cmd = new DisassembleCommand(entries, null, true);
plugin.getTool().executeBackgroundCommand(cmd, program);
}
private void init() {
setPopupMenuData(new MenuData(new String[] { MENU_TEXT }));
setDescription("Disassemble at the selected rows");
setHelpLocation(new HelpLocation(plugin.getName(), ACTION_NAME));
}
}

View file

@ -0,0 +1,97 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.List;
import org.tribuo.*;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.impl.ArrayExample;
import generic.concurrent.QCallback;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Program;
import ghidra.util.task.TaskMonitor;
/**
* This class is used as a callback for parallelized ensemble evaluation. Rather than
* computing the precise probability the ensemble would assign to a given address and label,
* it only computes whether the probability is >= .5. The computation is short-circuited:
* as soon as enough members of the ensemble have been checked to determine the return value
* the computation stops.
*/
public class EnsembleEvaluatorCallback implements QCallback<Address, Boolean> {
private EnsembleModel<Label> ensemble;
private int numModels;
private int numPreBytes;
private int numInitialBytes;
private Label label;
private Program program;
private boolean includeBitFeatures;
/**
* Create a new evaluator
* @param ensemble ensemble to evaluate
* @param p program containing addresses to test
* @param numPreBytes number of bytes before address
* @param numInitialBytes number of bytes after and including address
* @param includeBitFeatures whether to include bit-level features
* @param label target label
*/
public EnsembleEvaluatorCallback(EnsembleModel<Label> ensemble, Program p, int numPreBytes,
int numInitialBytes, boolean includeBitFeatures, Label label) {
this.ensemble = ensemble;
numModels = ensemble.getNumModels();
this.numPreBytes = numPreBytes;
this.numInitialBytes = numInitialBytes;
this.label = label;
program = p;
this.includeBitFeatures = includeBitFeatures;
}
@Override
public Boolean process(Address item, TaskMonitor monitor) throws Exception {
List<Feature> trainingVector = ModelTrainingUtils.getFeatureVector(program, item,
numPreBytes, numInitialBytes, includeBitFeatures);
if (trainingVector.isEmpty()) {
return null;
}
ArrayExample<Label> vec = new ArrayExample<>(label, trainingVector);
int numAgree = 0;
int numDisagree = 0;
for (int i = 0; i < numModels; i++) {
Model<Label> model = ensemble.getModels().get(i);
Prediction<Label> pred = model.predict(vec);
if (pred.getOutput().equals(label)) {
numAgree += 1;
}
else {
numDisagree += 1;
}
if (numAgree == (numModels + 1) / 2) {
return true;
}
if (numDisagree == (numModels / 2) + 1) {
return false;
}
}
throw new AssertionError("did not return value");
}
}

View file

@ -0,0 +1,53 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
/**
* A member of this class is a row in a {@link FunctionStartAlignmentTableModel}
*/
public class FunctionStartAlignmentRowObject {
private long remainder;
private long numFuncs;
/**
* Creates a row showing how many functions whose entry points have
* a remainder of {@code remainder} when divided by the alignment modulus.
* @param remainder remainder after division by alignment modulus
* @param numFuncs number of functions
*/
public FunctionStartAlignmentRowObject(long remainder, long numFuncs) {
this.remainder = remainder;
this.numFuncs = numFuncs;
}
/**
* Returns the remainder
* @return remainder
*/
public long getRemainder() {
return remainder;
}
/**
* Returns the number of functions
* @return num funcs
*/
public long getNumFuncs() {
return numFuncs;
}
}

View file

@ -0,0 +1,82 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.List;
import docking.widgets.table.AbstractSortedTableModel;
/**
* A table used to show how many functions have addresses aligned (or not aligned) relative to a
* given modulus.
*/
public class FunctionStartAlignmentTableModel
extends AbstractSortedTableModel<FunctionStartAlignmentRowObject> {
private List<FunctionStartAlignmentRowObject> rows;
/**
* Creates a table with the supplies rows
* @param rows rows
*/
public FunctionStartAlignmentTableModel(List<FunctionStartAlignmentRowObject> rows) {
this.rows = rows;
}
@Override
public boolean isSortable(int columnIndex) {
return true;
}
@Override
public int getColumnCount() {
return 2;
}
@Override
public String getColumnName(int columnIndex) {
switch (columnIndex) {
case 0:
return "Remainder";
case 1:
return "Number of Functions";
default:
throw new IllegalArgumentException("Invalid column index");
}
}
@Override
public String getName() {
return "Alignment";
}
@Override
public List<FunctionStartAlignmentRowObject> getModelData() {
return rows;
}
@Override
public Object getColumnValueForRow(FunctionStartAlignmentRowObject t, int columnIndex) {
switch (columnIndex) {
case 0:
return t.getRemainder();
case 1:
return t.getNumFuncs();
default:
throw new IllegalArgumentException("Invalid column index");
}
}
}

View file

@ -0,0 +1,84 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.List;
import org.tribuo.Feature;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.impl.ArrayExample;
import generic.concurrent.QCallback;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Program;
import ghidra.util.task.TaskMonitor;
/**
* A {@link QCallback} which is used to apply a model at a given address to determine the
* probability that the address represents a function start.
*/
class FunctionStartCallback implements QCallback<Address, Double> {
private EnsembleModel<Label> randomForest;
private int numPreBytes;
private int numInitialBytes;
private Program program;
private int alignment;
private Label target;
private boolean includeBitLevelFeatures;
/**
* Creates a callback which applies {@code model} to addresses in {@code program} using
* the data-gathering parameters {@code numPreBytes} and {@code numInitialBytes} to
* determine the probability the address has label {@label}.
* @param model model to apply
* @param numPreBytes bytes before address to gather
* @param numInitialBytes bytes after address to gather
* @param includeBitLevelFeatures whether to include bit-level features
* @param program source program
* @param target target label
*/
public FunctionStartCallback(EnsembleModel<Label> model, int numPreBytes, int numInitialBytes,
boolean includeBitLevelFeatures, Program program, Label target) {
this.randomForest = model;
this.numPreBytes = numPreBytes;
this.numInitialBytes = numInitialBytes;
this.program = program;
alignment = program.getLanguage().getInstructionAlignment();
this.target = target;
this.includeBitLevelFeatures = includeBitLevelFeatures;
}
@Override
public Double process(Address item, TaskMonitor monitor) throws Exception {
if (Long.remainderUnsigned(item.getOffset(), alignment) != 0) {
return null;
}
List<Feature> vecToClassify = ModelTrainingUtils.getFeatureVector(program, item,
numPreBytes, numInitialBytes, includeBitLevelFeatures);
ArrayExample<Label> vec =
new ArrayExample<>(LabelFactory.UNKNOWN_LABEL, vecToClassify);
if (vec.size() == 0) {
return null;
}
Prediction<Label> pred = randomForest.predict(vec);
return pred.getOutputScores().get(target.getLabel()).getScore();
}
}

View file

@ -0,0 +1,120 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.*;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.EnsembleModel;
import generic.concurrent.*;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSetView;
import ghidra.program.model.listing.Program;
import ghidra.util.Msg;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
/**
* This class uses a {@link GThreadPool} to execute a {@link FunctionStartCallback} in
* parallel in order to classify addresses in a program as function starts or non-starts.
*/
public class FunctionStartClassifier {
private static final double DEFAULT_PROB_THRESHOLD = 0.5d;
private Program program;
private RandomForestRowObject modelRow;
private double probabilityThreshold;
private Label target;
/**
* Creates an object used to apply the model in {@code modelRow} to addresses in {@code program}
* to the probability of having label {@code target}.
* @param program program to check
* @param modelRow row containing model and data gathering parameters
* @param target target addresses
*/
public FunctionStartClassifier(Program program, RandomForestRowObject modelRow, Label target) {
this.program = program;
this.modelRow = modelRow;
probabilityThreshold = DEFAULT_PROB_THRESHOLD;
this.target = target;
}
/**
* Sets the probability threshold. Address where the probability of having the target
* label is less than the threshold are not included in the map returned by
* {@link FunctionStartClassifier#classify(AddressSetView, TaskMonitor)}
* @param thresh new threshold
*/
public void setProbabilityThreshold(double thresh) {
probabilityThreshold = thresh;
}
/**
* Classifies the addresses in {@code addresses} in parallel.
* @param addresses addresses to classify
* @param monitor monitor
* @return map from addresses to probabilities
* @throws CancelledException if monitor is canceled
*/
public Map<Address, Double> classify(AddressSetView addresses, TaskMonitor monitor)
throws CancelledException {
monitor.initialize(addresses.getNumAddresses());
int preBytes = modelRow.getNumPreBytes();
int initialBytes = modelRow.getNumInitialBytes();
EnsembleModel<Label> randomForest = modelRow.getRandomForest();
Msg.info(this, "Number of addresses to classify: " + addresses.getNumAddresses());
GThreadPool threadPool = GThreadPool.getSharedThreadPool("FunctionStartClassifier");
ConcurrentQBuilder<Address, Double> classifyBuilder = new ConcurrentQBuilder<>();
ConcurrentQ<Address, Double> classifyQ = classifyBuilder.setThreadPool(threadPool)
.setCollectResults(true)
.setMonitor(monitor)
.build(new FunctionStartCallback(randomForest, preBytes, initialBytes,
modelRow.getIncludeBitLevelFeatures(), program, target));
classifyQ.add(addresses.getAddresses(true));
Collection<QResult<Address, Double>> results = Collections.emptyList();
long start = System.nanoTime();
try {
results = classifyQ.waitForResults();
}
catch (InterruptedException e) {
monitor.checkCanceled();
Msg.error(this, "Exception while classifying functions: " + e.getMessage());
}
long end = System.nanoTime();
Msg.info(this, String.format("Classification time: %g seconds",
(end - start) / RandomForestTrainingTask.NANOSECONDS_PER_SECOND));
Map<Address, Double> addrToProb = new HashMap<>();
for (QResult<Address, Double> result : results) {
Double score = null;
try {
score = result.getResult();
}
catch (Exception e) {
Msg.error(this, "Problem getting score of Address " + result.getItem() + ": " +
e.getMessage());
continue;
}
if (score != null && score >= probabilityThreshold) {
addrToProb.put(result.getItem(), score);
}
}
return addrToProb;
}
}

View file

@ -0,0 +1,347 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.math.BigInteger;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import ghidra.program.model.address.*;
import ghidra.program.model.lang.Register;
import ghidra.program.model.listing.*;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
/**
* This is a container class for the parameters that determine what data is collected for
* training random forests to recognize function starts.
*/
public class FunctionStartRFParams {
private List<Integer> preBytes; //number of bytes before a function start
private List<Integer> initialBytes; //number of bytes after (and including) the first byte
private List<Integer> samplingFactors; //how many non-starts to gather for each start gathered
private int minFuncSize; //minimum size of a function to mine
private int maxStarts; //maximum number of function starts to gather
private Program trainingSource;
private AddressSet funcEntries;
private AddressSet funcInteriors;
private int instructionAlignment;
private boolean includePrecedingAndFollowing;
private boolean includeBitFeatures;
//the following two lists are related: they must have the same size, and the ith
//entry in contextRegisterVals is the value of the ith element of contextRegisterNames
private List<String> contextRegisterNames; //names of context registers we care about
private List<BigInteger> contextRegisterVals; //values of context registers we care about
/**
* Constructs a FunctionStartRFParams object, a container object for data gathering
* parameters to use when training a random forest to recognize function starts.
*
* <p> Note that {@code randomSeed} is initialized to a random value.
*
* <p>Use setter methods to set the fields.
* @param trainingSource source program
*/
public FunctionStartRFParams(Program trainingSource) {
this.trainingSource = trainingSource;
preBytes = Collections.emptyList();
initialBytes = Collections.emptyList();
samplingFactors = Collections.emptyList();
contextRegisterNames = Collections.emptyList();
contextRegisterVals = Collections.emptyList();
instructionAlignment = trainingSource.getLanguage().getInstructionAlignment();
}
/**
*
* @return the number of bytes to gather before an address
*/
public List<Integer> getPreBytes() {
return preBytes;
}
/**
*
* @param preBytes the number of bytes to gather before an address
*/
public void setPreBytes(List<Integer> preBytes) {
this.preBytes = preBytes;
}
/**
*
* @return the number of bytes to gather after (and including) an address
*/
public List<Integer> getInitialBytes() {
return initialBytes;
}
/**
*
* @param initialBytes the number of bytes to gather after (and including) an address
*/
public void setInitialBytes(List<Integer> initialBytes) {
this.initialBytes = initialBytes;
}
/**
*
* @return the minimum size a function must be to have its data gathered
*/
public int getMinFuncSize() {
return minFuncSize;
}
/**
*
* @param minFuncSize the minimum size a function must be to have its data gathered
*/
public void setMinFuncSize(int minFuncSize) {
this.minFuncSize = minFuncSize;
}
/**
*
* @return the maximum number of function starts to gather
*/
public int getMaxStarts() {
return maxStarts;
}
/**
*
* @param max the maximum number of function starts to gather
*/
public void setMaxStarts(int max) {
maxStarts = max;
}
/**
*
* @return the number of non-starts to gather per function start
*/
public List<Integer> getSamplingFactors() {
return samplingFactors;
}
/**
*
* @param factors the number of non-starts to gather per function start
*/
public void setFactors(List<Integer> factors) {
samplingFactors = factors;
}
/**
* Returns true precisely when there is a least one context register value set.
* @return true if there are any context register values set
*/
public boolean isRestrictedByContext() {
return !contextRegisterNames.isEmpty();
}
/**
*
* @return the list of names of context registers to set before disassembly
*/
public List<String> getContextRegisterNames() {
return contextRegisterNames;
}
/**
* The values to assign to the context registers.
* @return context register values
*/
public List<BigInteger> getContextRegisterVals() {
return contextRegisterVals;
}
/**
* Parses register,value pairs if the form creg1=x,creg2=from csv and stores them. Any
* existing register,value pairs are discarded.
*
* @param csv the list to parse
* @throws IllegalArgumentException if there are any parsing errors
*/
public void setRegistersAndValues(String csv) {
contextRegisterNames = new ArrayList<>();
contextRegisterVals = new ArrayList<>();
String[] parts = csv.split(",");
for (String part : parts) {
String[] regValPair = part.split("=");
if (regValPair.length != 2) {
contextRegisterNames.clear();
contextRegisterVals.clear();
throw new IllegalArgumentException("Error parsing register=value string " + part);
}
String regName = regValPair[0].trim();
if (trainingSource.getRegister(regName) == null) {
contextRegisterNames.clear();
contextRegisterVals.clear();
throw new IllegalArgumentException(
"Register " + regName + " not found for program " + trainingSource.getName());
}
contextRegisterNames.add(regName);
BigInteger bigInt = new BigInteger(regValPair[1].trim());
contextRegisterVals.add(bigInt);
}
}
/**
* Parses a CSV into a sorted list of distinct integer values (duplicates are ignored). Returns
* an empty list of a parse error is encountered.
* @param csv csv string to parse
* @return sorted list
*/
public static List<Integer> parseIntegerCSV(String csv) {
if (StringUtils.isBlank(csv)) {
throw new IllegalArgumentException("Entry cannot be blank");
}
String trimmed = csv.trim();
if (trimmed.startsWith(",") || trimmed.endsWith(",")) {
throw new IllegalArgumentException("String must not begin or end with a comma");
}
Set<Integer> results = new HashSet<>();
String[] parts = trimmed.split(",");
for (String part : parts) {
Integer i = Integer.decode(part.trim());
if (i < 0) {
throw new IllegalArgumentException(
"Invalid element " + part + " - must be non-negative");
}
results.add(i);
}
return results.stream().sorted().collect(Collectors.toList());
}
/**
* Returns the {@link AddressSet} of function entries in the source program.
* <P>
* NB: Invoke {@link FunctionStartRFParams#computeFuncEntriesAndInteriors} before
* invoking this method.
* @return set of entries
*/
public AddressSet getFuncEntries() {
return funcEntries;
}
/**
* Returns the {@link AddressSet} of function interiors in the source program.
* <P>
* NB: Invoke {@link FunctionStartRFParams#computeFuncEntriesAndInteriors} before
* invoking this method.
* @return set of interiors
*/
public AddressSet getFuncInteriors() {
return funcInteriors;
}
/**
* Returns boolean indicating whether code units immediately preceding and
* following a function start should be included in the training set.
* @return include preceding and following
*/
public boolean getIncludePrecedingAndFollowing() {
return includePrecedingAndFollowing;
}
/**
* Sets boolean indicating whether code units immediately preceding and
* following a function start should be included in the training set.
* @param b new value
*/
public void setIncludePrecedingAndFollowing(boolean b) {
includePrecedingAndFollowing = b;
}
/**
* Returns boolean indicating whether to include bit-level features in the feature vectors.
* @return include bit level features
*/
public boolean getIncludeBitFeatures() {
return includeBitFeatures;
}
/**
* Sets boolean indicating whether to include bit-level features in the feature vectors.
* @param b new value
*/
public void setIncludeBitFeatures(boolean b) {
includeBitFeatures = b;
}
/**
* Computes the {@link AddressSet}s of function entries and bodies in the source program.
* Retrieve these sets via {@link FunctionStartRFParams#getFuncEntries()} and
* {@link FunctionStartRFParams#getFuncInteriors()}
* <p> Note: the interior of a function only contains addresses which are aligned relative
* to the instruction alignment of the processor
* @param monitor task monitor
* @throws CancelledException if monitor is canceled
*/
public void computeFuncEntriesAndInteriors(TaskMonitor monitor) throws CancelledException {
FunctionIterator fIter = trainingSource.getFunctionManager().getFunctions(true);
monitor.initialize(trainingSource.getFunctionManager().getFunctionCount());
funcEntries = new AddressSet();
funcInteriors = new AddressSet();
while (fIter.hasNext()) {
monitor.checkCanceled();
Function func = fIter.next();
monitor.incrementProgress(1);
if (func.getBody().getNumAddresses() < minFuncSize) {
continue;
}
if (isRestrictedByContext()) {
if (!isContextCompatible(func.getEntryPoint())) {
continue;
}
}
funcEntries.add(func.getEntryPoint());
AddressSet body = func.getBody().subtract(new AddressSet(func.getEntryPoint()));
AddressIterator addrIter = body.getAddresses(true);
while (addrIter.hasNext()) {
Address addr = addrIter.next();
if (addr.getOffset() % instructionAlignment == 0) {
funcInteriors.add(addr);
}
}
}
}
/**
* Checks whether {@code addr} is consistent with context register values
* supplied via {@link FunctionStartRFParams#setRegistersAndValues(String)}
* @param addr address to check
* @return is consistent with context regs
*/
public boolean isContextCompatible(Address addr) {
ProgramContext context = trainingSource.getProgramContext();
for (int i = 0; i < contextRegisterNames.size(); i++) {
Register reg = context.getRegister(contextRegisterNames.get(i));
BigInteger val = context.getValue(reg, addr, false);
if (!val.equals(contextRegisterVals.get(i))) {
return false;
}
}
return true;
}
}

View file

@ -0,0 +1,654 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.awt.BorderLayout;
import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import javax.swing.*;
import org.apache.commons.lang3.StringUtils;
import docking.DialogComponentProvider;
import docking.action.DockingAction;
import docking.action.builder.ActionBuilder;
import docking.widgets.combobox.GComboBox;
import docking.widgets.label.GDLabel;
import docking.widgets.table.GTable;
import docking.widgets.table.threaded.GThreadedTablePanel;
import docking.widgets.textfield.IntegerTextField;
import ghidra.app.services.ProgramManager;
import ghidra.framework.main.DataTreeDialog;
import ghidra.framework.preferences.Preferences;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.listing.*;
import ghidra.util.HelpLocation;
import ghidra.util.Msg;
import ghidra.util.exception.CancelledException;
import ghidra.util.exception.VersionException;
import ghidra.util.layout.PairLayout;
import ghidra.util.table.SelectionNavigationAction;
import ghidra.util.table.actions.MakeProgramSelectionAction;
import ghidra.util.task.*;
/**
* This class creates a dialog window for the user to enter data mining parameters
* for learning function starts, train models, see performance statistics, and
* apply the models.
*/
public class FunctionStartRFParamsDialog extends DialogComponentProvider {
private static final String INITIAL_BYTES_TEXT = "Number of Initial Bytes (CSV)";
private static final String INITIAL_BYTES_TIP =
"Number of initial bytes of a function to record";
private static final String PRE_BYTES_TEXT = "Number of Pre-bytes (CSV)";
private static final String PRE_BYTES_TIP =
"Number of bytes immediately before a function start to record";
private static final String MIN_FUNC_SIZE_TEXT = "Minimum Function Size";
private static final String MIN_FUNC_SIZE_TIP =
"Functions whose size in bytes are below this number are skipped";
private static final String CONTEXT_REGISTER_TEXT = "Context Registers and Values (CSV)";
private static final String CONTEXT_REGISTER_TIP =
"Restrict gathering to functions where context registers have been set." +
" Form: cReg1=x,cReg2=y,...";
private static final String FACTOR_TEXT = "Start to Non-start Sampling Factors (CSV)";
private static final String FACTOR_TIP =
"Number of non-starts to gather for each function start";
private static final String MAX_STARTS_TEXT = "Maximum Number of Starts";
private static final String MAX_STARTS_TIP = "Maximum number of function starts to gather";
private static final String INCLUDE_PRECEDING_FOLLOWING_TEXT =
"Include Preceding and Following";
private static final String INCLUDE_PRECEDING_FOLLOWING_TIP =
"Include code units immediately before and after a function start when testing and training";
private static final String INCLUDE_BIT_FEATURES_TEXT = "Include Bit Features";
private static final String INCLUDE_BIT_FEATURES_TIP =
"Include bit-level features. May improve models; will increase computation time.";
private static final String FUNCTIONS_MEETING_SIZE_TEXT = "Functions Meeting Size Bound";
private static final String FUNCTIONS_MEETING_SIZE_TIP =
"Number of functions meeting the size " + "bound";
private static final String RESTRICT_SEARCH_TEXT = "Restrict Search To Aligned Addresses ";
private static final String RESTRICT_SEARCH_TIP =
"Only apply model to aligned addresses. NOTE:" + "Does not affect training or test sets!";
private static final String ALIGNMENT_MODULUS_TEXT = "Alignment Modulus";
private static final String ALIGNMENT_MODULUS_TIP =
"Use to define the alignment for restricted search";
private static final String DEFAULT_INITIAL_BYTES = "8,16";
private static final String INITIAL_BYTES_PROPERTY = "functionStartRFParams_initialBytes";
private static final String DEFAULT_PRE_BYTES = "2,8";
private static final String PRE_BYTES_PROPERTY = "functionStartRFParams_preBytes";
private static final String DEFAULT_MINIMUM_FUNCTION_SIZE = "16";
private static final String MIN_FUNC_SIZE_PROPERTY = "functionStartRFParams_minFuncSize";
private static final String DEFAULT_CONTEXT_REGISTERS = "";
private static final String CONTEXT_REGISTER_PROPERTY = "functionStartRFParams_cRegs";
private static final String DEFAULT_FACTOR = "10,50";
private static final String FACTOR_PROPERTY = "functionStartRFParams_factor";
private static final String DEFAULT_MAX_STARTS = "1000";
private static final String MAX_STARTS_PROPERTY = "functionStartRFParams_maxStarts";
private static final String INCLUDE_PRECEDING_AND_FOLLOWING_PROPERTY =
"functionStartRFParams_use_pf";
private static final String DEFAULT_INCLUDE_PF = "True";
private static final String INCLUDE_BIT_FEATURES_PROPERTY =
"funcstionStartRFParams_includeBitFeatures";
private static final String DEFAULT_INCLUDE_BIT_FEATURES = "False";
private static final String DATA_GATHERING_PARAMETERS = "Data Gathering Parameters";
private static final String MODEL_STATISTICS = "Model Statistics";
private static final String TITLE = "Random Forest Function Finder";
private static final String FUNCTION_INFO = "Function Information";
private static final String APPLY_MODEL_ACTION_NAME = "ApplyModel";
private static final String APPLY_MODEL_MENU_TEXT = "Apply Model";
private static final String APPLY_MODEL_TO_ACTION_NAME = "ApplyModelTo";
private static final String APPLY_MODEL_TO_MENU_TEXT = "Apply Model To...";
private static final String DEBUG_MODEL_ACTION_NAME = "DebugModel";
private static final String DEBUG_MODEL_MENU_TEXT = "DEBUG - Show test set errors";
private JTextField initialBytesField;
private JTextField preBytesField;
private JTextField factorField;
private IntegerTextField minimumSizeField;
private IntegerTextField maxStartsField;
private JTextField contextRegistersField;
private JLabel numFuncsField;
private JScrollPane tableScrollPane;
private JPanel funcInfoPanel;
private RandomForestFunctionFinderPlugin plugin;
private List<RandomForestRowObject> rowObjects;
private RandomForestTableModel tableModel;
private Program trainingSource;
private FunctionStartRFParams params;
private Set<Program> openPrograms;
private Vector<Long> moduli = new Vector<>(Arrays.asList(new Long[] { 4l, 8l, 16l, 32l }));
private GComboBox<Long> modBox;
private JButton trainButton;
private JCheckBox includeBeforeAndAfterBox;
private JCheckBox includeBitFeaturesBox;
private JCheckBox restrictBox;
private JButton restoreDefaultsButton;
/**
* Creates a dialog for training models to find function starts using the
* current program of {@code plugin}.
* @param plugin plugin owning this dialog
*/
public FunctionStartRFParamsDialog(RandomForestFunctionFinderPlugin plugin) {
super(TITLE + ": " + plugin.getCurrentProgram().getDomainFile().getPathname(), false, true,
true, true);
this.plugin = plugin;
rowObjects = new ArrayList<>();
openPrograms = new HashSet<>();
trainingSource = plugin.getCurrentProgram();
JPanel panel = createPanel();
addWorkPanel(panel);
trainButton = addTrainModelsButton();
addHideDialogButton();
addRestoreDefaultsButton();
setHelpLocation(new HelpLocation(plugin.getName(), plugin.getName()));
}
@Override
public void taskCompleted(Task task) {
super.taskCompleted(task);
setStatusText("Training Completed");
setEnabled(true);
}
@Override
public void taskCancelled(Task task) {
super.taskCancelled(task);
setStatusText("Training Canceled");
setEnabled(true);
}
/**
* Returns the program used to train the models
* @return source program
*/
Program getTrainingSource() {
return trainingSource;
}
@Override
protected void dismissCallback() {
TaskMonitorComponent monitorComp = getTaskMonitorComponent();
if (monitorComp != null) {
monitorComp.cancel();
}
setStatusText("");
//rows in the table can have large objects as fields
//make sure that memory is reclaimed
tableModel.dispose();
rowObjects.clear();
plugin.resetDialog();
dispose();
}
private FunctionStartRFParams getMachineLearningParams() {
FunctionStartRFParams rfParams = new FunctionStartRFParams(trainingSource);
List<Integer> preBytes = FunctionStartRFParams.parseIntegerCSV(preBytesField.getText());
rfParams.setPreBytes(preBytes);
List<Integer> initialBytes =
FunctionStartRFParams.parseIntegerCSV(initialBytesField.getText());
rfParams.setInitialBytes(initialBytes);
List<Integer> factors = FunctionStartRFParams.parseIntegerCSV(factorField.getText());
rfParams.setFactors(factors);
int minSize = minimumSizeField.getIntValue();
if (minSize <= 0) {
Msg.showWarn(this, null, "Invalid Minimum Size", "Minimum size must be positive!");
return null;
}
rfParams.setMinFuncSize(minSize);
int maxStarts = maxStartsField.getIntValue();
if (maxStarts <= 0) {
Msg.showWarn(this, null, "Invalid Max Starts", "Max Starts must be positive!");
return null;
}
rfParams.setMaxStarts(maxStarts);
String csv = contextRegistersField.getText();
if (!StringUtils.isBlank(csv)) {
try {
rfParams.setRegistersAndValues(csv);
}
catch (IllegalArgumentException e) {
Msg.showError(factors, null, "Context Register/Value Error", e);
return null;
}
}
rfParams.setIncludePrecedingAndFollowing(includeBeforeAndAfterBox.isSelected());
rfParams.setIncludeBitFeatures(includeBitFeaturesBox.isSelected());
setProperties();
return rfParams;
}
private void trainModelsCallback() {
rowObjects.clear();
tableModel.reload();
params = getMachineLearningParams();
if (params == null) {
return;
}
RandomForestTrainingTask trainingTask = new RandomForestTrainingTask(trainingSource, params,
r -> tableModel.addObject(r), plugin.getTestMaxSize());
trainingTask.addTaskListener(this);
setEnabled(false);
executeProgressTask(trainingTask, 500);
}
private JButton addTrainModelsButton() {
JButton trainModelsButton = new JButton("Train");
trainModelsButton.setToolTipText("Train models using the specified parameters");
trainModelsButton.addActionListener(e -> trainModelsCallback());
addButton(trainModelsButton);
return trainModelsButton;
}
private JPanel createPanel() {
JPanel mainPanel = new JPanel(new BorderLayout());
tableModel = new RandomForestTableModel(plugin.getTool(), rowObjects);
GThreadedTablePanel<RandomForestRowObject> evalPanel =
new GThreadedTablePanel<>(tableModel);
GTable modelStatsTable = evalPanel.getTable();
modelStatsTable.setSelectionMode(ListSelectionModel.SINGLE_SELECTION);
evalPanel.setBorder(BorderFactory.createTitledBorder(MODEL_STATISTICS));
mainPanel.add(evalPanel, BorderLayout.EAST);
DockingAction applyAction = new ActionBuilder(APPLY_MODEL_ACTION_NAME, plugin.getName())
.description("Apply Model to Source Program")
.popupWhen(c -> trainingSource != null)
.enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1)
.popupMenuPath(APPLY_MODEL_MENU_TEXT)
.inWindow(ActionBuilder.When.ALWAYS)
.onAction(c -> {
searchTrainingProgram(tableModel.getLastSelectedObjects().get(0));
})
.build();
addAction(applyAction);
DockingAction applyToAction =
new ActionBuilder(APPLY_MODEL_TO_ACTION_NAME, plugin.getName())
.description("Choose Program and Apply Model to it")
.popupWhen(c -> trainingSource != null)
.enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1)
.popupMenuPath(APPLY_MODEL_TO_MENU_TEXT)
.inWindow(ActionBuilder.When.ALWAYS)
.onAction(c -> {
searchOtherProgram(tableModel.getLastSelectedObjects().get(0));
})
.build();
addAction(applyToAction);
DockingAction checkAction = new ActionBuilder(DEBUG_MODEL_ACTION_NAME, plugin.getName())
.description("Show Test Set Errors")
.popupWhen(c -> trainingSource != null)
.enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1)
.popupMenuPath(DEBUG_MODEL_MENU_TEXT)
.inWindow(ActionBuilder.When.ALWAYS)
.onAction(c -> {
showTestErrors(tableModel.getLastSelectedObjects().get(0));
})
.build();
addAction(checkAction);
JPanel paramsPanel = new JPanel();
paramsPanel.setBorder(BorderFactory.createTitledBorder(DATA_GATHERING_PARAMETERS));
PairLayout pairLayout = new PairLayout();
paramsPanel.setLayout(pairLayout);
JLabel preLabel = new GDLabel(PRE_BYTES_TEXT);
preLabel.setToolTipText(PRE_BYTES_TIP);
paramsPanel.add(preLabel);
preBytesField = new JTextField();
String preBytes = Preferences.getProperty(PRE_BYTES_PROPERTY, DEFAULT_PRE_BYTES);
preBytesField.setText(preBytes);
paramsPanel.add(preBytesField);
JLabel initialLabel = new GDLabel(INITIAL_BYTES_TEXT);
initialLabel.setToolTipText(INITIAL_BYTES_TIP);
paramsPanel.add(initialLabel);
initialBytesField = new JTextField();
String initialBytes =
Preferences.getProperty(INITIAL_BYTES_PROPERTY, DEFAULT_INITIAL_BYTES);
initialBytesField.setText(initialBytes);
paramsPanel.add(initialBytesField);
JLabel factorLabel = new GDLabel(FACTOR_TEXT);
factorLabel.setToolTipText(FACTOR_TIP);
paramsPanel.add(factorLabel);
factorField = new JTextField();
String factor = Preferences.getProperty(FACTOR_PROPERTY, DEFAULT_FACTOR);
factorField.setText(factor);
paramsPanel.add(factorField);
JLabel maxStartsLabel = new GDLabel(MAX_STARTS_TEXT);
maxStartsLabel.setToolTipText(MAX_STARTS_TIP);
paramsPanel.add(maxStartsLabel);
maxStartsField = new IntegerTextField();
String maxStarts = Preferences.getProperty(MAX_STARTS_PROPERTY, DEFAULT_MAX_STARTS);
maxStartsField.setValue(Integer.parseInt(maxStarts));
paramsPanel.add(maxStartsField.getComponent());
JLabel contextLabel = new GDLabel(CONTEXT_REGISTER_TEXT);
contextLabel.setToolTipText(CONTEXT_REGISTER_TIP);
paramsPanel.add(contextLabel);
contextRegistersField = new JTextField(DEFAULT_CONTEXT_REGISTERS);
String cRegs = Preferences.getProperty(CONTEXT_REGISTER_PROPERTY, "");
if (!StringUtils.isEmpty(cRegs)) {
contextRegistersField.setText(cRegs);
}
paramsPanel.add(contextRegistersField);
JLabel includeBeforeAndAfterLabel = new JLabel(INCLUDE_PRECEDING_FOLLOWING_TEXT);
includeBeforeAndAfterLabel.setToolTipText(INCLUDE_PRECEDING_FOLLOWING_TIP);
includeBeforeAndAfterBox = new JCheckBox();
String defaultUseBeforeAfter =
Preferences.getProperty(INCLUDE_PRECEDING_AND_FOLLOWING_PROPERTY, DEFAULT_INCLUDE_PF);
includeBeforeAndAfterBox.setSelected(Boolean.valueOf(defaultUseBeforeAfter));
paramsPanel.add(includeBeforeAndAfterLabel);
paramsPanel.add(includeBeforeAndAfterBox);
JLabel includeSelectionLabel = new JLabel(INCLUDE_BIT_FEATURES_TEXT);
includeSelectionLabel.setToolTipText(INCLUDE_BIT_FEATURES_TIP);
includeBitFeaturesBox = new JCheckBox();
String defaultIncludeBitFeatures =
Preferences.getProperty(INCLUDE_BIT_FEATURES_PROPERTY, DEFAULT_INCLUDE_BIT_FEATURES);
includeBitFeaturesBox.setSelected(Boolean.valueOf(defaultIncludeBitFeatures));
paramsPanel.add(includeSelectionLabel);
paramsPanel.add(includeBitFeaturesBox);
JLabel minFuncLabel = new GDLabel(MIN_FUNC_SIZE_TEXT);
minFuncLabel.setToolTipText(MIN_FUNC_SIZE_TIP);
paramsPanel.add(minFuncLabel);
minimumSizeField = new IntegerTextField();
String minSize =
Preferences.getProperty(MIN_FUNC_SIZE_PROPERTY, DEFAULT_MINIMUM_FUNCTION_SIZE);
minimumSizeField.setValue(Integer.parseInt(minSize));
minimumSizeField.addChangeListener(e -> {
updateNumFuncsField();
updateModulusTable();
});
paramsPanel.add(minimumSizeField.getComponent());
JPanel funcDataPanel = new JPanel();
pairLayout = new PairLayout();
funcDataPanel.setLayout(pairLayout);
JLabel numFuncsLabel = new GDLabel(FUNCTIONS_MEETING_SIZE_TEXT);
numFuncsLabel.setToolTipText(FUNCTIONS_MEETING_SIZE_TIP);
funcDataPanel.add(numFuncsLabel);
numFuncsField = new GDLabel();
updateNumFuncsField();
funcDataPanel.add(numFuncsField);
JLabel restrictLabel = new GDLabel(RESTRICT_SEARCH_TEXT);
restrictLabel.setToolTipText(RESTRICT_SEARCH_TIP);
funcDataPanel.add(restrictLabel);
restrictBox = new JCheckBox();
funcDataPanel.add(restrictBox);
JLabel modulusLabel = new GDLabel(ALIGNMENT_MODULUS_TEXT);
modulusLabel.setToolTipText(ALIGNMENT_MODULUS_TIP);
funcDataPanel.add(modulusLabel);
modBox = new GComboBox<>(moduli);
modBox.setSelectedItem(Long.valueOf(16));
modBox.addActionListener(e -> updateModulusTable());
funcDataPanel.add(modBox);
tableScrollPane = getFuncAlignmentScrollPane();
funcInfoPanel = new JPanel();
funcInfoPanel.setLayout(new BorderLayout());
funcInfoPanel.setBorder(BorderFactory.createTitledBorder(FUNCTION_INFO));
funcInfoPanel.add(funcDataPanel, BorderLayout.NORTH);
funcInfoPanel.add(tableScrollPane, BorderLayout.CENTER);
JPanel infoPanel = new JPanel(new BorderLayout());
infoPanel.add(paramsPanel, BorderLayout.NORTH);
infoPanel.add(funcInfoPanel, BorderLayout.CENTER);
mainPanel.add(infoPanel, BorderLayout.WEST);
return mainPanel;
}
private JScrollPane getFuncAlignmentScrollPane() {
Long modulus = (Long) modBox.getSelectedItem();
int minSize = minimumSizeField.getIntValue();
//initialize map
Map<Long, Long> countMap =
LongStream.range(0, modulus).boxed().collect(Collectors.toMap(i -> i, i -> 0l));
FunctionIterator fIter = trainingSource.getFunctionManager().getFunctionsNoStubs(true);
//needs some thought to determine how to display number of functions compatible
//with context register restrictions
for (Function f : fIter) {
if ((f.getBody().getNumAddresses() >= minSize) &&
(params == null || params.isContextCompatible(f.getEntryPoint()))) {
countMap.merge(f.getEntryPoint().getOffset() % modulus, 1l, Long::sum);
}
}
List<FunctionStartAlignmentRowObject> rows = countMap.entrySet()
.stream()
.map(e -> new FunctionStartAlignmentRowObject(e.getKey(), e.getValue()))
.collect(Collectors.toList());
FunctionStartAlignmentTableModel alignModel = new FunctionStartAlignmentTableModel(rows);
GTable alignTable = new GTable(alignModel);
return new JScrollPane(alignTable);
}
private void updateModulusTable() {
funcInfoPanel.remove(tableScrollPane);
tableScrollPane = getFuncAlignmentScrollPane();
funcInfoPanel.add(tableScrollPane);
getComponent().updateUI();
}
private void updateNumFuncsField() {
int numFuncs = 0;
long bound = minimumSizeField.getLongValue();
for (Function func : trainingSource.getFunctionManager().getFunctionsNoStubs(true)) {
if (func.getBody().getNumAddresses() >= bound) {
numFuncs += 1;
}
}
numFuncsField.setText(Integer.toString(numFuncs));
}
private void searchTrainingProgram(RandomForestRowObject modelRow) {
searchProgram(trainingSource, modelRow);
}
private void searchOtherProgram(RandomForestRowObject modelRow) {
Program p = selectProgram();
if (p == null) {
return;
}
ProgramManager pm = plugin.getTool().getService(ProgramManager.class);
pm.openProgram(p, ProgramManager.OPEN_VISIBLE);
searchProgram(p, modelRow);
}
private void showTestErrors(RandomForestRowObject modelRow) {
FunctionStartTableProvider provider = new FunctionStartTableProvider(plugin, trainingSource,
modelRow.getTestErrors(), modelRow, true);
addGeneralActions(provider);
}
private void searchProgram(Program prog, RandomForestRowObject modelRow) {
GetAddressesToClassifyTask getTask =
new GetAddressesToClassifyTask(prog, plugin.getMinUndefinedRangeSize());
//don't want to use the dialog's progress bar
TaskLauncher.launchModal("Gathering Addresses To Classify", getTask);
if (getTask.isCancelled()) {
return;
}
AddressSet execNonFunc = null;
if (restrictBox.isSelected()) {
execNonFunc = getTask.getAddressesToClassify((long) modBox.getSelectedItem());
}
else {
execNonFunc = getTask.getAddressesToClassify();
}
FunctionStartTableProvider provider =
new FunctionStartTableProvider(plugin, prog, execNonFunc, modelRow, false);
addGeneralActions(provider);
DisassembleFunctionStartsAction disassembleAction = null;
if (params.isRestrictedByContext()) {
disassembleAction = new DisassembleAndApplyContextAction(plugin, prog,
provider.getTable(), provider.getTableModel());
}
else {
disassembleAction = new DisassembleFunctionStartsAction(plugin, prog,
provider.getTable(), provider.getTableModel());
}
plugin.getTool().addLocalAction(provider, disassembleAction);
CreateFunctionsAction createActions =
new CreateFunctionsAction(plugin, prog, provider.getTable(), provider.getTableModel());
plugin.getTool().addLocalAction(provider, createActions);
}
private void addGeneralActions(FunctionStartTableProvider provider) {
plugin.addProvider(provider);
DockingAction programSelectAction =
new MakeProgramSelectionAction(plugin, provider.getTable());
programSelectAction.setEnabled(true);
plugin.getTool().addLocalAction(provider, programSelectAction);
DockingAction selectNavigationAction =
new SelectionNavigationAction(plugin, provider.getTable());
plugin.getTool().addLocalAction(provider, selectNavigationAction);
ShowSimilarStartsAction similarStarts = new ShowSimilarStartsAction(plugin,
plugin.getCurrentProgram(), provider.getTable(), provider.getTableModel());
plugin.getTool().addLocalAction(provider, similarStarts);
}
private Program selectProgram() {
DataTreeDialog dtd = new DataTreeDialog(null, "Select Program", DataTreeDialog.OPEN, f -> {
Class<?> c = f.getDomainObjectClass();
return Program.class.isAssignableFrom(c);
});
dtd.show();
if (dtd.wasCancelled()) {
return null;
}
Program otherProgram = null;
try {
otherProgram = (Program) dtd.getDomainFile()
.getDomainObject(plugin, true, true, getTaskMonitorComponent());
openPrograms.add(otherProgram);
}
catch (VersionException | CancelledException | IOException e) {
return null;
}
if (isProgramCompatible(otherProgram)) {
return otherProgram;
}
return null;
}
//checks whether otherProgram contains any specified context registers
//at some point might be worth adding more restrictions
private boolean isProgramCompatible(Program otherProgram) {
if (params == null) {
//shouldn't happen
throw new IllegalStateException("null params");
}
if (!params.isRestrictedByContext()) {
return true;
}
for (String regName : params.getContextRegisterNames()) {
if (otherProgram.getRegister(regName) == null) {
Msg.showError(this, null, "Error Applying Model", "Program " +
otherProgram.getName() + " does not have a context register named " + regName);
return false;
}
}
return true;
}
private void setEnabled(boolean b) {
minimumSizeField.setEnabled(b);
initialBytesField.setEnabled(b);
preBytesField.setEnabled(b);
factorField.setEnabled(b);
minimumSizeField.setEnabled(b);
maxStartsField.setEnabled(b);
trainButton.setEnabled(b);
contextRegistersField.setEnabled(b);
includeBeforeAndAfterBox.setEnabled(b);
includeBitFeaturesBox.setEnabled(b);
restoreDefaultsButton.setEnabled(b);
}
private void addRestoreDefaultsButton() {
restoreDefaultsButton = new JButton("Restore Defaults");
restoreDefaultsButton.setToolTipText("Restore training parameters to the default values");
restoreDefaultsButton.addActionListener(e -> restoreDefaults());
addButton(restoreDefaultsButton);
}
private void addHideDialogButton() {
JButton hideDialogButton = new JButton("Hide Dialog");
hideDialogButton.setToolTipText("Hide Dialog (does not cancel training or destory models)");
hideDialogButton.addActionListener(e -> close());
addButton(hideDialogButton);
}
private void restoreDefaults() {
initialBytesField.setText(DEFAULT_INITIAL_BYTES);
preBytesField.setText(DEFAULT_PRE_BYTES);
minimumSizeField.setValue((Integer.parseInt(DEFAULT_MINIMUM_FUNCTION_SIZE)));
maxStartsField.setValue(Integer.parseInt(DEFAULT_MAX_STARTS));
factorField.setText(DEFAULT_FACTOR);
includeBeforeAndAfterBox.setSelected(Boolean.valueOf(DEFAULT_INCLUDE_PF));
includeBitFeaturesBox.setSelected(Boolean.valueOf(DEFAULT_INCLUDE_BIT_FEATURES));
contextRegistersField.setText(null);
setProperties();
}
private void setProperties() {
Preferences.setProperty(INITIAL_BYTES_PROPERTY, initialBytesField.getText());
Preferences.setProperty(PRE_BYTES_PROPERTY, preBytesField.getText());
Preferences.setProperty(MIN_FUNC_SIZE_PROPERTY, minimumSizeField.getText());
Preferences.setProperty(MAX_STARTS_PROPERTY, maxStartsField.getText());
Preferences.setProperty(FACTOR_PROPERTY, factorField.getText());
if (StringUtils.isBlank(contextRegistersField.getText())) {
Preferences.removeProperty(CONTEXT_REGISTER_PROPERTY);
}
else {
Preferences.setProperty(CONTEXT_REGISTER_PROPERTY, contextRegistersField.getText());
}
Preferences.setProperty(INCLUDE_PRECEDING_AND_FOLLOWING_PROPERTY,
Boolean.toString(includeBeforeAndAfterBox.isSelected()));
Preferences.setProperty(DEFAULT_INCLUDE_BIT_FEATURES,
Boolean.toString(includeBitFeaturesBox.isSelected()));
}
}

View file

@ -0,0 +1,168 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Program;
import ghidra.program.model.symbol.*;
/**
* Represents a row in table showing the probabilities that addresses are function starts
*/
public class FunctionStartRowObject {
private Address address;
private double probability;
private Interpretation currentInter;
private int numDataRefs;
private int numUnconditionalFlowRefs;
private int numConditionalFlowRefs;
/**
* Creates a row showing that {@link Address} {@code address} has probability
* {@code probability} of being a function start. Use setter methods to set the other
* entries in the row.
* @param address address
* @param probability prob of being function start
*/
public FunctionStartRowObject(Address address, double probability) {
this.address = address;
this.probability = probability;
}
/**
* Returns the address.
* @return address
*/
public Address getAddress() {
return address;
}
/**
* Returns the probability.
* @return probability
*/
public double getProbability() {
return probability;
}
/**
* Returns the {@link Interpretation}
* @return interpretation
*/
public Interpretation getCurrentInterpretation() {
return currentInter;
}
/**
* Sets the {@link Interpretation}
* @param inter interpretation
*/
public void setCurrentInterpretation(Interpretation inter) {
currentInter = inter;
}
/**
* Returns the number of data references
* @return num data refs
*/
public int getNumDataRefs() {
return numDataRefs;
}
/**
* Sets the number of data references
* @param numRefs num data refs
*/
public void setNumDataRefs(int numRefs) {
numDataRefs = numRefs;
}
/**
* Returns the number of unconditional flow references
* @return num unconditional flow refs
*/
public int getNumUnconditionalFlowRefs() {
return numUnconditionalFlowRefs;
}
/**
* Sets the number of unconditional flow references
* @param numRefs num unconditional flow refs
*/
public void setNumUnconditionalFlowRefs(int numRefs) {
numUnconditionalFlowRefs = numRefs;
}
/**
* Sets the number of conditional flow references
* @param numConditionalFlowRefs num conditional refs
*/
public void setNumConditionalFlowRefs(int numConditionalFlowRefs) {
this.numConditionalFlowRefs = numConditionalFlowRefs;
}
/**
* Returns the number of conditional flow references
* @return num conditional flow refs
*/
public int getNumConditionalFlowRefs() {
return numConditionalFlowRefs;
}
@Override
public int hashCode() {
return address.hashCode();
}
@Override
public boolean equals(Object o) {
return address.equals(o);
}
/**
* Determines and sets the number of data, conditional flow, and unconditional flow references
* to the addresses corresponding to {@code rowObject}
* @param rowObject row
* @param program source program
*/
public static void setReferenceData(FunctionStartRowObject rowObject, Program program) {
int numUnconditionalFlowRefs = 0;
int numConditionalFlowRefs = 0;
int numDataRefs = 0;
ReferenceIterator refIter =
program.getReferenceManager().getReferencesTo(rowObject.getAddress());
while (refIter.hasNext()) {
Reference ref = refIter.next();
RefType type = ref.getReferenceType();
if (type instanceof DataRefType) {
numDataRefs++;
continue;
}
if (type instanceof FlowType) {
if (type.isConditional()) {
numConditionalFlowRefs++;
}
else {
numUnconditionalFlowRefs++;
}
}
}
rowObject.setNumDataRefs(numDataRefs);
rowObject.setNumUnconditionalFlowRefs(numUnconditionalFlowRefs);
rowObject.setNumConditionalFlowRefs(numConditionalFlowRefs);
}
}

View file

@ -0,0 +1,20 @@
/* ###
* IP: LICENSE
*/
package ghidra.machinelearning.functionfinding;
import ghidra.framework.plugintool.ServiceProvider;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Program;
import ghidra.util.table.ProgramLocationTableRowMapper;
public class FunctionStartRowObjectToAddressTableRowMapper
extends ProgramLocationTableRowMapper<FunctionStartRowObject, Address> {
@Override
public Address map(FunctionStartRowObject rowObject, Program data,
ServiceProvider serviceProvider) {
return rowObject.getAddress();
}
}

View file

@ -0,0 +1,32 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import ghidra.framework.plugintool.ServiceProvider;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.Program;
import ghidra.util.table.ProgramLocationTableRowMapper;
public class FunctionStartRowObjectToFunctionTableRowMapper
extends ProgramLocationTableRowMapper<FunctionStartRowObject, Function> {
@Override
public Function map(FunctionStartRowObject rowObject, Program data,
ServiceProvider serviceProvider) {
return data.getFunctionManager().getFunctionAt(rowObject.getAddress());
}
}

View file

@ -0,0 +1,31 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import ghidra.framework.plugintool.ServiceProvider;
import ghidra.program.model.listing.Program;
import ghidra.program.util.ProgramLocation;
import ghidra.util.table.ProgramLocationTableRowMapper;
public class FunctionStartRowObjectToProgramLocationTableRowMapper
extends ProgramLocationTableRowMapper<FunctionStartRowObject, ProgramLocation> {
@Override
public ProgramLocation map(FunctionStartRowObject rowObject, Program program,
ServiceProvider serviceProvider) {
return new ProgramLocation(program, rowObject.getAddress());
}
}

View file

@ -0,0 +1,221 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.Map;
import java.util.Map.Entry;
import docking.widgets.table.AbstractDynamicTableColumn;
import docking.widgets.table.TableColumnDescriptor;
import ghidra.docking.settings.Settings;
import ghidra.framework.plugintool.PluginTool;
import ghidra.framework.plugintool.ServiceProvider;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.block.BasicBlockModel;
import ghidra.program.model.listing.Program;
import ghidra.util.datastruct.Accumulator;
import ghidra.util.exception.CancelledException;
import ghidra.util.table.AddressBasedTableModel;
import ghidra.util.task.TaskMonitor;
/**
* A {@link AddressBasedTableModel} used to display information about addresses which are
* likely to be function starts.
*/
public class FunctionStartTableModel extends AddressBasedTableModel<FunctionStartRowObject> {
private RandomForestRowObject modelRow;
private AddressSet addressesToClassify;
private boolean debug;
private BasicBlockModel blockModel;
private Map<Address, Double> addressToProbability;
/**
* Creates a table to display address likely to be function starts. If {@code debug}
* is {@code true}, the table will contain a row for each address in
* {@code addressesToClassify}. Otherwise it will only contain rows for addresses whose
* associated probability is >= 0.5
* @param plugin owning plugin
* @param program source program
* @param toClassify addresses to search
* @param modelRow trained model info
* @param debug is table displaying debug data
*/
public FunctionStartTableModel(PluginTool plugin, Program program, AddressSet toClassify,
RandomForestRowObject modelRow, boolean debug) {
super(program.getName(), plugin, program, null, false);
this.modelRow = modelRow;
this.addressesToClassify = toClassify;
this.debug = debug;
blockModel = new BasicBlockModel(program);
}
@Override
public Address getAddress(int row) {
return getRowObject(row).getAddress();
}
@Override
protected void doLoad(Accumulator<FunctionStartRowObject> accumulator, TaskMonitor monitor)
throws CancelledException {
//table might change as users create functions/disassemble code, so this method
//may be called repeatedly. However, the probabilities don't change, so only
//compute the probability that an address is a function start once.
if (addressToProbability == null) {
FunctionStartClassifier classifier = new FunctionStartClassifier(program, modelRow,
RandomForestFunctionFinderPlugin.FUNC_START);
//if debug, we want to display all errors in the test set
//if we left the prob threshold at .5 we would not see true function starts
//that the models thinks are not function starts
if (debug) {
classifier.setProbabilityThreshold(0.0);
}
addressToProbability = classifier.classify(addressesToClassify, monitor);
}
for (Entry<Address, Double> entry : addressToProbability.entrySet()) {
Address addr = entry.getKey();
FunctionStartRowObject rowObject = new FunctionStartRowObject(addr, entry.getValue());
setInterpretation(rowObject, monitor);
FunctionStartRowObject.setReferenceData(rowObject, program);
accumulator.add(rowObject);
}
}
@Override
protected TableColumnDescriptor<FunctionStartRowObject> createTableColumnDescriptor() {
TableColumnDescriptor<FunctionStartRowObject> descriptor = new TableColumnDescriptor<>();
descriptor.addVisibleColumn(new AddressTableColumn());
descriptor.addVisibleColumn(new ProbabilityTableColumn(), 0, false);
descriptor.addVisibleColumn(new InterpretationTableColumn());
descriptor.addVisibleColumn(new DataReferencesTableColumn());
descriptor.addVisibleColumn(new UnconditionalFlowReferencesTableColumn());
descriptor.addVisibleColumn(new ConditionalFlowReferencesTableColumn());
return descriptor;
}
/**
* Returns the {@link RandomForestRowObject} corresponding to this row
* @return row object
*/
RandomForestRowObject getRandomForestRowObject() {
return modelRow;
}
/**
* Determines and sets he {@link Interpretation} of the address corresponding to
* {@code rowObject}.
* @param rowObject row of table
* @param monitor monitor
* @throws CancelledException if user cancels
*/
void setInterpretation(FunctionStartRowObject rowObject, TaskMonitor monitor)
throws CancelledException {
Interpretation inter =
Interpretation.getInterpretation(program, rowObject.getAddress(), blockModel, monitor);
rowObject.setCurrentInterpretation(inter);
}
private class AddressTableColumn
extends AbstractDynamicTableColumn<FunctionStartRowObject, Address, Object> {
@Override
public String getColumnName() {
return "Address";
}
@Override
public Address getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
ServiceProvider services) throws IllegalArgumentException {
return rowObject.getAddress();
}
}
private class ProbabilityTableColumn
extends AbstractDynamicTableColumn<FunctionStartRowObject, Double, Object> {
@Override
public String getColumnName() {
return "Probability";
}
@Override
public Double getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
ServiceProvider services) throws IllegalArgumentException {
return rowObject.getProbability();
}
}
private class InterpretationTableColumn
extends AbstractDynamicTableColumn<FunctionStartRowObject, Interpretation, Object> {
@Override
public String getColumnName() {
return "Interpretation";
}
@Override
public Interpretation getValue(FunctionStartRowObject rowObject, Settings settings,
Object data, ServiceProvider services) throws IllegalArgumentException {
return rowObject.getCurrentInterpretation();
}
}
private class DataReferencesTableColumn
extends AbstractDynamicTableColumn<FunctionStartRowObject, Integer, Object> {
@Override
public String getColumnName() {
return "Data Refs";
}
@Override
public Integer getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
ServiceProvider services) throws IllegalArgumentException {
return rowObject.getNumDataRefs();
}
}
private class UnconditionalFlowReferencesTableColumn
extends AbstractDynamicTableColumn<FunctionStartRowObject, Integer, Object> {
@Override
public String getColumnName() {
return "Unconditional Flow Refs";
}
@Override
public Integer getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
ServiceProvider services) throws IllegalArgumentException {
return rowObject.getNumUnconditionalFlowRefs();
}
}
private class ConditionalFlowReferencesTableColumn
extends AbstractDynamicTableColumn<FunctionStartRowObject, Integer, Object> {
@Override
public String getColumnName() {
return "Conditional Flow Refs";
}
@Override
public Integer getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
ServiceProvider services) throws IllegalArgumentException {
return rowObject.getNumConditionalFlowRefs();
}
}
}

View file

@ -0,0 +1,172 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.awt.*;
import javax.swing.*;
import ghidra.app.services.GoToService;
import ghidra.framework.model.*;
import ghidra.framework.plugintool.ComponentProviderAdapter;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.listing.Program;
import ghidra.program.util.ChangeManager;
import ghidra.util.HelpLocation;
import ghidra.util.table.*;
/**
* A {@link ComponentProviderAdapter} for displaying tables of addresses that are likely
* function starts
*/
public class FunctionStartTableProvider extends ProgramAssociatedComponentProviderAdapter
implements DomainObjectListener {
private JComponent component;
private FunctionStartTableModel model;
private RandomForestFunctionFinderPlugin plugin;
private Program program;
private RandomForestRowObject modelRow;
private AddressSet toClassify;
private boolean debug;
private String subTitle;
private GhidraTable startTable;
private GhidraThreadedTablePanel<FunctionStartRowObject> tablePanel;
/**
* Constructs a table provider for the table of addresses to classify. If {@code debug}
* is true, a debug version of the table will be created.
*
* @param plugin owning plugin
* @param program program containing addresses to classify
* @param toClassify addresses to classify
* @param modelRow model to apply
* @param debug whether to display debug version of table
*/
public FunctionStartTableProvider(RandomForestFunctionFinderPlugin plugin, Program program,
AddressSet toClassify, RandomForestRowObject modelRow, boolean debug) {
super(
debug ? "Debug: Test Set Errors in " + program.getDomainFile().getPathname()
: "Potential Functions in " + program.getDomainFile().getPathname(),
plugin.getName(), program, plugin);
this.program = program;
this.plugin = plugin;
this.modelRow = modelRow;
this.toClassify = toClassify;
this.debug = debug;
subTitle = "Pre-bytes:" + modelRow.getNumPreBytes() + " Initial bytes:" +
modelRow.getNumInitialBytes() + " Sampling Factor:" + modelRow.getSamplingFactor();
component = build();
program.addListener(this);
String anchor = debug ? "DebugModelTable" : "FunctionStartTable";
setHelpLocation(new HelpLocation(plugin.getName(), anchor));
}
@Override
public JComponent getComponent() {
return component;
}
@Override
public void domainObjectChanged(DomainObjectChangedEvent ev) {
if (!isVisible()) {
return;
}
if (ev.containsEvent(DomainObject.DO_OBJECT_RESTORED)) {
model.reload();
contextChanged();
}
for (int i = 0; i < ev.numRecords(); ++i) {
DomainObjectChangeRecord doRecord = ev.getChangeRecord(i);
int eventType = doRecord.getEventType();
switch (eventType) {
case ChangeManager.DOCR_FUNCTION_ADDED:
case ChangeManager.DOCR_FUNCTION_REMOVED:
case ChangeManager.DOCR_CODE_ADDED:
case ChangeManager.DOCR_CODE_REMOVED:
case ChangeManager.DOCR_CODE_REPLACED:
case ChangeManager.DOCR_MEM_REF_TYPE_CHANGED:
case ChangeManager.DOCR_MEM_REFERENCE_ADDED:
case ChangeManager.DOCR_MEM_REFERENCE_REMOVED:
model.reload();
contextChanged();
default:
break;
}
}
}
/**
* Returns the underlying {@link GhidraTable}
* @return table
*/
GhidraTable getTable() {
return startTable;
}
/**
* Returns the table model of this provider.
* @return table model
*/
FunctionStartTableModel getTableModel() {
return model;
}
private JComponent build() {
JPanel panel = new JPanel(new BorderLayout());
Component table = buildTablePanel();
panel.add(table, BorderLayout.CENTER);
return panel;
}
private Component buildTablePanel() {
model = new FunctionStartTableModel(plugin.getTool(), program, toClassify, modelRow, debug);
tablePanel = new GhidraThreadedTablePanel<>(model, 1000);
startTable = tablePanel.getTable();
startTable.setName("Potential Functions in " + model.getProgram().getName());
GoToService goToService = tool.getService(GoToService.class);
if (goToService != null) {
startTable.installNavigation(goToService, goToService.getDefaultNavigatable());
}
startTable.setNavigateOnSelectionEnabled(true);
startTable.setAutoResizeMode(JTable.AUTO_RESIZE_SUBSEQUENT_COLUMNS);
startTable.setPreferredScrollableViewportSize(new Dimension(900, 300));
startTable.setRowSelectionAllowed(true);
startTable.getSelectionModel().addListSelectionListener(e -> tool.contextChanged(this));
model.addTableModelListener(e -> {
int rowCount = model.getRowCount();
int unfilteredCount = model.getUnfilteredRowCount();
StringBuilder buffy = new StringBuilder();
buffy.append(" ").append(rowCount).append(" items");
if (rowCount != unfilteredCount) {
buffy.append(" (of ").append(unfilteredCount).append(" )");
}
setSubTitle(subTitle + buffy.toString());
});
JPanel container = new JPanel(new BorderLayout());
container.add(tablePanel, BorderLayout.CENTER);
var tableFilterPanel = new GhidraTableFilterPanel<>(startTable, model);
container.add(tableFilterPanel, BorderLayout.SOUTH);
return container;
}
}

View file

@ -0,0 +1,100 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import ghidra.program.model.address.*;
import ghidra.program.model.listing.*;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.Task;
import ghidra.util.task.TaskMonitor;
/**
* A {@link Task} for gathering addresses to feed to a function start classifier
*/
public class GetAddressesToClassifyTask extends Task {
private Program prog;
private AddressSet execNonFunc;
private long minUndefinedRangeSize;
/**
* Creates a {@link Task} that creates a set of addresses to check for function starts. The
* {code minUndefinedRangeSize} parameter determines how large a run of undefined bytes must be
* to be checked for function starts
* @param prog source program
* @param minUndefinedRangeSize minimum size of undefined range
*/
public GetAddressesToClassifyTask(Program prog, long minUndefinedRangeSize) {
super("Gathering Addresses to Classify", true, true, false, false);
this.prog = prog;
this.minUndefinedRangeSize = minUndefinedRangeSize;
}
@Override
public void run(TaskMonitor monitor) throws CancelledException {
execNonFunc = new AddressSet();
AddressSetView executable = prog.getMemory().getExecuteSet();
AddressSetView initialized = prog.getMemory().getLoadedAndInitializedAddressSet();
execNonFunc = executable.intersect(initialized);
monitor.initialize(prog.getFunctionManager().getFunctionCount());
FunctionIterator fIter = prog.getFunctionManager().getFunctions(true);
while (fIter.hasNext()) {
monitor.checkCanceled();
monitor.incrementProgress(1);
Function func = fIter.next();
execNonFunc = execNonFunc.subtract(func.getBody());
}
//remove small undefined ranges to avoid (for example) searching for
//function starts in an address range of length 3 between two known
//functions. "small" is controlled by a plugin option.
AddressSetView undefinedRanges =
prog.getListing().getUndefinedRanges(execNonFunc, true, monitor);
AddressSet toRemove = new AddressSet();
AddressRangeIterator iter = undefinedRanges.getAddressRanges(true);
while (iter.hasNext()) {
AddressRange range = iter.next();
if (range.getLength() <= minUndefinedRangeSize) {
toRemove.add(range);
}
}
execNonFunc = execNonFunc.subtract(toRemove);
}
/**
* Returns the set of addresses to classify
* @return addresses
*/
public AddressSet getAddressesToClassify() {
return execNonFunc;
}
/**
* Returns the subsets of the addresses to classify consisting of all addresses
* which are aligned relative to the given modulus.
* @param modulus alignment modulus
* @return aligned addresses
*/
public AddressSet getAddressesToClassify(long modulus) {
AddressSet aligned = new AddressSet();
for (Address a : execNonFunc.getAddresses(true)) {
if (a.getOffset() % modulus == 0) {
aligned.add(a);
}
}
return aligned;
}
}

View file

@ -0,0 +1,99 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import ghidra.program.model.address.Address;
import ghidra.program.model.block.BasicBlockModel;
import ghidra.program.model.listing.*;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
/**
* An enum representing possible interpretations of addresses
* (e.g. data, undefined, function start ...)
*/
public enum Interpretation {
UNDEFINED("Undefined"),
DATA("Data"),
OFFCUT("Offcut"),
BLOCK_START("Block Start"),
WITHIN_BLOCK("Within Block"),
FUNCTION_START("Function Start"),
//possibly want to refine this to block starts within functions
//and within block within functions
FUNCTION_INTERIOR("Function Interior");
private String display;
Interpretation(String display) {
this.display = display;
}
@Override
public String toString() {
return display;
}
/**
* Get the {@link Interpretation} for the given address in the given program.
* @param program source program
* @param addr address
* @param monitor monitor
* @return interpretation of addr
* @throws CancelledException if user cancels monitor
*/
public static Interpretation getInterpretation(Program program, Address addr,
TaskMonitor monitor) throws CancelledException {
BasicBlockModel model = new BasicBlockModel(program);
return getInterpretation(program, addr, model, monitor);
}
/**
* Get the {@link Interpretation} for the given address in the given program. This
* method is intended to be called repeatedly in a loop, so it takes a {@link BasicBlockModel}
* as a parameter (which only need be created once).
* @param program source program
* @param addr address in question
* @param model block model
* @param monitor task model
* @return interpretation of addr
* @throws CancelledException if user cancels monitor
*/
public static Interpretation getInterpretation(Program program, Address addr,
BasicBlockModel model, TaskMonitor monitor) throws CancelledException {
CodeUnit cu = program.getListing().getCodeUnitContaining(addr);
if (cu instanceof Data) {
if (((Data) cu).isDefined()) {
return DATA;
}
return UNDEFINED;
}
if (program.getListing().getInstructionAt(addr) == null &&
program.getListing().getInstructionContaining(addr) != null) {
return OFFCUT;
}
if (program.getFunctionManager().getFunctionAt(addr) != null) {
return FUNCTION_START;
}
if (program.getFunctionManager().getFunctionContaining(addr) != null) {
return FUNCTION_INTERIOR;
}
if (model.getCodeBlockAt(addr, monitor) == null) {
return WITHIN_BLOCK;
}
return Interpretation.BLOCK_START;
}
}

View file

@ -0,0 +1,227 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.*;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.classification.Label;
import org.tribuo.impl.ArrayExample;
import ghidra.program.model.address.*;
import ghidra.program.model.listing.*;
import ghidra.program.model.mem.MemoryAccessException;
import ghidra.program.model.mem.MemoryBlock;
import ghidra.util.Msg;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
/**
* This is a utility class containing static methods used when creating training/test
* sets to train models to recognize function starts
*/
public class ModelTrainingUtils {
public static final int MAX_PRECEDING_CODE_UNIT_SIZE = 8;
public static final double ZERO = 0.0d;
public static final double ONE = 1.0d;
//utility class
private ModelTrainingUtils() {
}
/**
* Creates a feature vector consisting of byte-level and optionally bit-level features around
* {@code address}
* @param program source program
* @param address address
* @param numPreBytes number of bytes to use preceding {@code address}
* @param numInitialBytes number of bytes to use including and after address
* @param includeBitFeatures whether to include bit-level features
* @return feature vector
*/
public static List<Feature> getFeatureVector(Program program, Address address, int numPreBytes,
int numInitialBytes, boolean includeBitFeatures) {
MemoryBlock block = program.getMemory().getBlock(address);
byte[] preBytesArray = new byte[numPreBytes];
byte[] initialBytesArray = new byte[numInitialBytes];
List<Feature> trainingVector = new ArrayList<>();
try {
Address preStart = address.add(-numPreBytes);
block.getBytes(preStart, preBytesArray);
block.getBytes(address, initialBytesArray);
}
catch (MemoryAccessException | AddressOutOfBoundsException e) {
//most likely an exception means that you are trying to read beyond a block
//boundary. This will happen occasionally when the sliding window is near the
//begining or end of a block.
Msg.warn(RandomForestTrainingTask.class,
"MemoryAccessException at " + address.toString());
return trainingVector;
}
for (int i = 0; i < numPreBytes; i++) {
int currentByte = Byte.toUnsignedInt(preBytesArray[i]);
trainingVector.add(new Feature("pbyte_" + i, currentByte));
if (!includeBitFeatures) {
continue;
}
for (int bit = 7; bit >= 0; bit--) {
String featureName = "pbit_" + i + "_" + bit;
double val = ((currentByte & (1 << bit)) > 0) ? ONE : ZERO;
trainingVector.add(new Feature(featureName, val));
}
}
for (int i = 0; i < numInitialBytes; i++) {
int currentByte = Byte.toUnsignedInt(initialBytesArray[i]);
trainingVector.add(new Feature("ibyte_" + i, currentByte));
if (!includeBitFeatures) {
continue;
}
for (int bit = 7; bit >= 0; bit--) {
String featureName = "ibit_" + i + "_" + bit;
double val = ((currentByte & (1 << bit)) > 0) ? ONE : ZERO;
trainingVector.add(new Feature(featureName, val));
}
}
return trainingVector;
}
/**
* Returns an {@link AddressSet} constructed as follows: for each {@link Address} {@code addr}
* in {@code addresses}, add the {@link Address} of the {@link CodeUnit} returned by
* {@link Listing#getCodeUnitAfter(Address)}
* <p> Addresses which correspond to function starts are not added to the returned set.
* @param program source program
* @param addresses addresses to follow
* @param monitor monitor
* @return following addresses
* @throws CancelledException if the monitor is canceled
*/
public static AddressSet getFollowingAddresses(Program program, AddressSetView addresses,
TaskMonitor monitor) throws CancelledException {
AddressSet following = new AddressSet();
for (Address addr : addresses.getAddresses(true)) {
monitor.checkCanceled();
CodeUnit cu = program.getListing().getCodeUnitAfter(addr);
if (cu == null) {
continue;
}
if (program.getFunctionManager().getFunctionAt(cu.getAddress()) != null) {
Msg.warn(ModelTrainingUtils.class,
"Function start following " + addr.toString() + ", skipping...");
continue;
}
following.add(cu.getAddress());
}
return following;
}
/**
* Returns an {@link AddressSet} constructed as follows: for each {@link Address} {@code addr}
* in {@code addresses}, add the {@link Address} of the {@link CodeUnit} returned by
* {@link Listing#getCodeUnitBefore(Address)}
* <p> Addresses which correspond to function starts are not added to the returned set.
* Addresses of {@link CodeUnit}s which are more than
* {@link ModelTrainingUtils#MAX_PRECEDING_CODE_UNIT_SIZE} bytes away from the addresses they
* precede are also not added to the returned set.
* @param program source program
* @param addresses addresses to precede
* @param monitor monitor
* @return preceding addresses
* @throws CancelledException if the monitor is canceled
*/
public static AddressSet getPrecedingAddresses(Program program, AddressSetView addresses,
TaskMonitor monitor) throws CancelledException {
AddressSet preceding = new AddressSet();
for (Address addr : addresses.getAddresses(true)) {
monitor.checkCanceled();
CodeUnit cu = program.getListing().getCodeUnitBefore(addr);
if (cu == null) {
continue;
}
if (program.getFunctionManager().getFunctionAt(cu.getAddress()) != null) {
Msg.warn(ModelTrainingUtils.class,
"Function start preceding " + addr.toString() + ", skipping...");
continue;
}
if (addr.getOffset() - cu.getAddress().getOffset() > MAX_PRECEDING_CODE_UNIT_SIZE) {
continue;
}
preceding.add(cu.getAddress());
}
return preceding;
}
/**
* Returns an {@link AddressSet} consisting of all {@link Address}es where data is defined
* in {@code program}. Note that this includes addresses within defined data and not just
* addresses where defined data starts.
* @param program source program
* @param monitor task monitor
* @return addresses where data is defined
* @throws CancelledException if monitor is canceled
*/
public static AddressSet getDefinedData(Program program, TaskMonitor monitor)
throws CancelledException {
DataIterator dataIter =
program.getListing().getDefinedData(program.getMemory().getExecuteSet(), true);
AddressSet definedData = new AddressSet();
for (Data d : dataIter) {
monitor.checkCanceled();
definedData.add(
new AddressRangeImpl(d.getAddress(), d.getAddress().add(d.getLength() - 1)));
}
return definedData;
}
/**
* Generates a feature vector for every address in {@code source} and applies the {@Label}
* {@code label}
*
* @param program program
* @param source input addresses
* @param label label to apply
* @param numPreBytes bytes before address to include
* @param numInitialBytes bytes after and including addresses
* @param includeBitFeatures whether to include bit-level features
* @param monitor monitor
* @return list of vectors
* @throws CancelledException if monitor is canceled
*/
public static List<Example<Label>> getVectorsFromAddresses(Program program,
AddressSetView source, Label label, int numPreBytes, int numInitialBytes,
boolean includeBitFeatures, TaskMonitor monitor) throws CancelledException {
List<Example<Label>> examples = new ArrayList<>();
monitor.initialize(source.getNumAddresses());
Iterator<Address> addressIter = source.getAddresses(true);
while (addressIter.hasNext()) {
monitor.checkCanceled();
Address addr = addressIter.next();
monitor.incrementProgress(1L);
List<Feature> trainingVector =
getFeatureVector(program, addr, numPreBytes, numInitialBytes, includeBitFeatures);
if (trainingVector.isEmpty()) {
continue;
}
ArrayExample<Label> vec = new ArrayExample<>(label, trainingVector);
examples.add(vec);
}
return examples;
}
}

View file

@ -0,0 +1,63 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import ghidra.framework.plugintool.ComponentProviderAdapter;
import ghidra.program.model.listing.Program;
/**
* This class is used by {@link RandomForestFunctionFinderPlugin}, which can create various
* components for displaying potential function starts in a given program (and related info).
* The purpose of this class is to facilitate closing all components associated with a program
* when that program is closed.
*/
public abstract class ProgramAssociatedComponentProviderAdapter extends ComponentProviderAdapter {
private Program program;
private RandomForestFunctionFinderPlugin plugin;
/**
* Creates a {@link ComponentProviderAdapter} with an associated {@link Program} and {
* {@link RandomForestFunctionFinderPlugin}
* @param name name
* @param owner owner
* @param program associated program
* @param plugin plugin
*/
public ProgramAssociatedComponentProviderAdapter(String name, String owner, Program program,
RandomForestFunctionFinderPlugin plugin) {
super(plugin.getTool(), name, owner);
this.program = program;
this.plugin = plugin;
setTransient();
setWindowMenuGroup("Search for Code and Functions");
}
/**
* Returns the associated program
* @return the program
*/
Program getProgram() {
return program;
}
@Override
public void closeComponent() {
plugin.removeProvider(this);
super.closeComponent();
}
}

View file

@ -0,0 +1,235 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.*;
import org.tribuo.classification.Label;
import aQute.bnd.unmodifiable.Lists;
import docking.action.builder.ActionBuilder;
import docking.tool.ToolConstants;
import ghidra.MiscellaneousPluginPackage;
import ghidra.app.context.NavigatableActionContext;
import ghidra.app.context.RestrictedAddressSetContext;
import ghidra.app.events.ProgramClosedPluginEvent;
import ghidra.app.events.ProgramLocationPluginEvent;
import ghidra.app.plugin.PluginCategoryNames;
import ghidra.app.plugin.ProgramPlugin;
import ghidra.app.services.GoToService;
import ghidra.framework.options.OptionsChangeListener;
import ghidra.framework.options.ToolOptions;
import ghidra.framework.plugintool.PluginInfo;
import ghidra.framework.plugintool.PluginTool;
import ghidra.framework.plugintool.util.PluginStatus;
import ghidra.program.model.listing.Program;
import ghidra.program.util.ProgramSelection;
import ghidra.util.HelpLocation;
import ghidra.util.Msg;
import ghidra.util.bean.opteditor.OptionsVetoException;
//@formatter:off
@PluginInfo(
status = PluginStatus.UNSTABLE,
packageName = MiscellaneousPluginPackage.NAME,
category = PluginCategoryNames.ANALYSIS,
shortDescription = "Function Finder",
description = "Trains a random forest model to find function starts.",
servicesRequired = { GoToService.class},
eventsProduced = { ProgramLocationPluginEvent.class },
eventsConsumed = { ProgramClosedPluginEvent.class}
)
//@formatter:on
/**
* A {@link ProgramPlugin} for training a model on the starts of known functions in a
* program and then using that model to look for more functions (in the source program or
* another program selected by the user).
*/
public class RandomForestFunctionFinderPlugin extends ProgramPlugin
implements OptionsChangeListener {
public static final Label FUNC_START = new Label("S");
public static final Label NON_START = new Label("N");
private static final String ACTION_NAME = "Search for Code and Functions";
private static final String MENU_PATH_ENTRY = "For Code and Functions...";
private static final String TEST_SET_MAX_SIZE_OPTION_NAME = "Maximum size of test sets";
static final Long TEST_SET_MAX_SIZE_DEFAULT = 1000000l;
private Long testSetMax;
private static final String MIN_UNDEFINED_RANGE_SIZE_OPTION_NAME =
"Minimum Length of Undefined Range to Search";
static final Long MIN_UNDEFINED_RANGE_SIZE_DEFAULT = 16l;
private Long minUndefinedRangeSize;
private FunctionStartRFParamsDialog paramsDialog;
//this map is used to close all providers associated with a program p
//when p is closed
private Map<Program, List<ProgramAssociatedComponentProviderAdapter>> programsToProviders;
/**
* Creates the plugin for the given tool.
* @param tool tool for plugin
*/
public RandomForestFunctionFinderPlugin(PluginTool tool) {
super(tool, false, true);
programsToProviders = new HashMap<>();
}
@Override
public void init() {
createActions();
initOptions(getTool().getOptions("Random Forest Function Finder"));
}
@Override
public void optionsChanged(ToolOptions options, String optionName, Object oldValue,
Object newValue) throws OptionsVetoException {
switch (optionName) {
case TEST_SET_MAX_SIZE_OPTION_NAME:
Long newMax = (Long) newValue;
if (newMax <= 0) {
//does this actually inform the user of the problem?
throw new OptionsVetoException(
TEST_SET_MAX_SIZE_OPTION_NAME + " must be positive!");
}
testSetMax = newMax;
break;
case MIN_UNDEFINED_RANGE_SIZE_OPTION_NAME:
Long newMin = (Long) newValue;
if (newMin <= 0) {
throw new OptionsVetoException(
MIN_UNDEFINED_RANGE_SIZE_OPTION_NAME + " must be positive!");
}
minUndefinedRangeSize = newMin;
break;
default:
Msg.showError(this, null, "Unknown option", "Unknown option: " + optionName);
}
}
/**
* Record the existence of a {@link ProgramAssociatedComponentProviderAdapter} so that it can
* be closed if its associated program is closed
* @param provider provider
*/
void addProvider(ProgramAssociatedComponentProviderAdapter provider) {
List<ProgramAssociatedComponentProviderAdapter> providers =
programsToProviders.computeIfAbsent(provider.getProgram(), p -> new ArrayList<>());
providers.add(provider);
tool.addComponentProvider(provider, true);
}
/**
* Remove the provider from the list of tracked providers
* @param provider provider
*/
void removeProvider(ProgramAssociatedComponentProviderAdapter provider) {
programsToProviders.get(provider.getProgram()).remove(provider);
}
/**
* Sets the current selection. Cf. {@link FunctionStartRFParamsDialog#addGeneralActions}
* @param selection new selection
*/
void setSelection(ProgramSelection selection) {
currentSelection = selection;
}
/**
* Returns the maximum size of a test set. Users can set this value via a plugin option.
* @return max size
*/
Long getTestMaxSize() {
return testSetMax;
}
/**
* Returns the minimum size of an undefined range to search for function starts. Users
* can set this value via a plugion option.
* @return min undefined range size
*/
Long getMinUndefinedRangeSize() {
return minUndefinedRangeSize;
}
/**
* Null out the dialog
*/
void resetDialog() {
paramsDialog = null;
}
@Override
protected void programClosed(Program p) {
//ProgramAssociatedComponentProviderAdapter.closeComponent modifies values of
//programsToProviders, so make a copy to avoid a ConcurrentModificationException
List<ProgramAssociatedComponentProviderAdapter> providersToClose =
Lists.copyOf(programsToProviders.getOrDefault(p, Collections.emptyList()));
for (ProgramAssociatedComponentProviderAdapter provider : providersToClose) {
provider.closeComponent();
}
programsToProviders.remove(p);
if (paramsDialog == null) {
return;
}
if (!paramsDialog.getTrainingSource().equals(p)) {
return;
}
paramsDialog.dismissCallback();
}
private void createActions() {
new ActionBuilder(ACTION_NAME, getName())
.menuPath(ToolConstants.MENU_SEARCH, MENU_PATH_ENTRY)
.menuGroup("search for", null)
.description("Train models to search for function starts")
.helpLocation(new HelpLocation(getName(), getName()))
.withContext(NavigatableActionContext.class)
.validContextWhen(c -> !(c instanceof RestrictedAddressSetContext))
.supportsDefaultToolContext(true)
.onAction(c -> {
displayDialog(c);
})
.buildAndInstall(tool);
}
private void displayDialog(NavigatableActionContext c) {
if (paramsDialog == null) {
paramsDialog = new FunctionStartRFParamsDialog(this);
}
if (!paramsDialog.getTrainingSource().equals(this.getCurrentProgram())) {
paramsDialog.dismissCallback();
paramsDialog = new FunctionStartRFParamsDialog(this);
}
tool.showDialog(paramsDialog, c.getComponentProvider());
}
private void initOptions(ToolOptions options) {
options.registerOption(TEST_SET_MAX_SIZE_OPTION_NAME, TEST_SET_MAX_SIZE_DEFAULT,
new HelpLocation(getName(), "MaxTestSetSize"),
"Maximum sizes for test sets (must be positive).");
testSetMax = options.getLong(TEST_SET_MAX_SIZE_OPTION_NAME, TEST_SET_MAX_SIZE_DEFAULT);
options.registerOption(MIN_UNDEFINED_RANGE_SIZE_OPTION_NAME,
MIN_UNDEFINED_RANGE_SIZE_DEFAULT,
new HelpLocation(getName(), "MinLengthUndefinedRange"),
"Minimum Size of an Undefined AddressRange to search (must be positive).");
minUndefinedRangeSize =
options.getLong(MIN_UNDEFINED_RANGE_SIZE_OPTION_NAME, MIN_UNDEFINED_RANGE_SIZE_DEFAULT);
options.addOptionsChangeListener(this);
}
}

View file

@ -0,0 +1,210 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.math.*;
import java.util.Collections;
import java.util.List;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.EnsembleModel;
import ghidra.program.model.address.AddressSet;
/**
* A class for row objects in a table whose rows are associated with models trained to
* find function starts. Some of the fields of this class are used to populate these
* rows with data about how accurate the model was on the test set. Other fields
* (such as {@code numPreBytes}, {@code numInitialBytes}, and {@code includeBitLevelFeatures})
* are needed by actions defined on this table apply the model (and thus must know how
* to generate feature vectors the model consumes).
*/
public class RandomForestRowObject {
private BigDecimal precision;
private BigDecimal recall;
private int numPreBytes;
private int numInitialBytes;
private int samplingFactor;
private boolean includeBitLevelFeatures;
private List<String> contextRegisters;
private List<BigInteger> contextRegisterValues;
private EnsembleModel<Label> randomForest;
private AddressSet testErrors;
private AddressSet trainingPositive;
private int[] confusionMatrix;
/**
* Constructs a row
* @param numPreBytes number of prebytes in vectors consumed by model
* @param numInitialBytes number of initialBytes in vectors consumed by model
* @param samplingFactor non-start to start sampling factor
* @param confusionMatrix confusion matrix of model on test set
* @param randomForest model
* @param testErrors set of addresses in test set with errors
* @param trainingPositive set of positive training examples (i.e. function starts)
* @param includeBitLevelFeatures whether bit-level features were included in model
*/
public RandomForestRowObject(int numPreBytes, int numInitialBytes, int samplingFactor,
int[] confusionMatrix, EnsembleModel<Label> randomForest, AddressSet testErrors,
AddressSet trainingPositive, boolean includeBitLevelFeatures) {
this.numPreBytes = numPreBytes;
this.numInitialBytes = numInitialBytes;
this.samplingFactor = samplingFactor;
this.randomForest = randomForest;
this.testErrors = testErrors;
this.contextRegisters = Collections.emptyList();
this.contextRegisterValues = Collections.emptyList();
this.confusionMatrix = confusionMatrix;
this.includeBitLevelFeatures = includeBitLevelFeatures;
BigDecimal numerator = new BigDecimal(confusionMatrix[RandomForestTrainingTask.TP]);
BigDecimal denominator = new BigDecimal(confusionMatrix[RandomForestTrainingTask.TP] +
confusionMatrix[RandomForestTrainingTask.FP]);
if (denominator.equals(BigDecimal.ZERO)) {
precision = null;
}
else {
precision = numerator.divide(denominator, 2, RoundingMode.HALF_EVEN);
}
denominator = new BigDecimal(confusionMatrix[RandomForestTrainingTask.TP] +
confusionMatrix[RandomForestTrainingTask.FN]);
if (denominator.equals(BigDecimal.ZERO)) {
recall = null;
}
else {
recall = numerator.divide(denominator, 2, RoundingMode.HALF_EVEN);
}
this.trainingPositive = trainingPositive;
}
/**
* Sets the values for context register the model is aware of
* @param regList register names
* @param valueList register values
*/
public void setContextRegistersAndValues(List<String> regList, List<BigInteger> valueList) {
if (regList.size() != valueList.size()) {
throw new IllegalArgumentException(
"Register list and value list must have the same size!");
}
contextRegisters = List.copyOf(regList);
contextRegisterValues = List.copyOf(valueList);
}
/**
* Returns a boolean indicating whether the model is aware of any context registers
* @return aware of context
*/
public boolean isContextRestricted() {
return !contextRegisters.isEmpty();
}
/**
* Returns the names of the context registers the model is aware of
* @return context reg names
*/
public List<String> getContextRegisterList() {
return contextRegisters;
}
/**
* Returns the list of values of context registers the model is aware of
* @return context reg values
*/
public List<BigInteger> getContextRegisterValues() {
return contextRegisterValues;
}
/**
* Returns the precision of the model on the test set
* @return precision
*/
public BigDecimal getPrecision() {
return precision;
}
/**
* Returns the recall of the model on the test set
* @return recall
*/
public BigDecimal getRecall() {
return recall;
}
/**
* Returns a boolean indicating whether bit-level features were included when training the model
* @return bit-level features used
*/
public boolean getIncludeBitLevelFeatures() {
return includeBitLevelFeatures;
}
/**
* Returns the number of pre-bytes used when training the model
* @return pre-bytes
*/
public int getNumPreBytes() {
return numPreBytes;
}
/**
* Returns the sampling factor used when training the model
* @return sampling factor
*/
public int getSamplingFactor() {
return samplingFactor;
}
/**
* Returns the number of initial bytes used when training the model
* @return num initial bytes
*/
public int getNumInitialBytes() {
return numInitialBytes;
}
/**
* Returns the model
* @return model
*/
public EnsembleModel<Label> getRandomForest() {
return randomForest;
}
/**
* Returns the addresses in the test set where the model made an error
* @return error set
*/
public AddressSet getTestErrors() {
return testErrors;
}
/**
* Returns the number of false positives the model produces when classifying the test set.
* @return num false positives
*/
public int getNumFalsePositives() {
return confusionMatrix[RandomForestTrainingTask.FP];
}
/**
* Returns the set of function starts in the training set.
* @return known starts
*/
public AddressSet getTrainingPositives() {
return trainingPositive;
}
}

View file

@ -0,0 +1,162 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.math.BigDecimal;
import java.util.List;
import docking.widgets.table.AbstractDynamicTableColumn;
import docking.widgets.table.TableColumnDescriptor;
import docking.widgets.table.threaded.ThreadedTableModelStub;
import ghidra.docking.settings.Settings;
import ghidra.framework.plugintool.ServiceProvider;
import ghidra.util.datastruct.Accumulator;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
/**
* A table model for tables that display information about random forests trained to find
* function starts.
*/
public class RandomForestTableModel extends ThreadedTableModelStub<RandomForestRowObject> {
private static final String MODEL_NAME = "Random Forest Evaluations";
private List<RandomForestRowObject> rowObjects;
/**
* Creates a table model
* @param serviceProvider service provider
* @param rowObjects rows of table
*/
public RandomForestTableModel(ServiceProvider serviceProvider,
List<RandomForestRowObject> rowObjects) {
super(MODEL_NAME, serviceProvider);
this.rowObjects = rowObjects;
}
@Override
protected void doLoad(Accumulator<RandomForestRowObject> accumulator, TaskMonitor monitor)
throws CancelledException {
accumulator.addAll(rowObjects);
}
@Override
protected TableColumnDescriptor<RandomForestRowObject> createTableColumnDescriptor() {
TableColumnDescriptor<RandomForestRowObject> descriptor = new TableColumnDescriptor<>();
descriptor.addVisibleColumn(new NumPreBytesColumn());
descriptor.addVisibleColumn(new NumInitialBytesColumn());
descriptor.addVisibleColumn(new SamplingFactorColumn());
descriptor.addVisibleColumn(new FalsePositivesColumn(), 1, true);
descriptor.addVisibleColumn(new PrecisionTableColumn());
descriptor.addVisibleColumn(new RecallTableColumn(), 2, false);
return descriptor;
}
/**
* Table column classes
*/
class NumPreBytesColumn
extends AbstractDynamicTableColumn<RandomForestRowObject, Integer, Object> {
@Override
public String getColumnName() {
return "Pre-Bytes";
}
@Override
public Integer getValue(RandomForestRowObject rowObject, Settings settings, Object data,
ServiceProvider sProvider) throws IllegalArgumentException {
return rowObject.getNumPreBytes();
}
}
class NumInitialBytesColumn
extends AbstractDynamicTableColumn<RandomForestRowObject, Integer, Object> {
@Override
public String getColumnName() {
return "Initial Bytes";
}
@Override
public Integer getValue(RandomForestRowObject rowObject, Settings settings, Object data,
ServiceProvider sProvider) throws IllegalArgumentException {
return rowObject.getNumInitialBytes();
}
}
class SamplingFactorColumn
extends AbstractDynamicTableColumn<RandomForestRowObject, Integer, Object> {
@Override
public String getColumnName() {
return "Factor";
}
@Override
public Integer getValue(RandomForestRowObject rowObject, Settings settings, Object data,
ServiceProvider sProvider) throws IllegalArgumentException {
return rowObject.getSamplingFactor();
}
}
class PrecisionTableColumn
extends AbstractDynamicTableColumn<RandomForestRowObject, BigDecimal, Object> {
@Override
public String getColumnName() {
return "Precision";
}
@Override
public BigDecimal getValue(RandomForestRowObject rowObject, Settings settings, Object data,
ServiceProvider sProvider) throws IllegalArgumentException {
return rowObject.getPrecision();
}
}
class RecallTableColumn
extends AbstractDynamicTableColumn<RandomForestRowObject, BigDecimal, Object> {
@Override
public String getColumnName() {
return "Recall";
}
@Override
public BigDecimal getValue(RandomForestRowObject rowObject, Settings settings, Object data,
ServiceProvider sProvider) throws IllegalArgumentException {
return rowObject.getRecall();
}
}
class FalsePositivesColumn
extends AbstractDynamicTableColumn<RandomForestRowObject, Integer, Object> {
@Override
public String getColumnName() {
return "False Positives";
}
@Override
public Integer getValue(RandomForestRowObject rowObject, Settings settings, Object data,
ServiceProvider sProvider) throws IllegalArgumentException {
return rowObject.getNumFalsePositives();
}
}
}

View file

@ -0,0 +1,485 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer;
import org.tribuo.*;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.dtree.CARTClassificationTrainer;
import org.tribuo.classification.ensemble.VotingCombiner;
import org.tribuo.common.tree.TreeModel;
import org.tribuo.dataset.DatasetView;
import org.tribuo.datasource.ListDataSource;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.ensemble.WeightedEnsembleModel;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import generic.concurrent.*;
import ghidra.program.model.address.*;
import ghidra.program.model.listing.Program;
import ghidra.program.util.ProgramSelection;
import ghidra.util.Msg;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.Task;
import ghidra.util.task.TaskMonitor;
/**
* This {@link Task} is used to train and evaluate random forests.
*/
public class RandomForestTrainingTask extends Task {
//NUM_TREES should be odd to avoid nuances involving tiebreaking
public static final int NUM_TREES = 99;
public static final String TITLE = "Training Model for ";
public static final int MAX_EVAL_SET_SIZE = 50000;
public static final String RANDOM_FOREST_TRAINING_THREADPOOL = "RandomForestTrainer";
public static final double NANOSECONDS_PER_SECOND = 1000000000.0d;
//for the confusion matrix
public static final int TP = 0;
public static final int FP = 1;
public static final int TN = 2;
public static final int FN = 3;
public static final int CONFUSION_MATRIX_SIZE = 4;
private FunctionStartRFParams params;
private Program program;
//trainingSet is accessed from different threads during training
private Dataset<Label> trainingSet;
private Consumer<RandomForestRowObject> rowObjectConsumer;
private AddressSet additionalStarts;
private AddressSet additionalNonStarts;
private Long testSetMax;
/**
* Creates a {@link Task} for training (and evaluating) random forests for function
* start identification.
*
* @param program source of training
* @param params parameters controlling training
* @param rowObjectConsumer consumes data about the trained models
* @param testSetMax maximum size of test sets
*/
public RandomForestTrainingTask(Program program, FunctionStartRFParams params,
Consumer<RandomForestRowObject> rowObjectConsumer, long testSetMax) {
super(TITLE + program.getName(), true, true, false, false);
this.program = program;
this.params = params;
this.rowObjectConsumer = rowObjectConsumer;
this.testSetMax = testSetMax;
additionalStarts = new AddressSet();
additionalNonStarts = new AddressSet();
}
/**
* Adds a {@ProgramSelection} to the training set. Function starts within the selection are
* added as positive examples and everything else is added as a negative example.
*
*<p>
* Addresses which are not aligned or which do not agree with the context register values
* in the {@code params} variable of the constructor are ignored.
* @param selection selection to add
* @return number of aligned addresses conflicting with the context register data specified in
* {@code params}
*/
public int setAdditional(ProgramSelection selection) {
if (selection == null) {
return 0;
}
int numConflicts = 0;
int instructionAlignment = program.getLanguage().getInstructionAlignment();
for (Address addr : selection.getAddresses(true)) {
if (addr.getOffset() % instructionAlignment != 0) {
continue;
}
if (params.isRestrictedByContext()) {
if (!params.isContextCompatible(addr)) {
numConflicts += 1;
continue;
}
}
if (program.getFunctionManager().getFunctionAt(addr) == null) {
additionalNonStarts.add(addr);
}
else {
additionalStarts.add(addr);
}
}
return numConflicts;
}
@Override
public void run(TaskMonitor monitor) throws CancelledException {
monitor.setIndeterminate(true);
monitor.setMessage("Gathering function entries and interiors");
//get all the function entries and interiors that are consistent
//with the requirements in params
params.computeFuncEntriesAndInteriors(monitor);
AddressSet allEntries = params.getFuncEntries();
AddressSet allInteriors = params.getFuncInteriors();
//defined data in executable sections will be added to the test set
AddressSet definedData = ModelTrainingUtils.getDefinedData(program, monitor);
for (Integer factor : params.getSamplingFactors()) {
//get the addresses used for training and testing
TrainingAndTestData data =
getTrainingAndTestData(allEntries, allInteriors, definedData, factor, monitor);
if (data == null) {
continue;
}
data.reduceTestSetSize(testSetMax, monitor);
monitor.setIndeterminate(false);
for (Integer preBytes : params.getPreBytes()) {
for (Integer initialBytes : params.getInitialBytes()) {
Msg.info(this,
String.format(
"Data Gathering Parameters: factor: %d preBytes: %s initialBytes: %d",
factor, preBytes, initialBytes));
//create the vectors for the training addresses
List<Example<Label>> trainingData = new ArrayList<>();
monitor.setMessage("Generating vectors for function entries");
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program,
data.getTrainingPositive(), RandomForestFunctionFinderPlugin.FUNC_START,
preBytes, initialBytes, params.getIncludeBitFeatures(), monitor));
monitor.setMessage("Generating vectors for function interiors");
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program,
data.getTrainingNegative(), RandomForestFunctionFinderPlugin.NON_START,
preBytes, initialBytes, params.getIncludeBitFeatures(), monitor));
//should only happen with very small training sets where MemoryAccessExceptions
//were thrown during vector generation
if (trainingData.isEmpty()) {
Msg.showWarn(this, null, "Empty Training Set", String.format(
"No vectors were generated for supplied addresses. preBytes = %d, " +
"initialBytes = %d",
preBytes, initialBytes));
continue;
}
//train the model
EnsembleModel<Label> randomForest = trainModel(trainingData, monitor);
//evaluate the model and create a RandomForestRowObject
AddressSet errors = new AddressSet();
int[] confusionMatrix = evaluateModel(randomForest, data.getTestPositive(),
data.getTestNegative(), errors, preBytes, initialBytes, monitor);
if (!monitor.isCancelled()) {
RandomForestRowObject row = new RandomForestRowObject(preBytes,
initialBytes, factor, confusionMatrix, randomForest, errors,
data.getTrainingPositive(), params.getIncludeBitFeatures());
row.setContextRegistersAndValues(params.getContextRegisterNames(),
params.getContextRegisterVals());
rowObjectConsumer.accept(row);
}
}
}
}
}
/**
* Creates the training and test sets
* @param allEntries function entries
* @param allInteriors function interiors
* @param definedData defined data
* @param factor sampling factor
* @param monitor task monitor
* @return training and test sets
* @throws CancelledException if monitor is canceled
*/
TrainingAndTestData getTrainingAndTestData(AddressSet allEntries, AddressSet allInteriors,
AddressSet definedData, int factor, TaskMonitor monitor) throws CancelledException {
//if the user has specified addresses to use in training,
//don't allow those addresses to also be selected at random
AddressSet selectableEntries = allEntries.subtract(additionalStarts);
AddressSet selectableInteriors = allInteriors.subtract(additionalNonStarts);
AddressSet trainingPositive = new AddressSet(additionalStarts);
AddressSet trainingNegative = new AddressSet(additionalNonStarts);
//select function entries at random and add them to the training set
long numEntries =
(int) Math.min(params.getMaxStarts(), selectableEntries.getNumAddresses());
monitor.setIndeterminate(true);
monitor.setMessage("Selecting " + numEntries + " random function entries");
long start = System.nanoTime();
AddressSetView randomFuncEntries =
RandomSubsetUtils.randomSubset(selectableEntries, numEntries, monitor);
long end = System.nanoTime();
Msg.info(this, String.format("factor: %d elapsed selecting random entries: %g", factor,
(end - start) / NANOSECONDS_PER_SECOND));
trainingPositive = trainingPositive.union(randomFuncEntries);
if (trainingPositive.isEmpty()) {
Msg.showError(this, null, "Data Gathering Error", "No functions in training set");
return null;
}
//function entries that weren't selected are used for testing
AddressSet testPositive = selectableEntries.subtract(randomFuncEntries);
if (testPositive.isEmpty()) {
Msg.showWarn(this, null, "Test Set Warning",
"No function entries in test set for models with sampling factor " + factor);
}
//for the randomly-selected function entries, optionally add the immediately
//preceding and following code units to the training and test sets as negative examples
AddressSet immediatelyPrecedingTraining = new AddressSet();
AddressSet immediatelyFollowingTraining = new AddressSet();
AddressSet immediatelyPrecedingTest = new AddressSet();
AddressSet immediatelyFollowingTest = new AddressSet();
if (params.getIncludePrecedingAndFollowing()) {
immediatelyPrecedingTraining =
ModelTrainingUtils.getPrecedingAddresses(program, randomFuncEntries, monitor);
immediatelyFollowingTraining =
ModelTrainingUtils.getFollowingAddresses(program, randomFuncEntries, monitor);
immediatelyPrecedingTest =
ModelTrainingUtils.getPrecedingAddresses(program, testPositive, monitor);
immediatelyFollowingTest =
ModelTrainingUtils.getFollowingAddresses(program, testPositive, monitor);
//immediatelyPreceding and immediately following can intersect in
//certain cases; subtract the overlap
AddressSet before = immediatelyPrecedingTraining.union(immediatelyPrecedingTest);
AddressSet after = immediatelyFollowingTraining.union(immediatelyFollowingTest);
immediatelyPrecedingTraining = immediatelyPrecedingTraining.subtract(after);
immediatelyPrecedingTest = immediatelyPrecedingTest.subtract(after);
immediatelyFollowingTraining = immediatelyFollowingTraining.subtract(before);
immediatelyFollowingTest = immediatelyFollowingTest.subtract(before);
//Since immediatelyFollowingXand immediatelyPrecedingX are explicitly
//added to the training/test sets, remove them from the set of interiors which could be
//randomly selected for inclusion in the training set
selectableInteriors = selectableInteriors
.subtract(immediatelyFollowingTraining.union(immediatelyPrecedingTraining));
selectableInteriors = selectableInteriors
.subtract(immediatelyFollowingTest.union(immediatelyPrecedingTest));
//remove before and after from definedData
definedData = definedData.subtract(before.union(after));
}
trainingNegative = trainingNegative
.union(immediatelyPrecedingTraining.union(immediatelyFollowingTraining));
//now select random function interiors and add them to training set
monitor.setMessage(
"Selecting " + numEntries * factor + " random addresses within function interiors");
start = System.nanoTime();
AddressSetView randomFuncInteriors =
RandomSubsetUtils.randomSubset(selectableInteriors, numEntries * factor, monitor);
end = System.nanoTime();
Msg.info(this, String.format("factor: %d elapsed selecting random interiors: %g seconds",
factor, (end - start) / NANOSECONDS_PER_SECOND));
trainingNegative = trainingNegative.union(randomFuncInteriors);
if (trainingNegative.isEmpty()) {
Msg.showError(this, null, "Data Gathering Error",
"No function interiors in training set");
return null;
}
if (trainingPositive.intersects(trainingNegative)) {
Address first = trainingPositive.findFirstAddressInCommon(trainingNegative);
Msg.showWarn(this, null, "Overlap between Training Positive and Training Negative Sets",
"Example: " + first.toString());
}
AddressSet unusedInteriors = selectableInteriors.subtract(randomFuncInteriors);
AddressSet testNegative = unusedInteriors.union(definedData);
testNegative = testNegative.union(immediatelyPrecedingTest).union(immediatelyFollowingTest);
if (testNegative.isEmpty()) {
Msg.showWarn(this, null, "Test Set Warning",
"No function interiors in test set for models with sampling factor " + factor);
}
if (testPositive.intersects(testNegative)) {
Address first = testPositive.findFirstAddressInCommon(testNegative);
Msg.showWarn(this, null, "Overlapping Test Positive and Negative sets",
"Example: " + first.toString());
}
if ((trainingPositive.union(trainingNegative)
.intersects(testPositive.union(testNegative)))) {
Address first = trainingPositive.union(trainingNegative)
.findFirstAddressInCommon(testPositive.union(testNegative));
Msg.showWarn(this, null, "Overlapping Training and Test Sets",
"Example: " + first.toString());
}
return new TrainingAndTestData(trainingPositive, trainingNegative, testPositive,
testNegative);
}
/**
* Trains a model to recognize function entries. Training is performed in parallel.
* @param trainingData training vectors
* @param monitor task monitor
* @return model
* @throws CancelledException if monitor is canceled
*/
EnsembleModel<Label> trainModel(List<Example<Label>> trainingData, TaskMonitor monitor)
throws CancelledException {
LabelFactory lf = new LabelFactory();
ListDataSource<Label> trainingSource = new ListDataSource<>(trainingData, lf,
new SimpleDataSourceProvenance(program.getDomainFile().getPathname(), lf));
trainingSet = new MutableDataset<>(trainingSource);
//want to select from sqrt(num features) features at each split
float featureFraction = (float) (1.0f / Math.sqrt(trainingSet.getFeatureMap().size()));
List<CARTClassificationTrainer> trainers = new ArrayList<>();
for (int i = 0; i < NUM_TREES; ++i) {
monitor.checkCanceled();
//Integer.MAX_VALUE: unlimited depth
trainers.add(new CARTClassificationTrainer(Integer.MAX_VALUE, featureFraction,
ThreadLocalRandom.current().nextLong()));
}
GThreadPool threadPool = GThreadPool.getSharedThreadPool(RANDOM_FOREST_TRAINING_THREADPOOL);
monitor.initialize(NUM_TREES);
monitor.setMessage("Training random forest");
ConcurrentQBuilder<CARTClassificationTrainer, TreeModel<Label>> builder =
new ConcurrentQBuilder<>();
ConcurrentQ<CARTClassificationTrainer, TreeModel<Label>> q =
builder.setThreadPool(threadPool)
.setCollectResults(true)
.setMonitor(monitor)
.build(new SingleTreeTrainer());
q.add(trainers);
EnsembleModel<Label> randomForest = null;
try {
long start = System.nanoTime();
var results = q.waitForResults();
long end = System.nanoTime();
Msg.info(this,
String.format("Training time: %g seconds", (end - start) / NANOSECONDS_PER_SECOND));
List<Model<Label>> trees = new ArrayList<>();
for (var r : results) {
trees.add(r.getResult());
}
randomForest = WeightedEnsembleModel.createEnsembleFromExistingModels("rf", trees,
new VotingCombiner());
}
catch (Exception e) {
monitor.checkCanceled();
Msg.error(this, "Exception while training model: " + e.getMessage());
}
return randomForest;
}
/**
* Evaluates a model
*
* @param randomForest model to evaluate
* @param testPositive test set of function entries
* @param testNegative test set of function interiors
* @param errors set to place addresses with classifier errors
* @param preBytes number of bytes before entries
* @param initialBytes number of bytes before interiors
* @param monitor task monitor
* @return confusion matrix
* @throws CancelledException if monitor is canceled
*/
int[] evaluateModel(EnsembleModel<Label> randomForest, AddressSet testPositive,
AddressSet testNegative, AddressSet errors, int preBytes, int initialBytes,
TaskMonitor monitor) throws CancelledException {
GThreadPool threadPool = GThreadPool.getSharedThreadPool(RANDOM_FOREST_TRAINING_THREADPOOL);
long start = System.nanoTime();
monitor.setMessage(
"Evaluating model (step 1 of 2; " + testPositive.getNumAddresses() + " addresses)");
ConcurrentQBuilder<Address, Boolean> evalBuilder = new ConcurrentQBuilder<>();
ConcurrentQ<Address, Boolean> evalQ = evalBuilder.setThreadPool(threadPool)
.setCollectResults(true)
.setMonitor(monitor)
.build(new EnsembleEvaluatorCallback(randomForest, program, preBytes, initialBytes,
params.getIncludeBitFeatures(), RandomForestFunctionFinderPlugin.FUNC_START));
evalQ.add(testPositive.getAddresses(true));
int[] confusionMatrix = new int[4];
try {
Collection<QResult<Address, Boolean>> results = evalQ.waitForResults();
updateConfusionMatrix(results, confusionMatrix,
RandomForestFunctionFinderPlugin.FUNC_START, errors);
}
catch (Exception e) {
monitor.checkCanceled();
Msg.error(this,
"Exception while evaluating model on known function starts: " + e.getMessage());
}
monitor.setMessage(
"Evaluating model (step 2 of 2; " + testNegative.getNumAddresses() + " addresses)");
evalQ = evalBuilder.setThreadPool(threadPool)
.setCollectResults(true)
.setMonitor(monitor)
.build(new EnsembleEvaluatorCallback(randomForest, program, preBytes, initialBytes,
params.getIncludeBitFeatures(), RandomForestFunctionFinderPlugin.NON_START));
evalQ.add(testNegative.getAddresses(true));
try {
Collection<QResult<Address, Boolean>> results = evalQ.waitForResults();
updateConfusionMatrix(results, confusionMatrix,
RandomForestFunctionFinderPlugin.NON_START, errors);
}
catch (Exception e) {
monitor.checkCanceled();
Msg.error(this,
"Exception while evaluating model on known function interiors: " + e.getMessage());
}
long end = System.nanoTime();
Msg.info(this,
String.format("Evaluation time: %g seconds", (end - start) / NANOSECONDS_PER_SECOND));
return confusionMatrix;
}
/**
* Updates the confusion matrix
*
* @param results results of classifier
* @param confusion confusion matrix
* @param target correct answer
* @param errors set to place addresses with classifier errors
* @throws Exception if exception encountered during processing
*/
void updateConfusionMatrix(Collection<QResult<Address, Boolean>> results, int[] confusion,
Label target, AddressSet errors) throws Exception {
int trueIndex = target.equals(RandomForestFunctionFinderPlugin.FUNC_START) ? TP : TN;
int falseIndex = target.equals(RandomForestFunctionFinderPlugin.FUNC_START) ? FN : FP;
for (QResult<Address, Boolean> result : results) {
Boolean ans = result.getResult();
if (ans == null) {
continue;
}
if (ans) {
confusion[trueIndex] += 1;
}
else {
confusion[falseIndex] += 1;
errors.add(result.getItem());
}
}
}
private synchronized DatasetView<Label> getBag() {
return DatasetView.createBootstrapView(trainingSet, trainingSet.size(),
ThreadLocalRandom.current().nextLong());
}
private class SingleTreeTrainer
implements QCallback<CARTClassificationTrainer, TreeModel<Label>> {
@Override
public TreeModel<Label> process(CARTClassificationTrainer trainer, TaskMonitor monitor)
throws Exception {
DatasetView<Label> bag = getBag();
TreeModel<Label> tree = trainer.train(bag);
monitor.incrementProgress(1);
return tree;
}
}
}

View file

@ -0,0 +1,115 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import ghidra.program.model.address.*;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
/**
* This is a utility class for generating random subsets of an {@link AddressSetView}.
*/
public class RandomSubsetUtils {
private RandomSubsetUtils() {
//utility class
}
/**
* This method generates a random subset of size {@code k} of {@code addresses}.
* <p>
* The parameter {@code k} can be of type {@code long}, but you will probably run out of heap
* space for large values.
*
* @param addresses addresses
* @param k size of random subset to generate
* @param monitor monitor
* @return random subset of size k
* @throws CancelledException if monitor is canceled
*/
public static AddressSet randomSubset(AddressSetView addresses, long k, TaskMonitor monitor)
throws CancelledException {
List<Long> sortedRandom = generateRandomIntegerSubset(addresses.getNumAddresses(), k);
Collections.sort(sortedRandom);
AddressSet randomAddresses = new AddressSet();
AddressIterator iter = addresses.getAddresses(true);
int addressesAdded = 0;
int addressesVisited = 0;
int listIndex = 0;
while (iter.hasNext() && addressesAdded < k) {
monitor.checkCanceled();
Address addr = iter.next();
if (sortedRandom.get(listIndex) == addressesVisited) {
randomAddresses.add(addr);
addressesAdded += 1;
listIndex += 1;
}
addressesVisited += 1;
}
return randomAddresses;
}
/**
* Generates of random subset of size {@code k} of the set [0,1,...n-1] by generating
* a random permutation
* @param n size of set (must be >= 0)
* @param k size of random subset (must be >= 0)
* @return list of indices of elements in random subset
*/
public static List<Long> generateRandomIntegerSubset(long n, long k) {
if (n < 0) {
throw new IllegalArgumentException("n cannot be negative");
}
if (k < 0) {
throw new IllegalArgumentException("k cannot be negative");
}
if (n < k) {
throw new IllegalArgumentException(
"size of subset (" + k + ") cannot be larger than size of set (" + n + ")");
}
Map<Long, Long> permutation = new HashMap<>();
for (long i = 0; i < k; ++i) {
swap(permutation, i, ThreadLocalRandom.current().nextLong(i, n));
}
List<Long> random = new ArrayList<>();
for (long i = 0; i < k; i++) {
random.add(permutation.getOrDefault(i, i));
}
return random;
}
/**
* Updates a Map<Long,Long> treated as a permutation p to produce a new permutation p'
* such that p'(i) = p(j) and p'(j) = p(i). For i not in the keySet of the map, it is
* assumed that p(i) = i.
* @param permutation permutation map
* @param i index
* @param j index
*/
public static void swap(Map<Long, Long> permutation, long i, long j) {
if (i == j) {
return;
}
long ith = permutation.getOrDefault(i, i);
long jth = permutation.getOrDefault(j, j);
permutation.put(i, jth);
permutation.put(j, ith);
}
}

View file

@ -0,0 +1,91 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.List;
import docking.ActionContext;
import docking.action.DockingAction;
import docking.action.MenuData;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Program;
import ghidra.util.HelpLocation;
import ghidra.util.table.GhidraTable;
/**
* A {@link DockingAction} for showing the most similar function starts in the training
* set to a possible function start
*/
public class ShowSimilarStartsAction extends DockingAction {
private static final String MENU_TEXT = "Show Similar Function Starts";
private static final String ACTION_NAME = "ShowSimilarStartsAction";
private static final int NUM_NEIGHBORS = 10;
private Program program;
private FunctionStartTableModel model;
private GhidraTable table;
private RandomForestRowObject modelAndParams;
private RandomForestFunctionFinderPlugin plugin;
private SimilarStartsFinder finder;
/**
* Constructs an action display similar function starts
* @param plugin plugin
* @param program source program
* @param table table
* @param model table with action
*/
public ShowSimilarStartsAction(RandomForestFunctionFinderPlugin plugin, Program program,
GhidraTable table, FunctionStartTableModel model) {
super(ACTION_NAME, plugin.getName());
this.program = program;
this.model = model;
this.table = table;
this.plugin = plugin;
this.modelAndParams = model.getRandomForestRowObject();
init();
finder = new SimilarStartsFinder(program, modelAndParams);
}
@Override
public boolean isAddToPopup(ActionContext context) {
return true;
}
@Override
public boolean isEnabledForContext(ActionContext context) {
return table.getSelectedRowCount() == 1;
}
@Override
public void actionPerformed(ActionContext context) {
Address potential = model.getAddress(table.getSelectedRow());
List<SimilarStartRowObject> closeNeighbors =
finder.getSimilarFunctionStarts(potential, NUM_NEIGHBORS);
SimilarStartsTableProvider provider = new SimilarStartsTableProvider(plugin, program,
potential, closeNeighbors, modelAndParams);
plugin.addProvider(provider);
}
private void init() {
setPopupMenuData(new MenuData(new String[] { MENU_TEXT }));
setDescription(
"Displays the most similar function starts in the training set to the given " +
"potential start.");
setHelpLocation(new HelpLocation(plugin.getName(), ACTION_NAME));
}
}

View file

@ -0,0 +1,27 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import ghidra.program.model.address.Address;
/**
* Record for storing random forest proximity information for a potential
* function start
* @param funcStart address
* @param numAgreements number of agreeing trees
*/
public record SimilarStartRowObject(Address funcStart, int numAgreements) {
};

View file

@ -0,0 +1,132 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.*;
import java.util.Map.Entry;
import org.tribuo.Feature;
import org.tribuo.Model;
import org.tribuo.classification.Label;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.TreeModel;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.impl.ArrayExample;
import org.tribuo.math.la.SparseVector;
import ghidra.program.model.address.*;
import ghidra.program.model.listing.Program;
/**
* Given a potential function start {@code S} and a random forest trained to recognize
* function starts, this class is used to find the function starts in the training set
* most similar to {@code S}. Here "similar" is defined in terms of proximity in a
* random forest (i.e., proportion of trees which agree on two feature vectors).
*/
public class SimilarStartsFinder {
private RandomForestRowObject modelAndParams;
private Program program;
private int preBytes;
private int initialBytes;
private boolean includeBitFeatures;
private Map<Address, List<Node<Label>>> startsToLeafList = new HashMap<>();
private EnsembleModel<Label> randomForest;
/**
* Creates a {@link SimilarStartsFinder} for the given program and model
* @param program program
* @param modelAndParams model and params
*/
public SimilarStartsFinder(Program program, RandomForestRowObject modelAndParams) {
this.program = program;
this.modelAndParams = modelAndParams;
preBytes = modelAndParams.getNumPreBytes();
initialBytes = modelAndParams.getNumInitialBytes();
includeBitFeatures = modelAndParams.getIncludeBitLevelFeatures();
randomForest = modelAndParams.getRandomForest();
computeLeafNodeLists();
}
/**
* Finds the functions starts in the training set that are most similar to {@code potential}
* according to the model
* @param potential address of potential start
* @param numStarts max size of returned list
* @return similar starts (in descending order)
*/
public List<SimilarStartRowObject> getSimilarFunctionStarts(Address potential, int numStarts) {
List<Node<Label>> leafNodes = getLeafNodes(potential);
List<SimilarStartRowObject> neighbors = new ArrayList<>(startsToLeafList.size());
for (Entry<Address, List<Node<Label>>> entry : startsToLeafList.entrySet()) {
Address start = entry.getKey();
List<Node<Label>> leafList = entry.getValue();
int matches = 0;
for (int i = 0; i < randomForest.getNumModels(); ++i) {
if (leafNodes.get(i).equals(leafList.get(i))) {
matches++;
}
}
neighbors.add(new SimilarStartRowObject(start, matches));
}
Collections.sort(neighbors,
(x, y) -> Integer.compare(y.numAgreements(), x.numAgreements()));
List<SimilarStartRowObject> closeNeighbors =
neighbors.subList(0, Math.min(numStarts, neighbors.size()));
return closeNeighbors;
}
/**
* For each function start in the training set and tree in the random forest,
* run the corresponding feature vector down the tree and record the leaf node
* reached.
*/
private void computeLeafNodeLists() {
AddressSet knownStarts = modelAndParams.getTrainingPositives();
AddressIterator addrIter = knownStarts.getAddresses(true);
while (addrIter.hasNext()) {
Address start = addrIter.next();
List<Node<Label>> nodeList = getLeafNodes(start);
startsToLeafList.put(start, nodeList);
}
}
/**
* Creates a feature vector for {@code addr}, runs it down each tree in the forest,
* and records the leaf node reached.
* @param addr potential function start
* @return list of leaf nodes
*/
List<Node<Label>> getLeafNodes(Address addr) {
List<Node<Label>> leafNodes = new ArrayList<>(randomForest.getNumModels());
List<Feature> potentialFeatureVector = ModelTrainingUtils.getFeatureVector(program, addr,
preBytes, initialBytes, includeBitFeatures);
ArrayExample<Label> example =
new ArrayExample<>(RandomForestFunctionFinderPlugin.FUNC_START, potentialFeatureVector);
SparseVector vec =
SparseVector.createSparseVector(example, randomForest.getFeatureIDMap(), false);
for (Model<Label> member : randomForest.getModels()) {
TreeModel<Label> tree = (TreeModel<Label>) member;
Node<Label> node = tree.getRoot();
while (!node.isLeaf()) {
node = node.getNextNode(vec);
}
leafNodes.add(node);
}
return leafNodes;
}
}

View file

@ -0,0 +1,177 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.util.List;
import javax.swing.JTable;
import javax.swing.table.TableModel;
import docking.widgets.table.AbstractDynamicTableColumn;
import docking.widgets.table.TableColumnDescriptor;
import ghidra.docking.settings.Settings;
import ghidra.framework.plugintool.PluginTool;
import ghidra.framework.plugintool.ServiceProvider;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Program;
import ghidra.program.model.mem.MemoryAccessException;
import ghidra.util.datastruct.Accumulator;
import ghidra.util.exception.CancelledException;
import ghidra.util.table.AddressBasedTableModel;
import ghidra.util.table.column.AbstractGColumnRenderer;
import ghidra.util.table.column.GColumnRenderer;
import ghidra.util.task.TaskMonitor;
/**
* Table model for a table displaying the closest function starts from the training set
* to a given potential start.
*/
public class SimilarStartsTableModel extends AddressBasedTableModel<SimilarStartRowObject> {
private Address potentialStart;
private List<SimilarStartRowObject> rows;
private RandomForestRowObject randomForestRow;
/**
* Construct a table model for a table to display the closest function starts to
* a potential function start
* @param plugin owning program
* @param program program
* @param potentialStart address of potential start
* @param rows similar function starts
* @param randomForestRow model and params
*/
public SimilarStartsTableModel(PluginTool plugin, Program program, Address potentialStart,
List<SimilarStartRowObject> rows, RandomForestRowObject randomForestRow) {
super("test", plugin, program, null, false);
this.potentialStart = potentialStart;
this.rows = rows;
this.randomForestRow = randomForestRow;
}
@Override
public Address getAddress(int row) {
return getRowObject(row).funcStart();
}
@Override
protected void doLoad(Accumulator<SimilarStartRowObject> accumulator, TaskMonitor monitor)
throws CancelledException {
//add a special row corresponding to the potential function start
//want it in the table to facilitate (visual) byte string comparisons
accumulator.add(new SimilarStartRowObject(potentialStart,
randomForestRow.getRandomForest().getNumModels()));
accumulator.addAll(rows);
}
@Override
protected TableColumnDescriptor<SimilarStartRowObject> createTableColumnDescriptor() {
TableColumnDescriptor<SimilarStartRowObject> descriptor = new TableColumnDescriptor<>();
descriptor.addVisibleColumn(new AddressTableColumn());
descriptor.addVisibleColumn(new SimilarityTableColumn(), 1, false);
descriptor.addVisibleColumn(new ByteStringTableColumn());
return descriptor;
}
private class AddressTableColumn
extends AbstractDynamicTableColumn<SimilarStartRowObject, String, Object> {
@Override
public String getColumnName() {
return "Address";
}
@Override
public String getValue(SimilarStartRowObject rowObject, Settings settings, Object data,
ServiceProvider services) throws IllegalArgumentException {
String addrString = rowObject.funcStart().toString();
//address corresponding to the potential start should stand out in the table
//so surround it with asterisks
if (rowObject.funcStart().equals(potentialStart)) {
addrString = "*" + addrString + "*";
}
return addrString;
}
}
private class SimilarityTableColumn
extends AbstractDynamicTableColumn<SimilarStartRowObject, Double, Object> {
@Override
public String getColumnName() {
return "Similarity";
}
@Override
public Double getValue(SimilarStartRowObject rowObject, Settings settings, Object data,
ServiceProvider services) throws IllegalArgumentException {
return rowObject.numAgreements() * 1.0 /
randomForestRow.getRandomForest().getNumModels();
}
}
private class ByteStringTableColumn
extends AbstractDynamicTableColumn<SimilarStartRowObject, String, Object> {
@Override
public String getColumnName() {
return "Byte String";
}
@Override
public GColumnRenderer<String> getColumnRenderer() {
final GColumnRenderer<String> monospacedRenderer = new AbstractGColumnRenderer<>() {
@Override
protected void configureFont(JTable table, TableModel model, int column) {
setFont(getFixedWidthFont());
}
@Override
public String getFilterString(String t, Settings settings) {
return t;
}
};
return monospacedRenderer;
}
@Override
public String getValue(SimilarStartRowObject rowObject, Settings settings, Object data,
ServiceProvider services) throws IllegalArgumentException {
Address funcStart = rowObject.funcStart();
byte[] bytes =
new byte[randomForestRow.getNumPreBytes() + randomForestRow.getNumInitialBytes()];
try {
program.getMemory()
.getBytes(funcStart.subtract(randomForestRow.getNumPreBytes()), bytes);
}
catch (MemoryAccessException e) {
return "??";
}
StringBuilder sb = new StringBuilder();
for (int i = 0; i < randomForestRow.getNumPreBytes(); ++i) {
sb.append(String.format("%02x ", bytes[i] & 0xff));
}
sb.append("* ");
for (int i = randomForestRow.getNumPreBytes(); i < randomForestRow.getNumPreBytes() +
randomForestRow.getNumInitialBytes(); ++i) {
sb.append(String.format("%02x ", bytes[i] & 0xff));
}
return sb.toString();
}
}
}

View file

@ -0,0 +1,87 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import java.awt.BorderLayout;
import java.awt.Dimension;
import java.util.List;
import javax.swing.*;
import ghidra.app.services.GoToService;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Program;
import ghidra.util.HelpLocation;
import ghidra.util.table.GhidraTable;
import ghidra.util.table.GhidraThreadedTablePanel;
/**
* Table provider for a table to display the closest function starts in the training
* set to a potential function start
*/
public class SimilarStartsTableProvider extends ProgramAssociatedComponentProviderAdapter {
private Program program;
private Address potentialStart;
private List<SimilarStartRowObject> rows;
private JComponent component;
private RandomForestRowObject randomForestRow;
/**
* Create a table provider
* @param plugin owning plugin
* @param program program being search
* @param potentialStart address of potential start
* @param rows closest potential starts
* @param randomForestRow model and params
*/
public SimilarStartsTableProvider(RandomForestFunctionFinderPlugin plugin, Program program,
Address potentialStart, List<SimilarStartRowObject> rows,
RandomForestRowObject randomForestRow) {
super(program.getName() + ": Similar Function Starts", plugin.getName(), program, plugin);
this.program = program;
this.potentialStart = potentialStart;
this.rows = rows;
this.randomForestRow = randomForestRow;
this.setSubTitle("Function Starts Similar to " + potentialStart.toString());
build();
setHelpLocation(new HelpLocation(plugin.getName(), "SimilarStartsTable"));
}
@Override
public JComponent getComponent() {
return component;
}
private void build() {
component = new JPanel(new BorderLayout());
SimilarStartsTableModel model =
new SimilarStartsTableModel(tool, program, potentialStart, rows, randomForestRow);
GhidraThreadedTablePanel<SimilarStartRowObject> similarStartsPanel =
new GhidraThreadedTablePanel<>(model, 1000);
GhidraTable similarStartsTable = similarStartsPanel.getTable();
similarStartsTable.setName(
program.getName() + ": Known Starts Similar to " + potentialStart.toString());
GoToService goToService = tool.getService(GoToService.class);
if (goToService != null) {
similarStartsTable.installNavigation(goToService, goToService.getDefaultNavigatable());
}
similarStartsTable.setNavigateOnSelectionEnabled(true);
similarStartsTable.setAutoResizeMode(JTable.AUTO_RESIZE_SUBSEQUENT_COLUMNS);
similarStartsTable.setPreferredScrollableViewportSize(new Dimension(900, 300));
component.add(similarStartsPanel, BorderLayout.CENTER);
}
}

View file

@ -0,0 +1,88 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import ghidra.program.model.address.AddressSet;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
/**
* Container class for {@link AddressSet}s used during model training and testing
*/
public class TrainingAndTestData {
private AddressSet trainingPositive;
private AddressSet trainingNegative;
private AddressSet testPositive;
private AddressSet testNegative;
public TrainingAndTestData(AddressSet trainingPositive, AddressSet trainingNegative,
AddressSet testPositive, AddressSet testNegative) {
this.trainingPositive = trainingPositive;
this.trainingNegative = trainingNegative;
this.testPositive = testPositive;
this.testNegative = testNegative;
}
/**
* Returns the {@link AddressSet} of positive examples for training
* @return training positive
*/
public AddressSet getTrainingPositive() {
return trainingPositive;
}
/**
* Returns the {@link AddressSet} of negative examples for training
* @return training negative
*/
public AddressSet getTrainingNegative() {
return trainingNegative;
}
/**
* Returns the {@link AddressSet} of positive examples for testing
* @return test positive
*/
public AddressSet getTestPositive() {
return testPositive;
}
/**
* Returns the {@link AddressSet} of negative examples for testing
* @return test negative
*/
public AddressSet getTestNegative() {
return testNegative;
}
/**
* Checks the sizes of the sets {@code testPositive} and {@code testNegative}. Any set
* that is larger than {@code max} is replaced with a random subset of size {@code max}.
*
* @param max max size of each set
* @param monitor task monitor
* @throws CancelledException if the monitor is canceled
*/
public void reduceTestSetSize(long max, TaskMonitor monitor) throws CancelledException {
if (testPositive.getNumAddresses() > max) {
testPositive = RandomSubsetUtils.randomSubset(testPositive, max, monitor);
}
if (testNegative.getNumAddresses() > max) {
testNegative = RandomSubsetUtils.randomSubset(testNegative, max, monitor);
}
}
}

View file

@ -0,0 +1,2 @@
The "src/resources/images" directory is intended to hold all image/icon files used by
this contrib.

View file

@ -0,0 +1,242 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import static org.junit.Assert.*;
import java.util.Iterator;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.classification.Label;
import ghidra.program.database.ProgramBuilder;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.data.DataType;
import ghidra.program.model.data.IntegerDataType;
import ghidra.program.model.listing.Program;
import ghidra.program.model.mem.MemoryBlock;
import ghidra.test.AbstractProgramBasedTest;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
public class DataGatheringUtilsTest extends AbstractProgramBasedTest {
private final static String BASE_ADDRESS = "0x10000";
private final static String ADD_R0_R1_R2_ARM = "02 00 81 e0";
private final static String SUB_R4_R5_R6_ARM = "06 40 45 e0";
private final static String ADD_R0_R1_THUMB = "08 44";
@Before
public void setUp() throws Exception {
initialize();
}
@Override
protected Program getProgram() throws Exception {
return buildProgram();
}
private Program buildProgram() throws Exception {
ProgramBuilder builder = new ProgramBuilder("DataGatheringUtilsTest", ProgramBuilder._ARM);
MemoryBlock block = builder.createMemory(".text", BASE_ADDRESS, 0x100);
builder.setExecute(block, true);
//undefined
builder.setBytes(BASE_ADDRESS, "00 01 02 03", false);
builder.setBytes("0x10004", ADD_R0_R1_R2_ARM, true);
builder.setBytes("0x10008", SUB_R4_R5_R6_ARM, true);
builder.setRegisterValue("TMode", "0x1000c", "0x1000f", 1);
builder.setBytes("0x1000c", ADD_R0_R1_THUMB, true);
builder.setBytes("0x1000e", ADD_R0_R1_THUMB, true);
DataType intType = new IntegerDataType();
builder.setBytes("0x10020", "00 01 02 03", false);
builder.applyDataType("0x10020", intType);
builder.setBytes("0x10024", "04 05 06 07", false);
builder.applyDataType("0x10024", intType);
builder.setBytes("0x10030", "08 09 0a 0b", false);
builder.applyDataType("0x10030", intType);
return builder.getProgram();
}
@Test
public void testGetByteValues() {
Address base = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10008);
List<Feature> test = ModelTrainingUtils.getFeatureVector(program, base, 1, 1, false);
assertEquals(2, test.size());
assertEquals("pbyte_0", test.get(0).getName());
assertEquals(224.0d, test.get(0).getValue(), 0.0);
assertEquals("ibyte_0", test.get(1).getName());
assertEquals(6d, test.get(1).getValue(), 0.0);
}
@Test
public void testGetBitValues() {
Address base = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10008);
List<Feature> test = ModelTrainingUtils.getFeatureVector(program, base, 1, 1, true);
assertEquals(18, test.size());
assertEquals("pbyte_0", test.get(0).getName());
assertEquals("pbit_0_7", test.get(1).getName());
assertEquals("pbit_0_6", test.get(2).getName());
assertEquals("pbit_0_5", test.get(3).getName());
assertEquals("pbit_0_4", test.get(4).getName());
assertEquals("pbit_0_3", test.get(5).getName());
assertEquals("pbit_0_2", test.get(6).getName());
assertEquals("pbit_0_1", test.get(7).getName());
assertEquals("pbit_0_0", test.get(8).getName());
assertEquals("ibyte_0", test.get(9).getName());
assertEquals("ibit_0_7", test.get(10).getName());
assertEquals("ibit_0_6", test.get(11).getName());
assertEquals("ibit_0_5", test.get(12).getName());
assertEquals("ibit_0_4", test.get(13).getName());
assertEquals("ibit_0_3", test.get(14).getName());
assertEquals("ibit_0_2", test.get(15).getName());
assertEquals("ibit_0_1", test.get(16).getName());
assertEquals("ibit_0_0", test.get(17).getName());
//e0 06
assertEquals(224.0d, test.get(0).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ONE, test.get(1).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ONE, test.get(2).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ONE, test.get(3).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(4).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(5).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(6).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(7).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(8).getValue(), 0.0);
assertEquals(6d, test.get(9).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(10).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(11).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(12).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(13).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(14).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ONE, test.get(15).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ONE, test.get(16).getValue(), 0.0);
assertEquals(ModelTrainingUtils.ZERO, test.get(17).getValue(), 0.0);
}
@Test
public void testGetFollowingAddresses() throws CancelledException {
Address one = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10004);
Address two = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1000c);
AddressSet addrs = new AddressSet(one);
addrs.add(two);
AddressSet following =
ModelTrainingUtils.getFollowingAddresses(program, addrs, TaskMonitor.DUMMY);
assertEquals(2, following.getNumAddresses());
assertTrue(following.contains(
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10008)));
assertTrue(following.contains(
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1000e)));
}
@Test
public void testGetPrecedingAddresses() throws CancelledException {
Address one = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10004);
Address two = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1000c);
AddressSet addrs = new AddressSet(one);
addrs.add(two);
AddressSet following =
ModelTrainingUtils.getPrecedingAddresses(program, addrs, TaskMonitor.DUMMY);
assertEquals(2, following.getNumAddresses());
assertTrue(following.contains(
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10003)));
assertTrue(following.contains(
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10008)));
}
@Test
public void testGetDefinedData() throws CancelledException {
AddressSet data = ModelTrainingUtils.getDefinedData(program, TaskMonitor.DUMMY);
assertEquals(12, data.getNumAddresses());
Address start = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10020);
Address end = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10027);
assertTrue(data.contains(start, end));
start = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10030);
end = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10033);
assertTrue(data.contains(start, end));
}
@Test
public void testGetVectorsFromAddresses() throws CancelledException {
Address base = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10008);
AddressSet testSet = new AddressSet(base);
List<Example<Label>> testVectorList = ModelTrainingUtils.getVectorsFromAddresses(program,
testSet, RandomForestFunctionFinderPlugin.FUNC_START, 1, 1, true, TaskMonitor.DUMMY);
assertEquals(1, testVectorList.size());
Example<Label> testVector = testVectorList.get(0);
assertEquals(RandomForestFunctionFinderPlugin.FUNC_START, testVector.getOutput());
testFeatureVector(testVector.iterator());
testVectorList = ModelTrainingUtils.getVectorsFromAddresses(program, testSet,
RandomForestFunctionFinderPlugin.NON_START, 1, 1, true, TaskMonitor.DUMMY);
assertEquals(1, testVectorList.size());
testVector = testVectorList.get(0);
assertEquals(RandomForestFunctionFinderPlugin.NON_START, testVector.getOutput());
testFeatureVector(testVector.iterator());
}
private void testFeatureVector(Iterator<Feature> iter) {
while (iter.hasNext()) {
Feature feature = iter.next();
switch (feature.getName()) {
case "pbyte_0":
assertEquals(224d, feature.getValue(), 0.0);
break;
case "ibyte_0":
assertEquals(6d, feature.getValue(), 0.0);
break;
case "pbit_0_7":
case "pbit_0_6":
case "pbit_0_5":
assertEquals(ModelTrainingUtils.ONE, feature.getValue(), 0.0);
break;
case "pbit_0_4":
case "pbit_0_3":
case "pbit_0_2":
case "pbit_0_1":
case "pbit_0_0":
case "ibit_0_7":
case "ibit_0_6":
case "ibit_0_5":
case "ibit_0_4":
case "ibit_0_3":
assertEquals(ModelTrainingUtils.ZERO, feature.getValue(), 0.0);
break;
case "ibit_0_2":
case "ibit_0_1":
assertEquals(ModelTrainingUtils.ONE, feature.getValue(), 0.0);
break;
case "ibit_0_0":
assertEquals(ModelTrainingUtils.ZERO, feature.getValue(), 0.0);
break;
default:
fail("Unknown feature name: " + feature.getName());
}
}
}
}

View file

@ -0,0 +1,152 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import org.tribuo.*;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.baseline.DummyClassifierTrainer;
import org.tribuo.classification.ensemble.VotingCombiner;
import org.tribuo.datasource.ListDataSource;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.ensemble.WeightedEnsembleModel;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import ghidra.program.database.ProgramBuilder;
import ghidra.program.database.ProgramDB;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.listing.Program;
import ghidra.test.AbstractProgramBasedTest;
import ghidra.test.ClassicSampleX86ProgramBuilder;
import ghidra.util.task.TaskMonitor;
public class EnsembleEvaluatorTest extends AbstractProgramBasedTest {
private ProgramBuilder builder;
@Override
protected Program getProgram() throws Exception {
builder = new ClassicSampleX86ProgramBuilder();
ProgramDB p = builder.getProgram();
return p;
}
@Test
public void basicEvaluatorTest() throws Exception {
initialize();
DummyClassifierTrainer dummyStartTrainer = DummyClassifierTrainer
.createConstantTrainer(RandomForestFunctionFinderPlugin.FUNC_START.getLabel());
DummyClassifierTrainer dummyNonStartTrainer = DummyClassifierTrainer
.createConstantTrainer(RandomForestFunctionFinderPlugin.NON_START.getLabel());
List<Example<Label>> trainingData = new ArrayList<>();
AddressSet starts = new AddressSet();
AddressSet nonStarts = new AddressSet();
//just need an address with 1 defined preByte and 1 defined byte
Address testStart = program.getSymbolTable().getSymbols("entry").next().getAddress().add(1);
Address testNonStart = testStart.add(1);
starts.add(testStart);
nonStarts.add(testNonStart);
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, starts,
RandomForestFunctionFinderPlugin.FUNC_START, 1, 1, true, TaskMonitor.DUMMY));
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, nonStarts,
RandomForestFunctionFinderPlugin.NON_START, 1, 1, true, TaskMonitor.DUMMY));
LabelFactory lf = new LabelFactory();
ListDataSource<Label> trainingSource =
new ListDataSource<Label>(trainingData, lf, new SimpleDataSourceProvenance("test", lf));
MutableDataset<Label> trainingSet = new MutableDataset<>(trainingSource);
//10 models, encounters 5 yes votes first
List<Model<Label>> models = new ArrayList<>();
for (int i = 0; i < 10; i++) {
models.add(dummyStartTrainer.train(trainingSet));
models.add(dummyNonStartTrainer.train(trainingSet));
}
EnsembleModel<Label> ensemble = WeightedEnsembleModel
.createEnsembleFromExistingModels("test", models, new VotingCombiner());
EnsembleEvaluatorCallback testEval = new EnsembleEvaluatorCallback(ensemble, program, 1, 1,
true, RandomForestFunctionFinderPlugin.FUNC_START);
Boolean res = testEval.process(testStart, TaskMonitor.DUMMY);
assertTrue(res);
//10 models, encounters 5 no votes first
//there are also 5 voting yes, ties should go to yes
models.clear();
for (int i = 0; i < 10; ++i) {
models.add(dummyNonStartTrainer.train(trainingSet));
models.add(dummyStartTrainer.train(trainingSet));
}
ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("test", models,
new VotingCombiner());
testEval = new EnsembleEvaluatorCallback(ensemble, program, 1, 1, true,
RandomForestFunctionFinderPlugin.FUNC_START);
res = testEval.process(testStart, TaskMonitor.DUMMY);
assertTrue(res);
//10 models, encounters 4 yes votes then 6 no votes
models.clear();
for (int i = 0; i < 4; ++i) {
models.add(dummyStartTrainer.train(trainingSet));
}
for (int i = 0; i < 6; ++i) {
models.add(dummyNonStartTrainer.train(trainingSet));
}
ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("test", models,
new VotingCombiner());
testEval = new EnsembleEvaluatorCallback(ensemble, program, 1, 1, true,
RandomForestFunctionFinderPlugin.FUNC_START);
res = testEval.process(testStart, TaskMonitor.DUMMY);
assertFalse(res);
//11 models, encounters 5 no votes first then 6 yes votes
models.clear();
for (int i = 0; i < 10; i++) {
models.add(dummyNonStartTrainer.train(trainingSet));
models.add(dummyStartTrainer.train(trainingSet));
}
models.add(dummyStartTrainer.train(trainingSet));
ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("test", models,
new VotingCombiner());
testEval = new EnsembleEvaluatorCallback(ensemble, program, 1, 1, true,
RandomForestFunctionFinderPlugin.FUNC_START);
res = testEval.process(testStart, TaskMonitor.DUMMY);
assertTrue(res);
//11 models, encounters 5 yes votes first then 6 no votes
models.clear();
for (int i = 0; i < 5; ++i) {
models.add(dummyStartTrainer.train(trainingSet));
}
for (int i = 0; i < 6; ++i) {
models.add(dummyNonStartTrainer.train(trainingSet));
}
ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("test", models,
new VotingCombiner());
testEval = new EnsembleEvaluatorCallback(ensemble, program, 1, 1, true,
RandomForestFunctionFinderPlugin.FUNC_START);
res = testEval.process(testStart, TaskMonitor.DUMMY);
assertFalse(res);
}
}

View file

@ -0,0 +1,186 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import static org.junit.Assert.*;
import org.junit.Before;
import org.junit.Test;
import ghidra.program.database.ProgramBuilder;
import ghidra.program.model.address.*;
import ghidra.program.model.listing.Program;
import ghidra.program.model.mem.MemoryBlock;
import ghidra.test.AbstractProgramBasedTest;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
public class FunctionStartRFParamsProgramBasedTest extends AbstractProgramBasedTest {
private final static String BASE_ADDRESS = "0x10000";
private final static String ADD_R0_R1_R2_ARM = "02 00 81 e0";
private final static String BX_LR_ARM = "1e ff 2f e1";
private final static String ADD_R0_R1_THUMB = "08 44";
private final static String BX_LR_THUMB = "70 47";
@Before
public void setUp() throws Exception {
initialize();
}
@Override
protected Program getProgram() throws Exception {
return buildProgram();
}
private Program buildProgram() throws Exception {
ProgramBuilder builder = new ProgramBuilder("DataGatheringUtilsTest", ProgramBuilder._ARM);
MemoryBlock block = builder.createMemory(".text", BASE_ADDRESS, 0x100);
builder.setExecute(block, true);
//undefined
builder.setBytes(BASE_ADDRESS, "00 01 02 03", false);
//small arm function
builder.setBytes("0x10004", ADD_R0_R1_R2_ARM, true);
builder.setBytes("0x10008", BX_LR_ARM, true);
builder.createFunction("0x10004");
//larger arm function
builder.setBytes("0x1000c", ADD_R0_R1_R2_ARM, true);
builder.setBytes("0x10010", ADD_R0_R1_R2_ARM, true);
builder.setBytes("0x10014", ADD_R0_R1_R2_ARM, true);
builder.setBytes("0x10018", ADD_R0_R1_R2_ARM, true);
builder.setBytes("0x1001c", BX_LR_ARM, true);
builder.createFunction("0x1000c");
builder.setRegisterValue("TMode", "0x10020", "0x10036", 1);
//small thumb function
builder.setBytes("0x10020", ADD_R0_R1_THUMB, true);
builder.setBytes("0x10022", BX_LR_THUMB, true);
builder.createFunction("0x10020");
//larger thumb function
builder.setBytes("0x10024", ADD_R0_R1_THUMB, true);
builder.setBytes("0x10026", ADD_R0_R1_THUMB, true);
builder.setBytes("0x10028", ADD_R0_R1_THUMB, true);
builder.setBytes("0x1002a", ADD_R0_R1_THUMB, true);
builder.setBytes("0x1002c", ADD_R0_R1_THUMB, true);
builder.setBytes("0x1002e", ADD_R0_R1_THUMB, true);
builder.setBytes("0x10030", ADD_R0_R1_THUMB, true);
builder.setBytes("0x10032", BX_LR_THUMB, true);
builder.createFunction("0x10024");
return builder.getProgram();
}
@Test
public void testCheckContextRegisters() {
FunctionStartRFParams params = new FunctionStartRFParams(program);
Address armFunc = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1000c);
Address thumbFunc =
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10020);
assertTrue(params.isContextCompatible(armFunc));
assertTrue(params.isContextCompatible(thumbFunc));
params.setRegistersAndValues("TMode=1");
assertFalse(params.isContextCompatible(armFunc));
assertTrue(params.isContextCompatible(thumbFunc));
}
@Test
public void testComputeEntriesAndInteriors() throws CancelledException {
AddressSpace defaultSpace = program.getAddressFactory().getDefaultAddressSpace();
Address smallArmEntry = defaultSpace.getAddress(0x10004);
//instruction alignment for the processor is 2, even in ARM mode
//so we need to adjust the arm interiors
AddressSet smallArmInterior = new AddressSet(defaultSpace.getAddress(0x10008));
smallArmInterior.add(defaultSpace.getAddress(0x10006));
smallArmInterior.add(defaultSpace.getAddress(0x1000a));
Address largeArmEntry = defaultSpace.getAddress(0x1000c);
AddressSet largeArmInterior = new AddressSet(defaultSpace.getAddress(0x10010));
largeArmInterior.add(defaultSpace.getAddress(0x1000e));
largeArmInterior.add(defaultSpace.getAddress(0x10012));
largeArmInterior.add(defaultSpace.getAddress(0x10014));
largeArmInterior.add(defaultSpace.getAddress(0x10016));
largeArmInterior.add(defaultSpace.getAddress(0x10018));
largeArmInterior.add(defaultSpace.getAddress(0x1001a));
largeArmInterior.add(defaultSpace.getAddress(0x1001c));
largeArmInterior.add(defaultSpace.getAddress(0x1001e));
Address smallThumbEntry = defaultSpace.getAddress(0x10020);
AddressSet smallThumbInterior = new AddressSet(defaultSpace.getAddress(0x10022));
Address largeThumbEntry = defaultSpace.getAddress(0x10024);
AddressSet largeThumbInterior = new AddressSet(defaultSpace.getAddress(0x10026));
largeThumbInterior.add(defaultSpace.getAddress(0x10028));
largeThumbInterior.add(defaultSpace.getAddress(0x1002a));
largeThumbInterior.add(defaultSpace.getAddress(0x1002c));
largeThumbInterior.add(defaultSpace.getAddress(0x1002e));
largeThumbInterior.add(defaultSpace.getAddress(0x10030));
largeThumbInterior.add(defaultSpace.getAddress(0x10032));
FunctionStartRFParams params = new FunctionStartRFParams(program);
params.computeFuncEntriesAndInteriors(TaskMonitor.DUMMY);
AddressSet entries = params.getFuncEntries();
AddressSet interiors = params.getFuncInteriors();
assertTrue(!entries.intersects(interiors));
assertEquals(4, entries.getNumAddresses());
assertTrue(entries.contains(smallArmEntry));
assertTrue(entries.contains(largeArmEntry));
assertTrue(entries.contains(smallThumbEntry));
assertTrue(entries.contains(largeThumbEntry));
assertEquals(
smallThumbInterior.getNumAddresses() + largeThumbInterior.getNumAddresses() +
smallArmInterior.getNumAddresses() + largeArmInterior.getNumAddresses(),
interiors.getNumAddresses());
assertTrue(interiors.contains(smallThumbInterior.union(largeThumbInterior)
.union(smallArmInterior)
.union(largeArmInterior)));
params = new FunctionStartRFParams(program);
params.setMinFuncSize(10);
params.setRegistersAndValues("TMode=0");
params.computeFuncEntriesAndInteriors(TaskMonitor.DUMMY);
entries = params.getFuncEntries();
interiors = params.getFuncInteriors();
assertTrue(!entries.intersects(interiors));
assertEquals(1, entries.getNumAddresses());
assertTrue(entries.contains(largeArmEntry));
assertTrue(interiors.contains(largeArmInterior));
assertEquals(interiors.getNumAddresses(), largeArmInterior.getNumAddresses());
params = new FunctionStartRFParams(program);
params.setMinFuncSize(10);
params.setRegistersAndValues("TMode=1");
params.computeFuncEntriesAndInteriors(TaskMonitor.DUMMY);
entries = params.getFuncEntries();
interiors = params.getFuncInteriors();
assertTrue(!entries.intersects(interiors));
assertEquals(1, entries.getNumAddresses());
assertTrue(entries.contains(largeThumbEntry));
assertTrue(interiors.contains(largeThumbInterior));
assertEquals(interiors.getNumAddresses(), largeThumbInterior.getNumAddresses());
}
}

View file

@ -0,0 +1,361 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.junit.Before;
import org.junit.Test;
import org.tribuo.*;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.baseline.DummyClassifierTrainer;
import org.tribuo.classification.ensemble.VotingCombiner;
import org.tribuo.datasource.ListDataSource;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.ensemble.WeightedEnsembleModel;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import generic.concurrent.QResult;
import ghidra.program.database.ProgramBuilder;
import ghidra.program.database.ProgramDB;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.Program;
import ghidra.program.util.ProgramSelection;
import ghidra.test.AbstractProgramBasedTest;
import ghidra.test.ClassicSampleX86ProgramBuilder;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskLauncher;
import ghidra.util.task.TaskMonitor;
public class RandomForestTrainingTaskTest extends AbstractProgramBasedTest {
private ProgramBuilder builder;
private FunctionStartRFParams params;
@Before
public void setup() throws Exception {
initialize();
}
@Override
protected Program getProgram() throws Exception {
builder = new ClassicSampleX86ProgramBuilder();
ProgramDB p = builder.getProgram();
return p;
}
@Test
public void testUpdateConfusionMatrix() throws Exception {
//test with one true positive, one false positive, one true negative, and one false negative
Address truePositive = program.getSymbolTable().getSymbols("entry").next().getAddress();
Address falsePositive = truePositive.add(1);
Address falseNegative = falsePositive.add(1);
Address trueNegative = falseNegative.add(1);
int[] confusionMatrix = new int[RandomForestTrainingTask.CONFUSION_MATRIX_SIZE];
AddressSet errors = new AddressSet();
List<QResult<Address, Boolean>> examples = new ArrayList<>();
examples.add(new QResult<>(truePositive, CompletableFuture.completedFuture(true)));
examples.add(new QResult<>(falsePositive, CompletableFuture.completedFuture(false)));
//don't need any of the fields of RandomForestTrainingTask for this test
RandomForestTrainingTask task = new RandomForestTrainingTask(program, null, null,
RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT);
task.updateConfusionMatrix(examples, confusionMatrix,
RandomForestFunctionFinderPlugin.FUNC_START, errors);
assertEquals(1, errors.getNumAddresses());
assertTrue(errors.contains(falsePositive));
errors.clear();
examples.clear();
examples.add(new QResult<>(falseNegative, CompletableFuture.completedFuture(false)));
examples.add(new QResult<>(trueNegative, CompletableFuture.completedFuture(true)));
task.updateConfusionMatrix(examples, confusionMatrix,
RandomForestFunctionFinderPlugin.NON_START, errors);
assertEquals(1, errors.getNumAddresses());
assertTrue(errors.contains(falseNegative));
assertEquals(1, confusionMatrix[RandomForestTrainingTask.TP]);
assertEquals(1, confusionMatrix[RandomForestTrainingTask.FP]);
assertEquals(1, confusionMatrix[RandomForestTrainingTask.TN]);
assertEquals(1, confusionMatrix[RandomForestTrainingTask.FN]);
}
@Test
public void testEvaluateModel() throws CancelledException {
DummyClassifierTrainer dummyStartTrainer = DummyClassifierTrainer
.createConstantTrainer(RandomForestFunctionFinderPlugin.FUNC_START.getLabel());
List<Example<Label>> trainingData = new ArrayList<>();
AddressSet starts = new AddressSet();
AddressSet nonStarts = new AddressSet();
//just need an address with 1 defined preByte and 1 defined byte
Address testStart = program.getSymbolTable().getSymbols("entry").next().getAddress().add(1);
Address testNonStart = testStart.add(1);
starts.add(testStart);
nonStarts.add(testNonStart);
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, starts,
RandomForestFunctionFinderPlugin.FUNC_START, 1, 1, true, TaskMonitor.DUMMY));
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, nonStarts,
RandomForestFunctionFinderPlugin.NON_START, 1, 1, true, TaskMonitor.DUMMY));
LabelFactory lf = new LabelFactory();
ListDataSource<Label> trainingSource =
new ListDataSource<Label>(trainingData, lf, new SimpleDataSourceProvenance("test", lf));
MutableDataset<Label> trainingSet = new MutableDataset<>(trainingSource);
//create ensemble with two models which always report FUNC_START
List<Model<Label>> models = new ArrayList<>();
models.add(dummyStartTrainer.train(trainingSet));
models.add(dummyStartTrainer.train(trainingSet));
WeightedEnsembleModel<Label> ensemble = WeightedEnsembleModel
.createEnsembleFromExistingModels("test", models, new VotingCombiner());
params = new FunctionStartRFParams(program);
params.setIncludeBitFeatures(true);
AddressSet errors = new AddressSet();
RandomForestTrainingTask task = new RandomForestTrainingTask(program, params, null,
RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT);
int[] confusion =
task.evaluateModel(ensemble, starts, nonStarts, errors, 1, 1, TaskMonitor.DUMMY);
assertEquals(1, errors.getNumAddresses());
assertTrue(errors.contains(testNonStart));
assertEquals(1, confusion[RandomForestTrainingTask.TP]);
assertEquals(1, confusion[RandomForestTrainingTask.FP]);
assertEquals(0, confusion[RandomForestTrainingTask.TN]);
assertEquals(0, confusion[RandomForestTrainingTask.FN]);
//create ensemble with two models which always report NON_START
DummyClassifierTrainer dummyNonStartTrainer = DummyClassifierTrainer
.createConstantTrainer(RandomForestFunctionFinderPlugin.NON_START.getLabel());
models = new ArrayList<>();
models.add(dummyNonStartTrainer.train(trainingSet));
models.add(dummyNonStartTrainer.train(trainingSet));
ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("test", models,
new VotingCombiner());
errors = new AddressSet();
confusion =
task.evaluateModel(ensemble, starts, nonStarts, errors, 1, 1, TaskMonitor.DUMMY);
assertEquals(1, errors.getNumAddresses());
assertTrue(errors.contains(testStart));
assertEquals(0, confusion[RandomForestTrainingTask.TP]);
assertEquals(0, confusion[RandomForestTrainingTask.FP]);
assertEquals(1, confusion[RandomForestTrainingTask.TN]);
assertEquals(1, confusion[RandomForestTrainingTask.FN]);
}
@Test
public void testTrainModel() throws CancelledException {
//train an ensemble on a trivial data set and verify that it contains the correct number
//of models
List<Example<Label>> trainingData = new ArrayList<>();
AddressSet starts = new AddressSet();
AddressSet nonStarts = new AddressSet();
//just need an address with 1 defined preByte and 1 defined byte
Address testStart = program.getSymbolTable().getSymbols("entry").next().getAddress().add(1);
Address testNonStart = testStart.add(1);
starts.add(testStart);
nonStarts.add(testNonStart);
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, starts,
RandomForestFunctionFinderPlugin.FUNC_START, 1, 1, true, TaskMonitor.DUMMY));
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, nonStarts,
RandomForestFunctionFinderPlugin.NON_START, 1, 1, true, TaskMonitor.DUMMY));
//don't need any of the fields of RandomForestTrainingTask for this test
RandomForestTrainingTask task = new RandomForestTrainingTask(program, null, null,
RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT);
EnsembleModel<Label> ensemble = task.trainModel(trainingData, TaskMonitor.DUMMY);
assertEquals(RandomForestTrainingTask.NUM_TREES, ensemble.getNumModels());
}
@Test
public void testSimilarStartsFinder() {
params = new FunctionStartRFParams(program);
List<Integer> testList = new ArrayList<>();
testList.add(5);
params.setFactors(testList);
params.setIncludeBitFeatures(true);
params.setIncludePrecedingAndFollowing(true);
params.setInitialBytes(testList);
params.setMaxStarts(100);
params.setMinFuncSize(16);
params.setPreBytes(testList);
List<RandomForestRowObject> rows = new ArrayList<>();
RandomForestTrainingTask task =
new RandomForestTrainingTask(program, params, x -> rows.add(x), 100);
TaskLauncher.launchModal("test", task);
assertEquals(1, rows.size());
SimilarStartsFinder finder = new SimilarStartsFinder(program, rows.get(0));
Address entryAddr = program.getSymbolTable().getSymbols("entry").next().getAddress();
List<SimilarStartRowObject> res = finder.getSimilarFunctionStarts(entryAddr, 7);
//just verify that the number of elements is correct, each element is a function start,
//and that the list is in descending order.
assertEquals(7, res.size());
res.forEach(
r -> assertTrue(program.getFunctionManager().getFunctionAt(r.funcStart()) != null));
int currentNum = res.get(0).numAgreements();
for (int i = 1; i < 7; ++i) {
assertTrue(currentNum >= res.get(i).numAgreements());
currentNum = res.get(i).numAgreements();
}
}
@Test
public void getTrainingAndTestDataBasicTest() throws CancelledException {
params = new FunctionStartRFParams(program);
params.setMaxStarts(5);
Address begin = program.getSymbolTable().getSymbols("entry").next().getAddress();
AddressSet entries = new AddressSet();
for (int i = 0; i < 10; ++i) {
entries.add(begin.add(i));
}
AddressSet interiors = new AddressSet();
for (int i = 10; i < 25; ++i) {
interiors.add(begin.add(i));
}
AddressSet definedData = new AddressSet();
for (int i = 25; i < 30; ++i) {
definedData.add(begin.add(i));
}
RandomForestTrainingTask task = new RandomForestTrainingTask(program, params, null,
RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT);
TrainingAndTestData data =
task.getTrainingAndTestData(entries, interiors, definedData, 2, TaskMonitor.DUMMY);
//5 function starts chosen from 10 possible
assertEquals(5, data.getTrainingPositive().getNumAddresses());
//5*2 interiors chosen from 15 possible
assertEquals(10, data.getTrainingNegative().getNumAddresses());
//5 function starts were not chosen
assertEquals(5, data.getTestPositive().getNumAddresses());
//5 interiors were not chosen + 5 defined data
assertEquals(10, data.getTestNegative().getNumAddresses());
assertTrue(data.getTestPositive()
.union(data.getTestNegative())
.intersect(data.getTrainingNegative().union(data.getTrainingPositive()))
.isEmpty());
assertTrue(data.getTestPositive().intersect(data.getTestNegative()).isEmpty());
assertTrue(data.getTrainingPositive().intersect(data.getTrainingNegative()).isEmpty());
assertTrue(entries.contains(data.getTestPositive()));
assertTrue(entries.contains(data.getTrainingPositive()));
assertTrue(interiors.contains(data.getTrainingNegative()));
assertTrue(interiors.contains(data.getTestNegative().subtract(definedData)));
assertTrue(data.getTestNegative().contains(definedData));
}
//FUN_010059a3 is legit
//
// 100641f 00
// 1006420 55 PUSH EBP <- entry
// 1006421 8b ec MOV EBP, ESP
//
// 1006423 6a ff PUSH -0x1
// 1006425 66 88 18 00 01 PUSH DAT_01001888
// 100642a 68 d0 65 00 01 PUSH DAT_010065d0
//
// 100642f 64 a1 00 00 00 00 MOV EAX, FS:[0x0]
// 1006435 50 PUSH EAX
// 1006436 64 89 25 00 00 00 00 MOV dword ptr FS:[0x0],ESP
//
@Test
public void getTrainingAndTestDataDeluxeTest() throws CancelledException {
params = new FunctionStartRFParams(program);
params.setMaxStarts(2);
params.setIncludePrecedingAndFollowing(true);
Address begin = program.getSymbolTable().getSymbols("entry").next().getAddress();
//create 3 starts, spaced out because we want to include the previous and following
AddressSet entries = new AddressSet();
entries.add(begin);
entries.add(program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1006425));
entries.add(program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1006435));
AddressSet interiors = new AddressSet();
for (int i = 0x20; i < 0x25; ++i) {
interiors.add(begin.add(i));
}
AddressSet definedData = new AddressSet();
for (int i = 0x30; i < 0x35; ++i) {
definedData.add(begin.add(i));
}
RandomForestTrainingTask task = new RandomForestTrainingTask(program, params, null,
RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT);
Address otherFuncEntry =
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10059a3);
Function func = program.getFunctionManager().getFunctionAt(otherFuncEntry);
task.setAdditional(
new ProgramSelection(func.getBody().getMinAddress(), func.getBody().getMaxAddress()));
TrainingAndTestData data =
task.getTrainingAndTestData(entries, interiors, definedData, 2, TaskMonitor.DUMMY);
//2 function starts chosen from 3 possible
//plus the entry of the function at 0x1005913 which was added explicitly
assertEquals(3, data.getTrainingPositive().getNumAddresses());
assertTrue(data.getTrainingPositive().contains(otherFuncEntry));
//2*2 interiors chosen from 5 possible
//2*2 plus preceding and following for two functions selected at random
//plus size of interior of function at 0x10059a3
assertEquals(4 + 4 + func.getBody().getNumAddresses() - 1,
data.getTrainingNegative().getNumAddresses());
//1 function start was not chosen
assertEquals(1, data.getTestPositive().getNumAddresses());
//1 interior was not chosen + (preceding and following for unchosen entry) + 5 defined data
assertEquals(8, data.getTestNegative().getNumAddresses());
assertTrue(data.getTestPositive()
.union(data.getTestNegative())
.intersect(data.getTrainingNegative().union(data.getTrainingPositive()))
.isEmpty());
assertTrue(data.getTestPositive().intersect(data.getTestNegative()).isEmpty());
assertTrue(data.getTrainingPositive().intersect(data.getTrainingNegative()).isEmpty());
assertTrue(entries.contains(data.getTestPositive()));
assertFalse(entries.contains(data.getTrainingPositive()));
AddressSet deluxeEntries = entries.union(new AddressSet(otherFuncEntry));
assertTrue(deluxeEntries.contains(data.getTrainingPositive()));
int numContained = 0;
Address entry_1006420 = begin;
Address entry_1006425 = begin.add(5l);
Address entry_1006435 = begin.add(0x15l);
if (data.getTrainingPositive().contains(entry_1006420)) {
numContained += 1;
assertTrue(data.getTrainingNegative().contains(entry_1006420.subtract(1l)));
assertTrue(data.getTrainingNegative().contains(entry_1006420.add(1l)));
}
if (data.getTrainingPositive().contains(entry_1006425)) {
numContained += 1;
assertTrue(data.getTrainingNegative().contains(entry_1006425.subtract(2l)));
assertTrue(data.getTrainingNegative().contains(entry_1006425.add(5l)));
}
if (data.getTrainingPositive().contains(entry_1006435)) {
numContained += 1;
assertTrue(data.getTrainingNegative().contains(entry_1006435.subtract(6l)));
assertTrue(data.getTrainingNegative().contains(entry_1006435.add(1l)));
}
assertEquals(2, numContained);
assertTrue(data.getTestNegative().contains(interiors.subtract(data.getTrainingNegative())));
assertTrue(data.getTestNegative().contains(definedData));
}
}

View file

@ -0,0 +1,95 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import static org.junit.Assert.*;
import java.util.List;
import org.junit.Test;
import generic.test.AbstractGenericTest;
public class FunctionStartRFParamsTest extends AbstractGenericTest {
@Test(expected = IllegalArgumentException.class)
public void testBadParse1() {
FunctionStartRFParams.parseIntegerCSV("");
}
@Test(expected = IllegalArgumentException.class)
public void testBadParse2() {
FunctionStartRFParams.parseIntegerCSV(" ");
}
@Test(expected = NumberFormatException.class)
public void testBadParse3() {
FunctionStartRFParams.parseIntegerCSV("1,,2");
}
@Test(expected = IllegalArgumentException.class)
public void testBadParse4() {
FunctionStartRFParams.parseIntegerCSV("-1");
}
@Test(expected = NumberFormatException.class)
public void testBadParse5() {
FunctionStartRFParams.parseIntegerCSV("--1");
}
@Test(expected = IllegalArgumentException.class)
public void testBadParse6() {
FunctionStartRFParams.parseIntegerCSV(",");
}
@Test(expected = IllegalArgumentException.class)
public void testBadParse7() {
FunctionStartRFParams.parseIntegerCSV("1,");
}
@Test(expected = IllegalArgumentException.class)
public void testBadParse8() {
FunctionStartRFParams.parseIntegerCSV("1,2,3,");
}
@Test(expected = IllegalArgumentException.class)
public void testBadParse9() {
FunctionStartRFParams.parseIntegerCSV(",1,2,3,");
}
@Test(expected = NumberFormatException.class)
public void testBadParse10() {
FunctionStartRFParams.parseIntegerCSV("1,0xabcdv,3");
}
@Test
public void testBasicValidParses() {
List<Integer> results = FunctionStartRFParams.parseIntegerCSV("12345678");
assertEquals(1, results.size());
assertEquals(Integer.valueOf(12345678), results.get(0));
results = FunctionStartRFParams.parseIntegerCSV("0x1,2 , 0x3, 4");
assertEquals(4, results.size());
assertEquals(Integer.valueOf(1), results.get(0));
assertEquals(Integer.valueOf(2), results.get(1));
assertEquals(Integer.valueOf(3), results.get(2));
assertEquals(Integer.valueOf(4), results.get(3));
results = FunctionStartRFParams.parseIntegerCSV("4,3,4,3");
assertEquals(2, results.size());
assertEquals(Integer.valueOf(3), results.get(0));
assertEquals(Integer.valueOf(4), results.get(1));
}
}

View file

@ -0,0 +1,94 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import static org.junit.Assert.*;
import java.util.*;
import org.junit.Test;
import generic.test.AbstractGenericTest;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.address.TestAddress;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
public class RandomSubsetTest extends AbstractGenericTest {
@Test
public void testGenerateTrivialSubsets() {
List<Long> empty = RandomSubsetUtils.generateRandomIntegerSubset(10, 0);
assertEquals(0, empty.size());
empty = RandomSubsetUtils.generateRandomIntegerSubset(0, 0);
assertEquals(0, empty.size());
List<Long> complete = RandomSubsetUtils.generateRandomIntegerSubset(1000000, 1000000);
Collections.sort(complete);
Iterator<Long> iter = complete.iterator();
long current = 0;
while (iter.hasNext()) {
long elem = iter.next();
assertEquals(current++, elem);
}
}
@Test
public void testBasicRandomSubsetOfAddresses() throws CancelledException {
AddressSet addrs = new AddressSet();
for (long i = 0; i < 10000; ++i) {
addrs.add(new TestAddress(i));
}
AddressSet rand = RandomSubsetUtils.randomSubset(addrs, 9998, TaskMonitor.DUMMY);
assertEquals(9998, rand.getNumAddresses());
}
@Test
public void testSwap() {
Map<Long, Long> permuted = new HashMap<>();
assertTrue(permuted.isEmpty());
//should do nothing
RandomSubsetUtils.swap(permuted, 1, 1);
assertTrue(permuted.isEmpty());
permuted.put(0l, 5l);
permuted.put(1l, 10l);
RandomSubsetUtils.swap(permuted, 0, 1);
assertEquals(2, permuted.size());
assertEquals(Long.valueOf(5), permuted.get(1l));
assertEquals(Long.valueOf(10), permuted.get(0l));
RandomSubsetUtils.swap(permuted, 100l, 200L);
assertEquals(4, permuted.size());
assertEquals(Long.valueOf(100), permuted.get(200l));
assertEquals(Long.valueOf(200), permuted.get(100l));
}
/**
@Test
public void timingTest() throws CancelledException {
AddressSet big = new AddressSet(new TestAddress(0), new TestAddress(999999));
long start = System.nanoTime();
List<Long> complete = RandomSubset.generateRandomIntegerSubset(1000000, 500000);
long end = System.nanoTime();
Msg.info(this, "choosing random subset of integers: " +
(end - start) / RandomForestTrainingTask.NANOSECONDS_PER_SECOND);
start = System.nanoTime();
AddressSet random = RandomSubset.randomSubset(big, 500000, TaskMonitor.DUMMY);
end = System.nanoTime();
Msg.info(this, "choosing random subset of addresses: " +
(end - start) / RandomForestTrainingTask.NANOSECONDS_PER_SECOND);
}*/
}

View file

@ -0,0 +1,68 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ghidra.machinelearning.functionfinding;
import static org.junit.Assert.*;
import org.junit.Before;
import org.junit.Test;
import generic.test.AbstractGenericTest;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.address.TestAddress;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
public class TrainingAndTestDataTest extends AbstractGenericTest {
private TrainingAndTestData data;
private AddressSet originalPos;
private AddressSet originalNeg;
@Before
public void setUp() {
AddressSet testPositive = new AddressSet();
testPositive.add(new TestAddress(0), new TestAddress(100));
originalPos = new AddressSet(testPositive);
AddressSet testNegative = new AddressSet();
testNegative.add(new TestAddress(500), new TestAddress(1000));
originalNeg = new AddressSet(testNegative);
data =
new TrainingAndTestData(new AddressSet(), new AddressSet(), testPositive, testNegative);
}
@Test
public void reduceTest1() throws CancelledException {
data.reduceTestSetSize(1001, TaskMonitor.DUMMY);
assertTrue(data.getTestPositive().hasSameAddresses(originalPos));
assertTrue(data.getTestNegative().hasSameAddresses(originalNeg));
}
@Test
public void reduceTest2() throws CancelledException {
data.reduceTestSetSize(250, TaskMonitor.DUMMY);
assertTrue(data.getTestPositive().hasSameAddresses(originalPos));
assertEquals(250, data.getTestNegative().getNumAddresses());
}
@Test
public void reduceTest3() throws CancelledException {
data.reduceTestSetSize(10, TaskMonitor.DUMMY);
assertEquals(10, data.getTestPositive().getNumAddresses());
assertEquals(10, data.getTestNegative().getNumAddresses());
}
}