mirror of
https://github.com/NationalSecurityAgency/ghidra
synced 2024-09-13 21:56:19 +00:00
Merge remote-tracking branch 'origin/GP-2204_MachineLearning_extensions--SQUASHED'
This commit is contained in:
commit
926286ee6f
10
Ghidra/Extensions/MachineLearning/Module.manifest
Normal file
10
Ghidra/Extensions/MachineLearning/Module.manifest
Normal file
|
@ -0,0 +1,10 @@
|
|||
MODULE FILE LICENSE: lib/olcut-config-protobuf-5.2.0.jar BSD-2-ORACLE
|
||||
MODULE FILE LICENSE: lib/olcut-core-5.2.0.jar BSD-2-ORACLE
|
||||
MODULE FILE LICENSE: lib/protobuf-java-3.17.3.jar BSD-3-GOOGLE
|
||||
MODULE FILE LICENSE: lib/tribuo-classification-core-4.2.0.jar Apache License 2.0
|
||||
MODULE FILE LICENSE: lib/tribuo-classification-tree-4.2.0.jar Apache License 2.0
|
||||
MODULE FILE LICENSE: lib/tribuo-common-tree-4.2.0.jar Apache License 2.0
|
||||
MODULE FILE LICENSE: lib/tribuo-core-4.2.0.jar Apache License 2.0
|
||||
MODULE FILE LICENSE: lib/tribuo-data-4.2.0.jar Apache License 2.0
|
||||
MODULE FILE LICENSE: lib/tribuo-math-4.2.0.jar Apache License 2.0
|
||||
MODULE FILE LICENSE: lib/tribuo-util-onnx-4.2.0.jar Apache License 2.0
|
46
Ghidra/Extensions/MachineLearning/build.gradle
Normal file
46
Ghidra/Extensions/MachineLearning/build.gradle
Normal file
|
@ -0,0 +1,46 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
apply from: "$rootProject.projectDir/gradle/distributableGhidraExtension.gradle"
|
||||
apply from: "$rootProject.projectDir/gradle/javaProject.gradle"
|
||||
apply from: "$rootProject.projectDir/gradle/javaTestProject.gradle"
|
||||
apply from: "$rootProject.projectDir/gradle/helpProject.gradle"
|
||||
|
||||
apply plugin: 'eclipse'
|
||||
eclipse.project.name = 'Xtra MachineLearning'
|
||||
|
||||
dependencies {
|
||||
api project(':Base')
|
||||
helpPath project(path: ":Base", configuration: 'helpPath')
|
||||
|
||||
api "com.oracle.labs.olcut:olcut-config-protobuf:5.2.0" //{exclude group: "com.google.protobuf", module: "protobuf-java"}
|
||||
api ("com.oracle.labs.olcut:olcut-core:5.2.0") {exclude group: "org.jline"}
|
||||
api "com.google.protobuf:protobuf-java:3.17.3" //only needed for running junits
|
||||
api "org.tribuo:tribuo-classification-core:4.2.0"
|
||||
api "org.tribuo:tribuo-classification-tree:4.2.0"
|
||||
api "org.tribuo:tribuo-common-tree:4.2.0"
|
||||
api 'org.tribuo:tribuo-core:4.2.0'
|
||||
api ("org.tribuo:tribuo-data:4.2.0") {exclude group: "com.opencsv"}
|
||||
api "org.tribuo:tribuo-math:4.2.0"
|
||||
api ("org.tribuo:tribuo-util-onnx:4.2.0") //{exclude group: "com.google.protobuf", module: "protobuf-java"}
|
||||
|
||||
testImplementation project(path: ':SoftwareModeling', configuration: 'testArtifacts')
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
10
Ghidra/Extensions/MachineLearning/certification.manifest
Normal file
10
Ghidra/Extensions/MachineLearning/certification.manifest
Normal file
|
@ -0,0 +1,10 @@
|
|||
##VERSION: 2.0
|
||||
##MODULE IP: Apache License 2.0
|
||||
##MODULE IP: BSD-2-ORACLE
|
||||
##MODULE IP: BSD-3-GOOGLE
|
||||
Module.manifest||GHIDRA||||END|
|
||||
extension.properties||GHIDRA||||END|
|
||||
lib/README.txt||GHIDRA||||END|
|
||||
src/main/help/help/TOC_Source.xml||GHIDRA||||END|
|
||||
src/main/help/help/topics/RandomForestFunctionFinderPlugin/RandomForestFunctionFinderPlugin.htm||GHIDRA||||END|
|
||||
src/main/resources/images/README.txt||GHIDRA||||END|
|
|
@ -0,0 +1,58 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
//Writes a list of the addresses of all call sites to a file.
|
||||
//@category machineLearning
|
||||
|
||||
import java.io.*;
|
||||
|
||||
import ghidra.app.script.GhidraScript;
|
||||
import ghidra.program.model.listing.*;
|
||||
import ghidra.program.model.pcode.PcodeOp;
|
||||
|
||||
public class DumpCalls extends GhidraScript {
|
||||
|
||||
private static final String DATA_DIR = "/local/calls";
|
||||
|
||||
@Override
|
||||
protected void run() throws Exception {
|
||||
File outFile = new File(DATA_DIR + File.separator + currentProgram.getName() + "_calls");
|
||||
FileWriter fWriter = new FileWriter(outFile);
|
||||
BufferedWriter bWriter = new BufferedWriter(fWriter);
|
||||
InstructionIterator fIter = currentProgram.getListing().getInstructions(true);
|
||||
int numCalls = 0;
|
||||
int numInstructions = 0;
|
||||
while (fIter.hasNext()) {
|
||||
Instruction inst = fIter.next();
|
||||
if (inst.getPcode() == null || inst.getPcode().length == 0) {
|
||||
continue;
|
||||
}
|
||||
numInstructions++;
|
||||
for (int i = 0; i < inst.getPcode().length; i++) {
|
||||
PcodeOp pCode = inst.getPcode()[i];
|
||||
int opCode = pCode.getOpcode();
|
||||
if (opCode == PcodeOp.CALL || opCode == PcodeOp.CALLIND) {
|
||||
//printf("Inst: %s at %s\n", inst.toString(), inst.getAddress());
|
||||
numCalls++;
|
||||
bWriter.write(inst.getAddress().toString() + "\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
printf("total num calls: %d\n", numCalls);
|
||||
printf("total num instructions: %d\n", numInstructions);
|
||||
bWriter.close();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
//Writes a list of the addresses of all function starts and their sizes to a file
|
||||
//@category machineLearning
|
||||
|
||||
import java.io.*;
|
||||
|
||||
import ghidra.app.script.GhidraScript;
|
||||
import ghidra.program.model.listing.Function;
|
||||
import ghidra.program.model.listing.FunctionIterator;
|
||||
|
||||
public class DumpFunctionStarts extends GhidraScript {
|
||||
|
||||
private static final String DATA_DIR = "/local/funcstarts/stripped";
|
||||
|
||||
@Override
|
||||
protected void run() throws Exception {
|
||||
File outFile =
|
||||
new File(DATA_DIR + File.separator + currentProgram.getName() + "_stripped_funcs");
|
||||
FileWriter fWriter = new FileWriter(outFile);
|
||||
BufferedWriter bWriter = new BufferedWriter(fWriter);
|
||||
FunctionIterator fIter = currentProgram.getFunctionManager().getFunctions(true);
|
||||
while (fIter.hasNext()) {
|
||||
Function func = fIter.next();
|
||||
if (func.isExternal()) {
|
||||
continue;
|
||||
}
|
||||
long size = func.getBody().getNumAddresses();
|
||||
bWriter.write(func.getEntryPoint().toString() + "," + size + "\n");
|
||||
}
|
||||
bWriter.close();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Paths;
|
||||
|
||||
import org.tribuo.*;
|
||||
import org.tribuo.classification.Label;
|
||||
import org.tribuo.classification.LabelFactory;
|
||||
import org.tribuo.classification.dtree.CARTClassificationTrainer;
|
||||
import org.tribuo.classification.ensemble.VotingCombiner;
|
||||
import org.tribuo.classification.evaluation.LabelEvaluation;
|
||||
import org.tribuo.classification.evaluation.LabelEvaluator;
|
||||
import org.tribuo.common.tree.RandomForestTrainer;
|
||||
import org.tribuo.data.csv.CSVLoader;
|
||||
import org.tribuo.ensemble.EnsembleModel;
|
||||
import org.tribuo.evaluation.TrainTestSplitter;
|
||||
|
||||
public class ExampleTribuoRunner {
|
||||
|
||||
public static void main(String args[]) throws IOException {
|
||||
|
||||
var irisHeaders =
|
||||
new String[] { "sepalLength", "sepalWidth", "petalLength", "petalWidth", "species" };
|
||||
DataSource<Label> irisData = new CSVLoader<>(new LabelFactory()).loadDataSource(
|
||||
Paths.get("/home/jmworth/ml/bezdekIris.data"), /* Output column */ irisHeaders[4],
|
||||
/* Column headers */ irisHeaders);
|
||||
|
||||
// Split iris data into training set (70%) and test set (30%)
|
||||
var splitIrisData =
|
||||
new TrainTestSplitter<>(irisData, /* Train fraction */ 0.7, /* RNG seed */ 1L);
|
||||
var trainData = new MutableDataset<>(splitIrisData.getTrain());
|
||||
var testData = new MutableDataset<>(splitIrisData.getTest());
|
||||
|
||||
// We can train a decision tree
|
||||
var cartTrainer = new CARTClassificationTrainer(100, (float) 0.2, 0);
|
||||
|
||||
var decisionTree = cartTrainer.train(trainData);
|
||||
|
||||
//Model<Label> tree = cartTrainer.train(trainData);
|
||||
|
||||
var trainer = new RandomForestTrainer<>(cartTrainer, // trainer - the tree trainer
|
||||
new VotingCombiner(), // combiner - the combining function for the ensemble
|
||||
10 // numMembers - the number of ensemble members to train
|
||||
);
|
||||
|
||||
EnsembleModel<Label> tree = trainer.train(trainData);
|
||||
|
||||
// Finally we make predictions on unseen data
|
||||
// Each prediction is a map from the output names (i.e. the labels) to the scores/probabilities
|
||||
Prediction<Label> prediction = tree.predict(testData.getExample(0));
|
||||
|
||||
// Or we can evaluate the full test dataset, calculating the accuracy, F1 etc.
|
||||
LabelEvaluation evaluation = new LabelEvaluator().evaluate(tree, testData);
|
||||
// we can inspect the evaluation manually
|
||||
double acc = evaluation.accuracy();
|
||||
// which returns 0.978
|
||||
// or print a formatted evaluation string
|
||||
System.out.println(evaluation.toString());
|
||||
}
|
||||
|
||||
}
|
5
Ghidra/Extensions/MachineLearning/extension.properties
Normal file
5
Ghidra/Extensions/MachineLearning/extension.properties
Normal file
|
@ -0,0 +1,5 @@
|
|||
name=MachineLearning
|
||||
description=Finds functions using ML
|
||||
author=Ghidra Team
|
||||
createdOn=9/25/2022
|
||||
version=@extversion@
|
|
@ -0,0 +1,200 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
//Example script for training random forests to find function starts
|
||||
//@category machineLearning
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import ghidra.app.cmd.disassemble.DisassembleCommand;
|
||||
import ghidra.app.cmd.function.CreateFunctionCmd;
|
||||
import ghidra.app.script.GhidraScript;
|
||||
import ghidra.machinelearning.functionfinding.*;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.block.BasicBlockModel;
|
||||
|
||||
//NOTE: This script is referenced by name in the help for the
|
||||
//RandomForestFunctionFinderPlugin. If you change the name be
|
||||
//sure to update the help.
|
||||
public class FindFunctionsRFExampleScript extends GhidraScript {
|
||||
|
||||
@Override
|
||||
protected void run() throws Exception {
|
||||
|
||||
//get the parameters controlling how many models are trained and
|
||||
//what data is used to train/test them
|
||||
FunctionStartRFParams params = new FunctionStartRFParams(currentProgram);
|
||||
|
||||
//maximum number of function starts to use in the training set
|
||||
//a warning will be issued if there are no functions left over for the test set
|
||||
params.setMaxStarts(1000);
|
||||
|
||||
//minimum size of a function to be included in the training/test sets
|
||||
params.setMinFuncSize(16);
|
||||
|
||||
//number of bytes before a function start
|
||||
params.setPreBytes(Arrays.asList(new Integer[] { 2, 8 }));
|
||||
|
||||
//number of bytes after (and including) a function start
|
||||
params.setInitialBytes(Arrays.asList(new Integer[] { 8, 16 }));
|
||||
|
||||
//number of non-starts to sample for each start in the training set
|
||||
params.setFactors(Arrays.asList(new Integer[] { 10, 50 }));
|
||||
|
||||
//for every function start in the (test,training) set, add the code
|
||||
//units immediately before and immediately after to the (test,training) set
|
||||
//as non-starts.
|
||||
params.setIncludePrecedingAndFollowing(true);
|
||||
|
||||
//uncomment to include features for each bit rather than just byte-level features
|
||||
//params.setIncludeBitFeatures(true);
|
||||
|
||||
//bound for reducing the size of the test sets
|
||||
long testSetMax = 1000000l;
|
||||
|
||||
//this is where the trained models will go
|
||||
List<RandomForestRowObject> trainedModels = new ArrayList<>();
|
||||
|
||||
RandomForestTrainingTask trainingTask = new RandomForestTrainingTask(currentProgram, params,
|
||||
r -> trainedModels.add(r), testSetMax);
|
||||
|
||||
//launch the task to train the models in parallel
|
||||
trainingTask.run(monitor);
|
||||
|
||||
//sort the models by the number of false positives (ascending)
|
||||
//if you actually need *unsigned* comparison it's likely that something has gone
|
||||
//horribly wrong
|
||||
Collections.sort(trainedModels,
|
||||
(x, y) -> Integer.compareUnsigned(x.getNumFalsePositives(), y.getNumFalsePositives()));
|
||||
|
||||
//grab the model with the fewest false positives
|
||||
//note: there could be ties; could sort the winners by recall
|
||||
RandomForestRowObject best = trainedModels.get(0);
|
||||
|
||||
printf(
|
||||
"Best model: pre-bytes: %d, initialBytes: %d, sampling factor: %d, false positives: %d," +
|
||||
" precision: %g, recall: %g\n",
|
||||
best.getNumPreBytes(), best.getNumInitialBytes(), best.getSamplingFactor(),
|
||||
best.getNumFalsePositives(), best.getPrecision(), best.getRecall());
|
||||
|
||||
//to get more information about the test set errors, apply the model to the error set
|
||||
|
||||
FunctionStartClassifier classifier = new FunctionStartClassifier(currentProgram, best,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START);
|
||||
//uncomment to see false negatives as well
|
||||
//classifier.setProbabilityThreshold(0.0);
|
||||
|
||||
Map<Address, Double> errors = classifier.classify(best.getTestErrors(), monitor);
|
||||
|
||||
List<Entry<Address, Double>> falsePositives = errors.entrySet()
|
||||
.stream()
|
||||
.filter(x -> currentProgram.getFunctionManager().getFunctionAt(x.getKey()) == null)
|
||||
.sorted((x, y) -> Double.compare(y.getValue(), x.getValue()))
|
||||
.toList();
|
||||
|
||||
//print out addresses of false positives
|
||||
printf("False positives:\n");
|
||||
falsePositives.forEach(x -> printf(" %s %g\n", x.getKey().toString(), x.getValue()));
|
||||
|
||||
//show the true function starts most similar to one of the false positives
|
||||
if (!falsePositives.isEmpty()) {
|
||||
SimilarStartsFinder finder = new SimilarStartsFinder(currentProgram, best);
|
||||
List<SimilarStartRowObject> neighbors =
|
||||
finder.getSimilarFunctionStarts(falsePositives.get(0).getKey(), 10);
|
||||
printf("\nClosest function starts to false positive at %s :\n",
|
||||
falsePositives.get(0).getKey());
|
||||
neighbors.forEach(n -> printf(" %s %d\n", n.funcStart(), n.numAgreements()));
|
||||
}
|
||||
|
||||
//grab the set of bytes in executable memory which are undefined (i.e., initialized but
|
||||
//not yet assigned to be code or data) or instructions which are not assigned to a function
|
||||
//body
|
||||
//don't bother looking in small undefined ranges
|
||||
long minUndefinedRange = 16;
|
||||
GetAddressesToClassifyTask getAddressTask =
|
||||
new GetAddressesToClassifyTask(currentProgram, minUndefinedRange);
|
||||
getAddressTask.run(monitor);
|
||||
|
||||
AddressSet toClassify = getAddressTask.getAddressesToClassify();
|
||||
|
||||
Map<Address, Double> potentialStarts = classifier.classify(toClassify, monitor);
|
||||
|
||||
//grab all of the addresses with probability of being a function start >= .7
|
||||
//and disassemble any that are currently undefined
|
||||
|
||||
//block model is needed to get the interpretation of an address
|
||||
//e.g., undefined, block start, within block,...
|
||||
BasicBlockModel blockModel = new BasicBlockModel(currentProgram);
|
||||
|
||||
//@formatter:off
|
||||
List<Address> addresses = potentialStarts.entrySet()
|
||||
.stream()
|
||||
.filter(x -> {return x.getValue() >= 0.7d;})
|
||||
.map(x -> x.getKey())
|
||||
.collect(Collectors.toList());
|
||||
//@formatter:on
|
||||
|
||||
AddressSet toDisassemble = new AddressSet();
|
||||
for (Address addr : addresses) {
|
||||
Interpretation inter =
|
||||
Interpretation.getInterpretation(currentProgram, addr, blockModel, monitor);
|
||||
if (inter.equals(Interpretation.UNDEFINED)) {
|
||||
toDisassemble.add(addr);
|
||||
}
|
||||
}
|
||||
|
||||
//see DisassemblyAndApplyContextAction#actionPerfomed if you need to
|
||||
//apply context register values
|
||||
printf("Found %d addresses to disassemble\n", toDisassemble.getNumAddresses());
|
||||
DisassembleCommand cmd = new DisassembleCommand(toDisassemble, null, true);
|
||||
cmd.applyTo(currentProgram);
|
||||
|
||||
//create functions at any address with probability of being a function start > .8,
|
||||
//where an instruction exists,
|
||||
//which is defined as a BLOCK_START (so not already a function start)
|
||||
//and with no conditional flow references to it
|
||||
//FunctionStartRowObject represents a row in a table in the gui, but you don't
|
||||
//actually need the gui and it's a convenient container for information about an address
|
||||
|
||||
//@formatter:off
|
||||
List<FunctionStartRowObject> funcRows = potentialStarts.entrySet()
|
||||
.stream()
|
||||
.filter(x -> {return x.getValue() >= 0.8d;})
|
||||
.map(x -> new FunctionStartRowObject(x.getKey(),x.getValue()))
|
||||
.collect(Collectors.toList());
|
||||
//@formatter:on
|
||||
|
||||
for (FunctionStartRowObject funcRow : funcRows) {
|
||||
funcRow.setCurrentInterpretation(Interpretation.getInterpretation(currentProgram,
|
||||
funcRow.getAddress(), blockModel, monitor));
|
||||
FunctionStartRowObject.setReferenceData(funcRow, currentProgram);
|
||||
}
|
||||
|
||||
AddressSet entries = new AddressSet();
|
||||
funcRows.stream()
|
||||
.filter(x -> x.getCurrentInterpretation().equals(Interpretation.BLOCK_START))
|
||||
.filter(x -> x.getNumConditionalFlowRefs() == 0)
|
||||
.forEach(x -> entries.add(x.getAddress()));
|
||||
|
||||
printf("Found %d addresses to create functions\n", entries.getNumAddresses());
|
||||
|
||||
CreateFunctionCmd createCmd = new CreateFunctionCmd(entries);
|
||||
createCmd.applyTo(currentProgram);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
// Turns off function start searching (intended for use with the
|
||||
// headless analyzer as a prescript)
|
||||
//@category machineLearning
|
||||
|
||||
import ghidra.app.script.GhidraScript;
|
||||
|
||||
public class TurnOffFuncStartSearch extends GhidraScript {
|
||||
|
||||
@Override
|
||||
protected void run() throws Exception {
|
||||
setAnalysisOption(currentProgram, "Function Start Search", "false");
|
||||
setAnalysisOption(currentProgram, "Function Start Search After Code", "false");
|
||||
setAnalysisOption(currentProgram, "Function Start Search After Data", "false");
|
||||
setAnalysisOption(currentProgram, "Function ID", "false");
|
||||
|
||||
}
|
||||
|
||||
}
|
3
Ghidra/Extensions/MachineLearning/lib/README.txt
Normal file
3
Ghidra/Extensions/MachineLearning/lib/README.txt
Normal file
|
@ -0,0 +1,3 @@
|
|||
The "lib" directory is intended to hold Jar files which this contrib
|
||||
is dependent upon. This directory may be eliminated from a specific
|
||||
contrib if no other Jar files are needed.
|
|
@ -0,0 +1,57 @@
|
|||
<?xml version="1.0" encoding="ISO-8859-1"?>
|
||||
<!--
|
||||
|
||||
This is an XML file intended to be parsed by the Ghidra help system. It is loosely based
|
||||
upon the JavaHelp table of contents document format. The Ghidra help system uses a
|
||||
TOC_Source.xml file to allow a module with help to define how its contents appear in the
|
||||
Ghidra help viewer's table of contents. The main document (in the Base module)
|
||||
defines a basic structure for the
|
||||
Ghidra table of contents system. Other TOC_Source.xml files may use this structure to insert
|
||||
their files directly into this structure (and optionally define a substructure).
|
||||
|
||||
|
||||
In this document, a tag can be either a <tocdef> or a <tocref>. The former is a definition
|
||||
of an XML item that may have a link and may contain other <tocdef> and <tocref> children.
|
||||
<tocdef> items may be referred to in other documents by using a <tocref> tag with the
|
||||
appropriate id attribute value. Using these two tags allows any module to define a place
|
||||
in the table of contents system (<tocdef>), which also provides a place for
|
||||
other TOC_Source.xml files to insert content (<tocref>).
|
||||
|
||||
During the help build time, all TOC_Source.xml files will be parsed and validated to ensure
|
||||
that all <tocref> tags point to valid <tocdef> tags. From these files will be generated
|
||||
<module name>_TOC.xml files, which are table of contents files written in the format
|
||||
desired by the JavaHelp system. Additionally, the genated files will be merged together
|
||||
as they are loaded by the JavaHelp system. In the end, when displaying help in the Ghidra
|
||||
help GUI, there will be on table of contents that has been created from the definitions in
|
||||
all of the modules' TOC_Source.xml files.
|
||||
|
||||
|
||||
Tags and Attributes
|
||||
|
||||
<tocdef>
|
||||
-id - the name of the definition (this must be unique across all TOC_Source.xml files)
|
||||
-text - the display text of the node, as seen in the help GUI
|
||||
-target** - the file to display when the node is clicked in the GUI
|
||||
-sortgroup - this is a string that defines where a given node should appear under a given
|
||||
parent. The string values will be sorted by the JavaHelp system using
|
||||
a javax.text.RulesBasedCollator. If this attribute is not specified, then
|
||||
the text of attribute will be used.
|
||||
|
||||
<tocref>
|
||||
-id - The id of the <tocdef> that this reference points to
|
||||
|
||||
**The URL for the target is relative and should start with 'help/topics'. This text is
|
||||
used by the Ghidra help system to provide a universal starting point for all links so that
|
||||
they can be resolved at runtime, across modules.
|
||||
|
||||
|
||||
-->
|
||||
<tocroot>
|
||||
<tocref id="Program Search">
|
||||
<tocdef id="Random Forest Function Finder"
|
||||
text="Search for Code and Functions"
|
||||
target="help/topics/RandomForestFunctionFinderPlugin/RandomForestFunctionFinderPlugin.htm"
|
||||
sortgroup="zzz">
|
||||
</tocdef>
|
||||
</tocref>
|
||||
</tocroot>
|
|
@ -0,0 +1,64 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*
|
||||
WARNING!
|
||||
This file is copied to all help directories. If you change this file, you must copy it
|
||||
to each src/main/help/help/shared directory.
|
||||
|
||||
|
||||
Java Help Note: JavaHelp does not accept sizes (like in 'margin-top') in anything but
|
||||
px (pixel) or with no type marking.
|
||||
|
||||
*/
|
||||
|
||||
body { margin-bottom: 50px; margin-left: 10px; margin-right: 10px; margin-top: 10px; } /* some padding to improve readability */
|
||||
li { font-family:times new roman; font-size:14pt; }
|
||||
h1 { color:#000080; font-family:times new roman; font-size:36pt; font-style:italic; font-weight:bold; text-align:center; }
|
||||
h2 { margin: 10px; margin-top: 20px; color:#984c4c; font-family:times new roman; font-size:18pt; font-weight:bold; }
|
||||
h3 { margin-left: 10px; margin-top: 20px; color:#0000ff; font-family:times new roman; `font-size:14pt; font-weight:bold; }
|
||||
h4 { margin-left: 10px; margin-top: 20px; font-family:times new roman; font-size:14pt; font-style:italic; }
|
||||
|
||||
/*
|
||||
P tag code. Most of the help files nest P tags inside of blockquote tags (the was the
|
||||
way it had been done in the beginning). The net effect is that the text is indented. In
|
||||
modern HTML we would use CSS to do this. We need to support the Ghidra P tags, nested in
|
||||
blockquote tags, as well as naked P tags. The following two lines accomplish this. Note
|
||||
that the 'blockquote p' definition will inherit from the first 'p' definition.
|
||||
*/
|
||||
p { margin-left: 40px; font-family:times new roman; font-size:14pt; }
|
||||
blockquote p { margin-left: 10px; }
|
||||
|
||||
p.providedbyplugin { color:#7f7f7f; margin-left: 10px; font-size:14pt; margin-top:100px }
|
||||
p.ProvidedByPlugin { color:#7f7f7f; margin-left: 10px; font-size:14pt; margin-top:100px }
|
||||
p.relatedtopic { color:#800080; margin-left: 10px; font-size:14pt; }
|
||||
p.RelatedTopic { color:#800080; margin-left: 10px; font-size:14pt; }
|
||||
|
||||
/*
|
||||
We wish for a tables to have space between it and the preceding element, so that text
|
||||
is not too close to the top of the table. Also, nest the table a bit so that it is clear
|
||||
the table relates to the preceding text.
|
||||
*/
|
||||
table { margin-left: 20px; margin-top: 10px; width: 80%;}
|
||||
td { font-family:times new roman; font-size:14pt; vertical-align: top; }
|
||||
th { font-family:times new roman; font-size:14pt; font-weight:bold; background-color: #EDF3FE; }
|
||||
|
||||
/*
|
||||
Code-like formatting for things such as file system paths and proper names of classes,
|
||||
methods, etc. To apply this to a file path, use this syntax:
|
||||
<CODE CLASS="path">...</CODE>
|
||||
*/
|
||||
code { color: black; font-weight: bold; font-family: courier new, monospace; font-size: 14pt; white-space: nowrap; }
|
||||
code.path { color: #4682B4; font-weight: bold; font-family: courier new, monospace; font-size: 14pt; white-space: nowrap; }
|
|
@ -0,0 +1,216 @@
|
|||
<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN">
|
||||
|
||||
<HTML>
|
||||
<HEAD>
|
||||
<META http-equiv="Content-Language" content="en-us">
|
||||
<META http-equiv="Content-Type" content="text/html; charset=windows-1252">
|
||||
|
||||
<TITLE>Random Forest Function Finder Plugin</TITLE>
|
||||
<LINK rel="stylesheet" type="text/css" href="../../shared/Frontpage.css">
|
||||
</HEAD>
|
||||
|
||||
<BODY>
|
||||
<H1><A name="RandomForestFunctionFinderPlugin"></A> Random Forest Function Finder Plugin</H1>
|
||||
|
||||
<P> This plugin trains models used to find function starts within a program. Essentially,
|
||||
the training set consists of addresses in a program where Ghidra's analysis was able to
|
||||
find functions. The models are then applied to the rest of the program.
|
||||
Models can also be applied to other programs. </P>
|
||||
|
||||
<P> In the motivating use case, you either don't know the toolchain which produced a program
|
||||
or do not have a large number of sample programs to train other types of models. </P>
|
||||
|
||||
<P> Note: in general, this plugin ensures that addresses used for training, testing, or
|
||||
searching for function starts are aligned relative to the processor's instruction alignment.
|
||||
Defined data within an executable block is an exception - all such bytes are added to
|
||||
the test set as examples of non-starts.</P>
|
||||
|
||||
<H2><A name="SuggestedWorkflow"></A> Basic Suggested Workflow</H2>
|
||||
<ol>
|
||||
<li> To begin, select <em>Search->For Code And Functions...</em> from the Code Browser.</li>
|
||||
<li> Click the <em>Train</em> button to train models using the default parameters.</li>
|
||||
<li> Choose the model with the fewest false positives (which will be apparent from
|
||||
the <em>Model Statistics</em> table).</li>
|
||||
<li> Right-click on that model's row and select <em>DEBUG - Show test set errors</em>.</li>
|
||||
<li> Examine the resulting table to determine if there is a good cutoff for
|
||||
the probabilities. Note that some of the "errors" might not actually be
|
||||
errors of the model: see the discussion in
|
||||
<A href="#DebugModelTable">Debug Model Table</A>.</li>
|
||||
<li> If you're satisified with the performance of the model, right-click on the
|
||||
row and select <em>Apply Model</em>. If you aren't, you can try changing the parameters
|
||||
and training again. You can also try the <em>Include Bit Features</em> training option.</li>
|
||||
<li> In the resulting table, select all addresses with an <em>Undefined</em> interpretation whose
|
||||
probability is above your threshold, right-click, and select <em>Disassemble</em>. This will
|
||||
start disassembly (and follow-on analysis) at each selected address.</li>
|
||||
<li> Now, select all addresses whose interpretation is <em>Block Start</em> and whose probability
|
||||
of being a function is above your threshold, right-click, and select <em>Create Function(s)</em>.
|
||||
It's also probably worth filtering out any addresses which are the targets of
|
||||
conditional references (which can be seen in the <em>Conditional Flow Refs</em> column). </li>
|
||||
</ol>
|
||||
<P> The script <em>FindFunctionsRFExampleScript.java</em> shows how to access the functionality of
|
||||
this plugin programmatically. </P>
|
||||
|
||||
<H2><A name="ModelTrainingTable"></A> Model Training Table</H2>
|
||||
<P> This table is the main interface for training and applying models. </P>
|
||||
|
||||
<H3><A name="DataGatheringParameters"></A> Data Gathering Parameters</H3>
|
||||
<P> The values in this panel control the number of models trained and the data used to train them.
|
||||
The first three fields: <A href="#NumberOfPreBytes">Number of Pre-Bytes (CSV)</A>,
|
||||
<A href="#NumberOfInitialBytes">Number of Initial Bytes (CSV)</A>, and
|
||||
<A href="#StartToNonStartFactors"> Start to Non-start Sampling Factors (CSV)</A> accept
|
||||
CSVs of positive integers as input (a single integer with no comma is allowed). Models
|
||||
corresponding to all possible choices of the three values will be trained and evaluated.
|
||||
That is, if you enter two values for the <em>Pre-bytes</em> field, three values for the
|
||||
<em>Initial Bytes</em> field, and four values for the <em>Sampling Factors</em> field, a total
|
||||
of 2*3*4 = 24 models will be trained and evaluated. </P>
|
||||
|
||||
<H4><A name="NumberOfPreBytes"></A> Number of Pre-bytes (CSV) </H4>
|
||||
<P> Values in this list control how many bytes before an address are used to construct its
|
||||
feature vector. </P>
|
||||
|
||||
<H4><A name="NumberOfInitialBytes"></A> Number of Initial Bytes (CSV) </H4>
|
||||
<P> Values in this list control how many bytes are used to construct the feature vector
|
||||
of an address, starting at the address. </P>
|
||||
|
||||
<H4><A name="StartToNonStartFactors"></A> Start to Non-start Sampling Factors (CSV) </H4>
|
||||
<P> Values in this list control how many non-starts (i.e., addresses in the interiors
|
||||
of functions) are added to the training set for each function start in the training set.</P>
|
||||
|
||||
<H4><A name="MaximumNumberOfStarts"></A> Maximum Number Of Starts </H4>
|
||||
<P> This field controls the maximum number of function starts that are added to the training
|
||||
set. </P>
|
||||
|
||||
<H4><A name="ContextRegsAndValues"></A> Context Registers and Values (CSV) </H4>
|
||||
<P> This field allows you to specify values of context registers. Addresses will only
|
||||
be added to the training/test sets if they agree with these values, and the disassembly
|
||||
action on the <a href="#FunctionStartTable"> Potential Functions Table</a> will apply the
|
||||
context register values first. This field accepts CSVs of the form "creg1=x,creg2=y,...".
|
||||
For example, to restrict Thumb mode in an ARM program, you would enter "TMode=1" in this field.
|
||||
</P>
|
||||
|
||||
<H4><A name="IncludePrecedingAndFollowing"></A> Include Preceding and Following </H4>
|
||||
<P> If this is selected, for every function entry in the training set, the code units immediately
|
||||
before it and after it are added to the training set as negative examples (and similarly for the
|
||||
test set). </P>
|
||||
|
||||
<H4><A name="IncludeBitFeatures"></A> Include Bit Features </H4>
|
||||
<P> If this is selected, a binary feature is added to the feature vector for each bit in the
|
||||
recorded bytes. </P>
|
||||
|
||||
<H4><A name="MinimumFunctionSize"></A> Minimum Function Size </H4>
|
||||
<P> This value is the minimum size a function must be for its entry and interior to be included
|
||||
in the training and test sets. </P>
|
||||
|
||||
<H3><A name="FunctionInformation"></A> Function Information</H3>
|
||||
<P> This panel displays information about the functions in the program. </P>
|
||||
|
||||
<H4><A name="FunctionsMeetingSizeBound"></A> Functions Meeting Size Bound</H4>
|
||||
<P> This field displays the number of functions meeting the size bound in the
|
||||
<A href="#MinimumFunctionSize"> Minimum Function Size</A> field. You can use
|
||||
this to ensure that the value in <A href="#MaximumNumberOfStarts">Maximum Number
|
||||
of Starts</A> field doesn't cause all starts to be used for training (leaving
|
||||
none for testing).</P>
|
||||
|
||||
<H4><A name="RestrictSearchToAlignedAddresses"></A> Restrict Search to Aligned Addresses </H4>
|
||||
<P> If this is checked, only addresses which are zero modulo the value in the
|
||||
<A href="#AlignmentModulus">Alignment Modulus</A> combo box are searched for function starts.
|
||||
This does not affect training or testing, but can be a useful optimization when applying
|
||||
models, for instance when the <A href="#FunctionAlignmentTable">Function Alignment Table</A>
|
||||
shows that all (known) functions in the program are aligned on 16-byte boundaries. </P>
|
||||
|
||||
<H4><A name="AlignmentModulus"></A> Alignment Modulus </H4>
|
||||
<P> The value in this combo box determines the modulus used when computing the values in
|
||||
the <A href="#FunctionAlignmentTable">Function Alignment Table</A>. </P>
|
||||
|
||||
<H4><A name="FunctionAlignmentTable"></A> Function Alignment Table </H4>
|
||||
<P> The rows in this table display the number of (known) functions in the program
|
||||
whose address has the given remainder modulo the alignment modulus.</P>
|
||||
|
||||
<H3><A name="ModelStatistics"></A> Model Statistics</H3>
|
||||
<P> This panel displays the statistics about the trained models as rows in a table.
|
||||
Actions on these rows allow you to apply the models or see the test set failures.</P>
|
||||
|
||||
<H4><A name="ApplyModel"></A> Apply Model Action </H4>
|
||||
<P> This action will apply the model to the program used to train it. The addresses
|
||||
searched consist of all addresses which are loaded, initialized, marked as executable,
|
||||
and not already in a function body (this set can be modified by the user via the
|
||||
<A href="#RestrictSearchToAlignedAddresses"> Restrict Search to Aligned Addresses</A>
|
||||
and <A href="#MinLengthUndefinedRange"> Minimum Length of Undefined Ranges to Search</A>
|
||||
options). The results are displayed in a
|
||||
<A href="#FunctionStartTable"> Function Start Table</A>. </P>
|
||||
|
||||
<H4><A name="ApplyModelTo"></A> Apply Model To... Action </H4>
|
||||
<P> This action will open a dialog to select another program in the current project and
|
||||
then apply the model to it. Note that the only check that the model is compatible with
|
||||
the selected program is that any context registers specified when training must be
|
||||
present in the selected program. </P>
|
||||
|
||||
<H4><A name="DebugModel"></A> Debug Model Action </H4>
|
||||
<P> This action will display a <A href="#DebugModelTable"> Debug Model Table</A>, which shows
|
||||
all of the errors encountered when applying the model to its test set. </P>
|
||||
|
||||
<H2><A name="FunctionStartTable"></A> Potential Functions Table</H2>
|
||||
<P> This table displays all addresses in the search set which the model thinks are function starts
|
||||
with probability at least .5. The table also shows the current "Interpretation" (e.g., undefined,
|
||||
instruction at start of basic block, etc) of the address along with the numbers of certain types
|
||||
of references to the address. </P>
|
||||
<P> The following actions are defined on this table:</P>
|
||||
|
||||
<H3><A name="DisassembleAction"></A> Disassemble Action </H3>
|
||||
<P> This action is enabled when at least one of the selected rows corresponds to an address
|
||||
with an interpretation of "Undefined". It begins disassembly at each "Undefined" address
|
||||
corresponding to a row in the selection. </P>
|
||||
|
||||
<H3><A name="DisassembleAndApplyContextAction"></A> Disassemble and Apply Context Action </H3>
|
||||
<P> This action is similar to the <A href="#DisassembleAction">Disassemble Action</A>, except
|
||||
it sets the context register values specified before training the model at the addresses
|
||||
and then disassembles. </P>
|
||||
|
||||
<H3><A name="CreateFunctionsAction"></A> Create Functions Action </H3>
|
||||
<P> This action is enabled whenever the selection contains at least one row whose corresponding
|
||||
address is the start of a basic block. This action creates functions at all such addresses.</P>
|
||||
|
||||
<H3><A name="ShowSimilarStartsAction"></A> Show Similar Function Starts Action </H3>
|
||||
<P> This action is enabled when the selection contains exactly one row. It displays
|
||||
a <A href="#SimilarStartsTable"> table</A> of the function starts in the training set
|
||||
which are most similar to the bytes at the address of the row. </P>
|
||||
|
||||
<H2><A name="SimilarStartsTable"></A> Similar Function Starts Table </H2>
|
||||
<P> This table displays the function starts in the training set which are most similar
|
||||
to a potential function start "from the model's point of view". Formally, similarity
|
||||
is measured using <b>random forest proximity</b>. Given a potential start <i>p</i> and
|
||||
a known start <i>s</i>, the similarity of <i>p</i> and <i>s</i> is the proportion of trees
|
||||
which end up in the same leaf node when processing <i>p</i> and <i>s</i>. </P>
|
||||
<P> For convenience, the potential start is also displayed as a row in the table. In
|
||||
the Address column, its address is surrounded by asterisks.</P>
|
||||
|
||||
<H2><A name="DebugModelTable"></A> Debug Model Table </H2>
|
||||
<P> This table has the same format as the <A href="#FunctionStartTable">Potential Functions Table</A>
|
||||
but does not have the disassembly or function-creating actions (it does have the action to
|
||||
display similar function starts). It displays all addresses in the test set where the classifier
|
||||
made an error. Note that some in some cases, it might be the classifier which is correct and the
|
||||
original analysis which was wrong. A common example is a tail call which
|
||||
was optimized to a jump during compilation. If there is only one jump to this address, then analysis
|
||||
may (reasonably) think that the function is just part of the function containing the jump even though
|
||||
the classifier thinks the jump target is a function start.</P>
|
||||
|
||||
<H2><A name="Options"></A> Options </H2>
|
||||
<P> This plugin has the following options. They can be set in the Tool Options menu. </P>
|
||||
|
||||
<H3><A name="MaxTestSetSize"></A> Maximum Test Set Size </H3>
|
||||
<P> This option controls the maximum size of the test sets (the test set of function
|
||||
starts and the test set of known non-starts which together form the model's "test set").
|
||||
Each set that is larger than the maximum will be replaced with a random subset of the maximum size. </P>
|
||||
|
||||
<H3><A name="MinLengthUndefinedRange"></A> Minimum Length of Undefined Ranges to Search </H3>
|
||||
<P> This option controls the minimum length a run of undefined bytes must be in order to
|
||||
be searched for function starts. This is an optimization which allows you to skip the
|
||||
(often quite numerous) small runs of undefined bytes between adjacent functions. Note
|
||||
that this option has no effect on model training or evaluation. </P>
|
||||
|
||||
|
||||
|
||||
|
||||
<P class="providedbyplugin">Provided By: <I>RandomForestFunctionFinderPlugin</I></P>
|
||||
</BODY>
|
||||
</HTML>
|
|
@ -0,0 +1,98 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import docking.ActionContext;
|
||||
import docking.action.DockingAction;
|
||||
import docking.action.MenuData;
|
||||
import ghidra.app.cmd.function.CreateFunctionCmd;
|
||||
import ghidra.framework.plugintool.Plugin;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.util.HelpLocation;
|
||||
import ghidra.util.table.GhidraTable;
|
||||
|
||||
/**
|
||||
* A {@link DockingAction} for creating functions from rows in a {@link FunctionStartTableModel}.
|
||||
* When performed on a selection, functions are created at all rows in the selection whose
|
||||
* {@link Interpretation} is {@link Interpretation#BLOCK_START}.
|
||||
*/
|
||||
public class CreateFunctionsAction extends DockingAction {
|
||||
private static final String MENU_TEXT = "Create Function(s)";
|
||||
private static final String ACTION_NAME = "CreateFunctionsAction";
|
||||
private Program program;
|
||||
private FunctionStartTableModel model;
|
||||
private GhidraTable table;
|
||||
private Plugin plugin;
|
||||
|
||||
/**
|
||||
* Constructs an action for creating functions.
|
||||
* @param plugin plugin
|
||||
* @param program source program
|
||||
* @param table table
|
||||
* @param model table model
|
||||
*/
|
||||
public CreateFunctionsAction(Plugin plugin, Program program, GhidraTable table,
|
||||
FunctionStartTableModel model) {
|
||||
super(ACTION_NAME, plugin.getName());
|
||||
this.program = program;
|
||||
this.model = model;
|
||||
this.plugin = plugin;
|
||||
this.table = table;
|
||||
init();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isAddToPopup(ActionContext context) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEnabledForContext(ActionContext context) {
|
||||
for (FunctionStartRowObject row : model.getRowObjects(table.getSelectedRows())) {
|
||||
switch (row.getCurrentInterpretation()) {
|
||||
case BLOCK_START:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void actionPerformed(ActionContext context) {
|
||||
AddressSet entries = new AddressSet();
|
||||
for (FunctionStartRowObject row : model.getRowObjects(table.getSelectedRows())) {
|
||||
switch (row.getCurrentInterpretation()) {
|
||||
case BLOCK_START:
|
||||
entries.add(row.getAddress());
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
CreateFunctionCmd cmd = new CreateFunctionCmd(entries);
|
||||
plugin.getTool().executeBackgroundCommand(cmd, program);
|
||||
}
|
||||
|
||||
private void init() {
|
||||
setPopupMenuData(new MenuData(new String[] { MENU_TEXT }));
|
||||
setDescription(
|
||||
String.format("Creates functions at all %s rows", Interpretation.BLOCK_START.name()));
|
||||
setHelpLocation(new HelpLocation(plugin.getName(), ACTION_NAME));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.List;
|
||||
|
||||
import docking.ActionContext;
|
||||
import docking.action.DockingAction;
|
||||
import docking.action.MenuData;
|
||||
import ghidra.app.cmd.disassemble.DisassembleCommand;
|
||||
import ghidra.framework.plugintool.Plugin;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.lang.Register;
|
||||
import ghidra.program.model.lang.RegisterValue;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.program.model.listing.ProgramContext;
|
||||
import ghidra.util.HelpLocation;
|
||||
import ghidra.util.table.GhidraTable;
|
||||
|
||||
/**
|
||||
* A {@link DockingAction} for disassembling at addresses corresponding to rows in a
|
||||
* {@link FunctionStartTableModel}. Context register values specified before training
|
||||
* will be applied before disassembly.
|
||||
*/
|
||||
public class DisassembleAndApplyContextAction extends DisassembleFunctionStartsAction {
|
||||
private static final String ACTION_NAME = "DisassembleAndApplyContextAction";
|
||||
private static final String MENU_TEXT = "Disassemble and Apply Context";
|
||||
private Program program;
|
||||
private FunctionStartTableModel model;
|
||||
private Plugin plugin;
|
||||
private GhidraTable table;
|
||||
|
||||
/**
|
||||
* Creates an action for disassembling at rows in a {@link FunctionStartTableModel} if the
|
||||
* {@link Interpretation} of the row is {@link Interpretation#UNDEFINED}. Specified
|
||||
* context register values are set before disassembly.
|
||||
* @param plugin owning plugin
|
||||
* @param program source program
|
||||
* @param table table
|
||||
* @param model table model
|
||||
*/
|
||||
public DisassembleAndApplyContextAction(Plugin plugin, Program program, GhidraTable table,
|
||||
FunctionStartTableModel model) {
|
||||
super(plugin, program, table, model);
|
||||
this.program = program;
|
||||
this.model = model;
|
||||
this.plugin = plugin;
|
||||
this.table = table;
|
||||
init();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void actionPerformed(ActionContext context) {
|
||||
AddressSet entries = new AddressSet();
|
||||
for (FunctionStartRowObject row : model.getRowObjects(table.getSelectedRows())) {
|
||||
switch (row.getCurrentInterpretation()) {
|
||||
case UNDEFINED:
|
||||
entries.add(row.getAddress());
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
DisassembleCommand cmd = new DisassembleCommand(entries, null, true);
|
||||
|
||||
RandomForestRowObject row = model.getRandomForestRowObject();
|
||||
if (row.isContextRestricted()) {
|
||||
ProgramContext programContext = program.getProgramContext();
|
||||
List<String> regNames = row.getContextRegisterList();
|
||||
List<BigInteger> regValues = row.getContextRegisterValues();
|
||||
RegisterValue newValue = new RegisterValue(programContext.getBaseContextRegister());
|
||||
for (int i = 0; i < regNames.size(); ++i) {
|
||||
Register reg = program.getRegister(regNames.get(i));
|
||||
newValue = newValue.combineValues(new RegisterValue(reg, regValues.get(i)));
|
||||
}
|
||||
cmd.setInitialContext(newValue);
|
||||
}
|
||||
plugin.getTool().executeBackgroundCommand(cmd, program);
|
||||
}
|
||||
|
||||
private void init() {
|
||||
setPopupMenuData(new MenuData(new String[] { MENU_TEXT }));
|
||||
setDescription("Apply context and disassemble at the selected rows");
|
||||
setHelpLocation(new HelpLocation(plugin.getName(), ACTION_NAME));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import docking.ActionContext;
|
||||
import docking.action.DockingAction;
|
||||
import docking.action.MenuData;
|
||||
import ghidra.app.cmd.disassemble.DisassembleCommand;
|
||||
import ghidra.framework.plugintool.Plugin;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.util.HelpLocation;
|
||||
import ghidra.util.table.GhidraTable;
|
||||
|
||||
/**
|
||||
* A {@link DockingAction} for disassembling at addresses corresponding to rows in a
|
||||
* {@link FunctionStartTableModel}.
|
||||
*/
|
||||
public class DisassembleFunctionStartsAction extends DockingAction {
|
||||
private static final String ACTION_NAME = "DisassembleAction";
|
||||
private static final String MENU_TEXT = "Disassemble";
|
||||
private Program program;
|
||||
private FunctionStartTableModel model;
|
||||
private GhidraTable table;
|
||||
private Plugin plugin;
|
||||
|
||||
/**
|
||||
* Creates and action for disassembling at rows in a {@link FunctionStartTableModel} if the
|
||||
* {@link Interpretation} of the row is {@link Interpretation#UNDEFINED}.
|
||||
* @param plugin owning plugin
|
||||
* @param program source program
|
||||
* @param table table
|
||||
* @param model table model
|
||||
*/
|
||||
public DisassembleFunctionStartsAction(Plugin plugin, Program program, GhidraTable table,
|
||||
FunctionStartTableModel model) {
|
||||
super(ACTION_NAME, plugin.getName());
|
||||
this.program = program;
|
||||
this.model = model;
|
||||
this.plugin = plugin;
|
||||
this.table = table;
|
||||
init();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isAddToPopup(ActionContext context) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEnabledForContext(ActionContext context) {
|
||||
for (FunctionStartRowObject row : model.getRowObjects(table.getSelectedRows())) {
|
||||
switch (row.getCurrentInterpretation()) {
|
||||
case UNDEFINED:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void actionPerformed(ActionContext context) {
|
||||
AddressSet entries = new AddressSet();
|
||||
for (FunctionStartRowObject row : model.getRowObjects(table.getSelectedRows())) {
|
||||
switch (row.getCurrentInterpretation()) {
|
||||
case UNDEFINED:
|
||||
entries.add(row.getAddress());
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
DisassembleCommand cmd = new DisassembleCommand(entries, null, true);
|
||||
plugin.getTool().executeBackgroundCommand(cmd, program);
|
||||
}
|
||||
|
||||
private void init() {
|
||||
setPopupMenuData(new MenuData(new String[] { MENU_TEXT }));
|
||||
setDescription("Disassemble at the selected rows");
|
||||
setHelpLocation(new HelpLocation(plugin.getName(), ACTION_NAME));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.tribuo.*;
|
||||
import org.tribuo.classification.Label;
|
||||
import org.tribuo.ensemble.EnsembleModel;
|
||||
import org.tribuo.impl.ArrayExample;
|
||||
|
||||
import generic.concurrent.QCallback;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* This class is used as a callback for parallelized ensemble evaluation. Rather than
|
||||
* computing the precise probability the ensemble would assign to a given address and label,
|
||||
* it only computes whether the probability is >= .5. The computation is short-circuited:
|
||||
* as soon as enough members of the ensemble have been checked to determine the return value
|
||||
* the computation stops.
|
||||
*/
|
||||
public class EnsembleEvaluatorCallback implements QCallback<Address, Boolean> {
|
||||
|
||||
private EnsembleModel<Label> ensemble;
|
||||
private int numModels;
|
||||
private int numPreBytes;
|
||||
private int numInitialBytes;
|
||||
private Label label;
|
||||
private Program program;
|
||||
private boolean includeBitFeatures;
|
||||
|
||||
/**
|
||||
* Create a new evaluator
|
||||
* @param ensemble ensemble to evaluate
|
||||
* @param p program containing addresses to test
|
||||
* @param numPreBytes number of bytes before address
|
||||
* @param numInitialBytes number of bytes after and including address
|
||||
* @param includeBitFeatures whether to include bit-level features
|
||||
* @param label target label
|
||||
*/
|
||||
public EnsembleEvaluatorCallback(EnsembleModel<Label> ensemble, Program p, int numPreBytes,
|
||||
int numInitialBytes, boolean includeBitFeatures, Label label) {
|
||||
this.ensemble = ensemble;
|
||||
numModels = ensemble.getNumModels();
|
||||
this.numPreBytes = numPreBytes;
|
||||
this.numInitialBytes = numInitialBytes;
|
||||
this.label = label;
|
||||
program = p;
|
||||
this.includeBitFeatures = includeBitFeatures;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean process(Address item, TaskMonitor monitor) throws Exception {
|
||||
List<Feature> trainingVector = ModelTrainingUtils.getFeatureVector(program, item,
|
||||
numPreBytes, numInitialBytes, includeBitFeatures);
|
||||
if (trainingVector.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
ArrayExample<Label> vec = new ArrayExample<>(label, trainingVector);
|
||||
|
||||
int numAgree = 0;
|
||||
int numDisagree = 0;
|
||||
for (int i = 0; i < numModels; i++) {
|
||||
Model<Label> model = ensemble.getModels().get(i);
|
||||
Prediction<Label> pred = model.predict(vec);
|
||||
if (pred.getOutput().equals(label)) {
|
||||
numAgree += 1;
|
||||
}
|
||||
else {
|
||||
numDisagree += 1;
|
||||
}
|
||||
if (numAgree == (numModels + 1) / 2) {
|
||||
return true;
|
||||
}
|
||||
if (numDisagree == (numModels / 2) + 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
throw new AssertionError("did not return value");
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
/**
|
||||
* A member of this class is a row in a {@link FunctionStartAlignmentTableModel}
|
||||
*/
|
||||
public class FunctionStartAlignmentRowObject {
|
||||
|
||||
private long remainder;
|
||||
private long numFuncs;
|
||||
|
||||
/**
|
||||
* Creates a row showing how many functions whose entry points have
|
||||
* a remainder of {@code remainder} when divided by the alignment modulus.
|
||||
* @param remainder remainder after division by alignment modulus
|
||||
* @param numFuncs number of functions
|
||||
*/
|
||||
public FunctionStartAlignmentRowObject(long remainder, long numFuncs) {
|
||||
this.remainder = remainder;
|
||||
this.numFuncs = numFuncs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the remainder
|
||||
* @return remainder
|
||||
*/
|
||||
public long getRemainder() {
|
||||
return remainder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of functions
|
||||
* @return num funcs
|
||||
*/
|
||||
public long getNumFuncs() {
|
||||
return numFuncs;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import docking.widgets.table.AbstractSortedTableModel;
|
||||
|
||||
/**
|
||||
* A table used to show how many functions have addresses aligned (or not aligned) relative to a
|
||||
* given modulus.
|
||||
*/
|
||||
public class FunctionStartAlignmentTableModel
|
||||
extends AbstractSortedTableModel<FunctionStartAlignmentRowObject> {
|
||||
|
||||
private List<FunctionStartAlignmentRowObject> rows;
|
||||
|
||||
/**
|
||||
* Creates a table with the supplies rows
|
||||
* @param rows rows
|
||||
*/
|
||||
public FunctionStartAlignmentTableModel(List<FunctionStartAlignmentRowObject> rows) {
|
||||
this.rows = rows;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSortable(int columnIndex) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getColumnCount() {
|
||||
return 2;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getColumnName(int columnIndex) {
|
||||
switch (columnIndex) {
|
||||
case 0:
|
||||
return "Remainder";
|
||||
case 1:
|
||||
return "Number of Functions";
|
||||
default:
|
||||
throw new IllegalArgumentException("Invalid column index");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "Alignment";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionStartAlignmentRowObject> getModelData() {
|
||||
return rows;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getColumnValueForRow(FunctionStartAlignmentRowObject t, int columnIndex) {
|
||||
switch (columnIndex) {
|
||||
case 0:
|
||||
return t.getRemainder();
|
||||
case 1:
|
||||
return t.getNumFuncs();
|
||||
default:
|
||||
throw new IllegalArgumentException("Invalid column index");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.tribuo.Feature;
|
||||
import org.tribuo.Prediction;
|
||||
import org.tribuo.classification.Label;
|
||||
import org.tribuo.classification.LabelFactory;
|
||||
import org.tribuo.ensemble.EnsembleModel;
|
||||
import org.tribuo.impl.ArrayExample;
|
||||
|
||||
import generic.concurrent.QCallback;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* A {@link QCallback} which is used to apply a model at a given address to determine the
|
||||
* probability that the address represents a function start.
|
||||
*/
|
||||
class FunctionStartCallback implements QCallback<Address, Double> {
|
||||
|
||||
private EnsembleModel<Label> randomForest;
|
||||
private int numPreBytes;
|
||||
private int numInitialBytes;
|
||||
private Program program;
|
||||
private int alignment;
|
||||
private Label target;
|
||||
private boolean includeBitLevelFeatures;
|
||||
|
||||
/**
|
||||
* Creates a callback which applies {@code model} to addresses in {@code program} using
|
||||
* the data-gathering parameters {@code numPreBytes} and {@code numInitialBytes} to
|
||||
* determine the probability the address has label {@label}.
|
||||
* @param model model to apply
|
||||
* @param numPreBytes bytes before address to gather
|
||||
* @param numInitialBytes bytes after address to gather
|
||||
* @param includeBitLevelFeatures whether to include bit-level features
|
||||
* @param program source program
|
||||
* @param target target label
|
||||
*/
|
||||
public FunctionStartCallback(EnsembleModel<Label> model, int numPreBytes, int numInitialBytes,
|
||||
boolean includeBitLevelFeatures, Program program, Label target) {
|
||||
this.randomForest = model;
|
||||
this.numPreBytes = numPreBytes;
|
||||
this.numInitialBytes = numInitialBytes;
|
||||
this.program = program;
|
||||
alignment = program.getLanguage().getInstructionAlignment();
|
||||
this.target = target;
|
||||
this.includeBitLevelFeatures = includeBitLevelFeatures;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double process(Address item, TaskMonitor monitor) throws Exception {
|
||||
if (Long.remainderUnsigned(item.getOffset(), alignment) != 0) {
|
||||
return null;
|
||||
}
|
||||
List<Feature> vecToClassify = ModelTrainingUtils.getFeatureVector(program, item,
|
||||
numPreBytes, numInitialBytes, includeBitLevelFeatures);
|
||||
ArrayExample<Label> vec =
|
||||
new ArrayExample<>(LabelFactory.UNKNOWN_LABEL, vecToClassify);
|
||||
|
||||
if (vec.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
Prediction<Label> pred = randomForest.predict(vec);
|
||||
return pred.getOutputScores().get(target.getLabel()).getScore();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import org.tribuo.classification.Label;
|
||||
import org.tribuo.ensemble.EnsembleModel;
|
||||
|
||||
import generic.concurrent.*;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.address.AddressSetView;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.util.Msg;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* This class uses a {@link GThreadPool} to execute a {@link FunctionStartCallback} in
|
||||
* parallel in order to classify addresses in a program as function starts or non-starts.
|
||||
*/
|
||||
public class FunctionStartClassifier {
|
||||
|
||||
private static final double DEFAULT_PROB_THRESHOLD = 0.5d;
|
||||
|
||||
private Program program;
|
||||
private RandomForestRowObject modelRow;
|
||||
private double probabilityThreshold;
|
||||
private Label target;
|
||||
|
||||
/**
|
||||
* Creates an object used to apply the model in {@code modelRow} to addresses in {@code program}
|
||||
* to the probability of having label {@code target}.
|
||||
* @param program program to check
|
||||
* @param modelRow row containing model and data gathering parameters
|
||||
* @param target target addresses
|
||||
*/
|
||||
public FunctionStartClassifier(Program program, RandomForestRowObject modelRow, Label target) {
|
||||
this.program = program;
|
||||
this.modelRow = modelRow;
|
||||
probabilityThreshold = DEFAULT_PROB_THRESHOLD;
|
||||
this.target = target;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the probability threshold. Address where the probability of having the target
|
||||
* label is less than the threshold are not included in the map returned by
|
||||
* {@link FunctionStartClassifier#classify(AddressSetView, TaskMonitor)}
|
||||
* @param thresh new threshold
|
||||
*/
|
||||
public void setProbabilityThreshold(double thresh) {
|
||||
probabilityThreshold = thresh;
|
||||
}
|
||||
|
||||
/**
|
||||
* Classifies the addresses in {@code addresses} in parallel.
|
||||
* @param addresses addresses to classify
|
||||
* @param monitor monitor
|
||||
* @return map from addresses to probabilities
|
||||
* @throws CancelledException if monitor is canceled
|
||||
*/
|
||||
public Map<Address, Double> classify(AddressSetView addresses, TaskMonitor monitor)
|
||||
throws CancelledException {
|
||||
monitor.initialize(addresses.getNumAddresses());
|
||||
int preBytes = modelRow.getNumPreBytes();
|
||||
int initialBytes = modelRow.getNumInitialBytes();
|
||||
EnsembleModel<Label> randomForest = modelRow.getRandomForest();
|
||||
Msg.info(this, "Number of addresses to classify: " + addresses.getNumAddresses());
|
||||
GThreadPool threadPool = GThreadPool.getSharedThreadPool("FunctionStartClassifier");
|
||||
|
||||
ConcurrentQBuilder<Address, Double> classifyBuilder = new ConcurrentQBuilder<>();
|
||||
ConcurrentQ<Address, Double> classifyQ = classifyBuilder.setThreadPool(threadPool)
|
||||
.setCollectResults(true)
|
||||
.setMonitor(monitor)
|
||||
.build(new FunctionStartCallback(randomForest, preBytes, initialBytes,
|
||||
modelRow.getIncludeBitLevelFeatures(), program, target));
|
||||
classifyQ.add(addresses.getAddresses(true));
|
||||
Collection<QResult<Address, Double>> results = Collections.emptyList();
|
||||
long start = System.nanoTime();
|
||||
try {
|
||||
results = classifyQ.waitForResults();
|
||||
}
|
||||
catch (InterruptedException e) {
|
||||
monitor.checkCanceled();
|
||||
Msg.error(this, "Exception while classifying functions: " + e.getMessage());
|
||||
}
|
||||
long end = System.nanoTime();
|
||||
Msg.info(this, String.format("Classification time: %g seconds",
|
||||
(end - start) / RandomForestTrainingTask.NANOSECONDS_PER_SECOND));
|
||||
Map<Address, Double> addrToProb = new HashMap<>();
|
||||
for (QResult<Address, Double> result : results) {
|
||||
Double score = null;
|
||||
try {
|
||||
score = result.getResult();
|
||||
}
|
||||
catch (Exception e) {
|
||||
Msg.error(this, "Problem getting score of Address " + result.getItem() + ": " +
|
||||
e.getMessage());
|
||||
continue;
|
||||
}
|
||||
if (score != null && score >= probabilityThreshold) {
|
||||
addrToProb.put(result.getItem(), score);
|
||||
}
|
||||
}
|
||||
return addrToProb;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,347 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import ghidra.program.model.address.*;
|
||||
import ghidra.program.model.lang.Register;
|
||||
import ghidra.program.model.listing.*;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* This is a container class for the parameters that determine what data is collected for
|
||||
* training random forests to recognize function starts.
|
||||
*/
|
||||
|
||||
public class FunctionStartRFParams {
|
||||
|
||||
private List<Integer> preBytes; //number of bytes before a function start
|
||||
private List<Integer> initialBytes; //number of bytes after (and including) the first byte
|
||||
private List<Integer> samplingFactors; //how many non-starts to gather for each start gathered
|
||||
private int minFuncSize; //minimum size of a function to mine
|
||||
private int maxStarts; //maximum number of function starts to gather
|
||||
private Program trainingSource;
|
||||
private AddressSet funcEntries;
|
||||
private AddressSet funcInteriors;
|
||||
private int instructionAlignment;
|
||||
private boolean includePrecedingAndFollowing;
|
||||
private boolean includeBitFeatures;
|
||||
|
||||
//the following two lists are related: they must have the same size, and the ith
|
||||
//entry in contextRegisterVals is the value of the ith element of contextRegisterNames
|
||||
private List<String> contextRegisterNames; //names of context registers we care about
|
||||
private List<BigInteger> contextRegisterVals; //values of context registers we care about
|
||||
|
||||
/**
|
||||
* Constructs a FunctionStartRFParams object, a container object for data gathering
|
||||
* parameters to use when training a random forest to recognize function starts.
|
||||
*
|
||||
* <p> Note that {@code randomSeed} is initialized to a random value.
|
||||
*
|
||||
* <p>Use setter methods to set the fields.
|
||||
* @param trainingSource source program
|
||||
*/
|
||||
public FunctionStartRFParams(Program trainingSource) {
|
||||
this.trainingSource = trainingSource;
|
||||
preBytes = Collections.emptyList();
|
||||
initialBytes = Collections.emptyList();
|
||||
samplingFactors = Collections.emptyList();
|
||||
contextRegisterNames = Collections.emptyList();
|
||||
contextRegisterVals = Collections.emptyList();
|
||||
instructionAlignment = trainingSource.getLanguage().getInstructionAlignment();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return the number of bytes to gather before an address
|
||||
*/
|
||||
public List<Integer> getPreBytes() {
|
||||
return preBytes;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param preBytes the number of bytes to gather before an address
|
||||
*/
|
||||
public void setPreBytes(List<Integer> preBytes) {
|
||||
this.preBytes = preBytes;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return the number of bytes to gather after (and including) an address
|
||||
*/
|
||||
public List<Integer> getInitialBytes() {
|
||||
return initialBytes;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param initialBytes the number of bytes to gather after (and including) an address
|
||||
*/
|
||||
public void setInitialBytes(List<Integer> initialBytes) {
|
||||
this.initialBytes = initialBytes;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return the minimum size a function must be to have its data gathered
|
||||
*/
|
||||
public int getMinFuncSize() {
|
||||
return minFuncSize;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param minFuncSize the minimum size a function must be to have its data gathered
|
||||
*/
|
||||
public void setMinFuncSize(int minFuncSize) {
|
||||
this.minFuncSize = minFuncSize;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return the maximum number of function starts to gather
|
||||
*/
|
||||
public int getMaxStarts() {
|
||||
return maxStarts;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param max the maximum number of function starts to gather
|
||||
*/
|
||||
public void setMaxStarts(int max) {
|
||||
maxStarts = max;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return the number of non-starts to gather per function start
|
||||
*/
|
||||
public List<Integer> getSamplingFactors() {
|
||||
return samplingFactors;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param factors the number of non-starts to gather per function start
|
||||
*/
|
||||
public void setFactors(List<Integer> factors) {
|
||||
samplingFactors = factors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true precisely when there is a least one context register value set.
|
||||
* @return true if there are any context register values set
|
||||
*/
|
||||
public boolean isRestrictedByContext() {
|
||||
return !contextRegisterNames.isEmpty();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return the list of names of context registers to set before disassembly
|
||||
*/
|
||||
public List<String> getContextRegisterNames() {
|
||||
return contextRegisterNames;
|
||||
}
|
||||
|
||||
/**
|
||||
* The values to assign to the context registers.
|
||||
* @return context register values
|
||||
*/
|
||||
public List<BigInteger> getContextRegisterVals() {
|
||||
return contextRegisterVals;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses register,value pairs if the form creg1=x,creg2=from csv and stores them. Any
|
||||
* existing register,value pairs are discarded.
|
||||
*
|
||||
* @param csv the list to parse
|
||||
* @throws IllegalArgumentException if there are any parsing errors
|
||||
*/
|
||||
public void setRegistersAndValues(String csv) {
|
||||
contextRegisterNames = new ArrayList<>();
|
||||
contextRegisterVals = new ArrayList<>();
|
||||
String[] parts = csv.split(",");
|
||||
for (String part : parts) {
|
||||
String[] regValPair = part.split("=");
|
||||
if (regValPair.length != 2) {
|
||||
contextRegisterNames.clear();
|
||||
contextRegisterVals.clear();
|
||||
throw new IllegalArgumentException("Error parsing register=value string " + part);
|
||||
}
|
||||
String regName = regValPair[0].trim();
|
||||
if (trainingSource.getRegister(regName) == null) {
|
||||
contextRegisterNames.clear();
|
||||
contextRegisterVals.clear();
|
||||
throw new IllegalArgumentException(
|
||||
"Register " + regName + " not found for program " + trainingSource.getName());
|
||||
}
|
||||
contextRegisterNames.add(regName);
|
||||
BigInteger bigInt = new BigInteger(regValPair[1].trim());
|
||||
contextRegisterVals.add(bigInt);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a CSV into a sorted list of distinct integer values (duplicates are ignored). Returns
|
||||
* an empty list of a parse error is encountered.
|
||||
* @param csv csv string to parse
|
||||
* @return sorted list
|
||||
*/
|
||||
public static List<Integer> parseIntegerCSV(String csv) {
|
||||
if (StringUtils.isBlank(csv)) {
|
||||
throw new IllegalArgumentException("Entry cannot be blank");
|
||||
}
|
||||
String trimmed = csv.trim();
|
||||
if (trimmed.startsWith(",") || trimmed.endsWith(",")) {
|
||||
throw new IllegalArgumentException("String must not begin or end with a comma");
|
||||
}
|
||||
Set<Integer> results = new HashSet<>();
|
||||
String[] parts = trimmed.split(",");
|
||||
for (String part : parts) {
|
||||
Integer i = Integer.decode(part.trim());
|
||||
if (i < 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid element " + part + " - must be non-negative");
|
||||
}
|
||||
results.add(i);
|
||||
}
|
||||
return results.stream().sorted().collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link AddressSet} of function entries in the source program.
|
||||
* <P>
|
||||
* NB: Invoke {@link FunctionStartRFParams#computeFuncEntriesAndInteriors} before
|
||||
* invoking this method.
|
||||
* @return set of entries
|
||||
*/
|
||||
public AddressSet getFuncEntries() {
|
||||
return funcEntries;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link AddressSet} of function interiors in the source program.
|
||||
* <P>
|
||||
* NB: Invoke {@link FunctionStartRFParams#computeFuncEntriesAndInteriors} before
|
||||
* invoking this method.
|
||||
* @return set of interiors
|
||||
*/
|
||||
public AddressSet getFuncInteriors() {
|
||||
return funcInteriors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns boolean indicating whether code units immediately preceding and
|
||||
* following a function start should be included in the training set.
|
||||
* @return include preceding and following
|
||||
*/
|
||||
public boolean getIncludePrecedingAndFollowing() {
|
||||
return includePrecedingAndFollowing;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets boolean indicating whether code units immediately preceding and
|
||||
* following a function start should be included in the training set.
|
||||
* @param b new value
|
||||
*/
|
||||
public void setIncludePrecedingAndFollowing(boolean b) {
|
||||
includePrecedingAndFollowing = b;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns boolean indicating whether to include bit-level features in the feature vectors.
|
||||
* @return include bit level features
|
||||
*/
|
||||
public boolean getIncludeBitFeatures() {
|
||||
return includeBitFeatures;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets boolean indicating whether to include bit-level features in the feature vectors.
|
||||
* @param b new value
|
||||
*/
|
||||
public void setIncludeBitFeatures(boolean b) {
|
||||
includeBitFeatures = b;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the {@link AddressSet}s of function entries and bodies in the source program.
|
||||
* Retrieve these sets via {@link FunctionStartRFParams#getFuncEntries()} and
|
||||
* {@link FunctionStartRFParams#getFuncInteriors()}
|
||||
* <p> Note: the interior of a function only contains addresses which are aligned relative
|
||||
* to the instruction alignment of the processor
|
||||
* @param monitor task monitor
|
||||
* @throws CancelledException if monitor is canceled
|
||||
*/
|
||||
public void computeFuncEntriesAndInteriors(TaskMonitor monitor) throws CancelledException {
|
||||
FunctionIterator fIter = trainingSource.getFunctionManager().getFunctions(true);
|
||||
monitor.initialize(trainingSource.getFunctionManager().getFunctionCount());
|
||||
funcEntries = new AddressSet();
|
||||
funcInteriors = new AddressSet();
|
||||
while (fIter.hasNext()) {
|
||||
monitor.checkCanceled();
|
||||
Function func = fIter.next();
|
||||
monitor.incrementProgress(1);
|
||||
if (func.getBody().getNumAddresses() < minFuncSize) {
|
||||
continue;
|
||||
}
|
||||
if (isRestrictedByContext()) {
|
||||
if (!isContextCompatible(func.getEntryPoint())) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
funcEntries.add(func.getEntryPoint());
|
||||
AddressSet body = func.getBody().subtract(new AddressSet(func.getEntryPoint()));
|
||||
AddressIterator addrIter = body.getAddresses(true);
|
||||
while (addrIter.hasNext()) {
|
||||
Address addr = addrIter.next();
|
||||
if (addr.getOffset() % instructionAlignment == 0) {
|
||||
funcInteriors.add(addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether {@code addr} is consistent with context register values
|
||||
* supplied via {@link FunctionStartRFParams#setRegistersAndValues(String)}
|
||||
* @param addr address to check
|
||||
* @return is consistent with context regs
|
||||
*/
|
||||
public boolean isContextCompatible(Address addr) {
|
||||
ProgramContext context = trainingSource.getProgramContext();
|
||||
for (int i = 0; i < contextRegisterNames.size(); i++) {
|
||||
Register reg = context.getRegister(contextRegisterNames.get(i));
|
||||
BigInteger val = context.getValue(reg, addr, false);
|
||||
if (!val.equals(contextRegisterVals.get(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,654 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.awt.BorderLayout;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.LongStream;
|
||||
|
||||
import javax.swing.*;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import docking.DialogComponentProvider;
|
||||
import docking.action.DockingAction;
|
||||
import docking.action.builder.ActionBuilder;
|
||||
import docking.widgets.combobox.GComboBox;
|
||||
import docking.widgets.label.GDLabel;
|
||||
import docking.widgets.table.GTable;
|
||||
import docking.widgets.table.threaded.GThreadedTablePanel;
|
||||
import docking.widgets.textfield.IntegerTextField;
|
||||
import ghidra.app.services.ProgramManager;
|
||||
import ghidra.framework.main.DataTreeDialog;
|
||||
import ghidra.framework.preferences.Preferences;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.listing.*;
|
||||
import ghidra.util.HelpLocation;
|
||||
import ghidra.util.Msg;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.exception.VersionException;
|
||||
import ghidra.util.layout.PairLayout;
|
||||
import ghidra.util.table.SelectionNavigationAction;
|
||||
import ghidra.util.table.actions.MakeProgramSelectionAction;
|
||||
import ghidra.util.task.*;
|
||||
|
||||
/**
|
||||
* This class creates a dialog window for the user to enter data mining parameters
|
||||
* for learning function starts, train models, see performance statistics, and
|
||||
* apply the models.
|
||||
*/
|
||||
public class FunctionStartRFParamsDialog extends DialogComponentProvider {
|
||||
|
||||
private static final String INITIAL_BYTES_TEXT = "Number of Initial Bytes (CSV)";
|
||||
private static final String INITIAL_BYTES_TIP =
|
||||
"Number of initial bytes of a function to record";
|
||||
|
||||
private static final String PRE_BYTES_TEXT = "Number of Pre-bytes (CSV)";
|
||||
private static final String PRE_BYTES_TIP =
|
||||
"Number of bytes immediately before a function start to record";
|
||||
|
||||
private static final String MIN_FUNC_SIZE_TEXT = "Minimum Function Size";
|
||||
private static final String MIN_FUNC_SIZE_TIP =
|
||||
"Functions whose size in bytes are below this number are skipped";
|
||||
|
||||
private static final String CONTEXT_REGISTER_TEXT = "Context Registers and Values (CSV)";
|
||||
private static final String CONTEXT_REGISTER_TIP =
|
||||
"Restrict gathering to functions where context registers have been set." +
|
||||
" Form: cReg1=x,cReg2=y,...";
|
||||
|
||||
private static final String FACTOR_TEXT = "Start to Non-start Sampling Factors (CSV)";
|
||||
private static final String FACTOR_TIP =
|
||||
"Number of non-starts to gather for each function start";
|
||||
|
||||
private static final String MAX_STARTS_TEXT = "Maximum Number of Starts";
|
||||
private static final String MAX_STARTS_TIP = "Maximum number of function starts to gather";
|
||||
|
||||
private static final String INCLUDE_PRECEDING_FOLLOWING_TEXT =
|
||||
"Include Preceding and Following";
|
||||
private static final String INCLUDE_PRECEDING_FOLLOWING_TIP =
|
||||
"Include code units immediately before and after a function start when testing and training";
|
||||
|
||||
private static final String INCLUDE_BIT_FEATURES_TEXT = "Include Bit Features";
|
||||
private static final String INCLUDE_BIT_FEATURES_TIP =
|
||||
"Include bit-level features. May improve models; will increase computation time.";
|
||||
|
||||
private static final String FUNCTIONS_MEETING_SIZE_TEXT = "Functions Meeting Size Bound";
|
||||
private static final String FUNCTIONS_MEETING_SIZE_TIP =
|
||||
"Number of functions meeting the size " + "bound";
|
||||
|
||||
private static final String RESTRICT_SEARCH_TEXT = "Restrict Search To Aligned Addresses ";
|
||||
private static final String RESTRICT_SEARCH_TIP =
|
||||
"Only apply model to aligned addresses. NOTE:" + "Does not affect training or test sets!";
|
||||
|
||||
private static final String ALIGNMENT_MODULUS_TEXT = "Alignment Modulus";
|
||||
private static final String ALIGNMENT_MODULUS_TIP =
|
||||
"Use to define the alignment for restricted search";
|
||||
|
||||
private static final String DEFAULT_INITIAL_BYTES = "8,16";
|
||||
private static final String INITIAL_BYTES_PROPERTY = "functionStartRFParams_initialBytes";
|
||||
private static final String DEFAULT_PRE_BYTES = "2,8";
|
||||
private static final String PRE_BYTES_PROPERTY = "functionStartRFParams_preBytes";
|
||||
private static final String DEFAULT_MINIMUM_FUNCTION_SIZE = "16";
|
||||
private static final String MIN_FUNC_SIZE_PROPERTY = "functionStartRFParams_minFuncSize";
|
||||
private static final String DEFAULT_CONTEXT_REGISTERS = "";
|
||||
private static final String CONTEXT_REGISTER_PROPERTY = "functionStartRFParams_cRegs";
|
||||
private static final String DEFAULT_FACTOR = "10,50";
|
||||
private static final String FACTOR_PROPERTY = "functionStartRFParams_factor";
|
||||
private static final String DEFAULT_MAX_STARTS = "1000";
|
||||
private static final String MAX_STARTS_PROPERTY = "functionStartRFParams_maxStarts";
|
||||
private static final String INCLUDE_PRECEDING_AND_FOLLOWING_PROPERTY =
|
||||
"functionStartRFParams_use_pf";
|
||||
private static final String DEFAULT_INCLUDE_PF = "True";
|
||||
private static final String INCLUDE_BIT_FEATURES_PROPERTY =
|
||||
"funcstionStartRFParams_includeBitFeatures";
|
||||
private static final String DEFAULT_INCLUDE_BIT_FEATURES = "False";
|
||||
|
||||
private static final String DATA_GATHERING_PARAMETERS = "Data Gathering Parameters";
|
||||
private static final String MODEL_STATISTICS = "Model Statistics";
|
||||
private static final String TITLE = "Random Forest Function Finder";
|
||||
private static final String FUNCTION_INFO = "Function Information";
|
||||
|
||||
private static final String APPLY_MODEL_ACTION_NAME = "ApplyModel";
|
||||
private static final String APPLY_MODEL_MENU_TEXT = "Apply Model";
|
||||
private static final String APPLY_MODEL_TO_ACTION_NAME = "ApplyModelTo";
|
||||
private static final String APPLY_MODEL_TO_MENU_TEXT = "Apply Model To...";
|
||||
private static final String DEBUG_MODEL_ACTION_NAME = "DebugModel";
|
||||
private static final String DEBUG_MODEL_MENU_TEXT = "DEBUG - Show test set errors";
|
||||
|
||||
private JTextField initialBytesField;
|
||||
private JTextField preBytesField;
|
||||
private JTextField factorField;
|
||||
private IntegerTextField minimumSizeField;
|
||||
private IntegerTextField maxStartsField;
|
||||
private JTextField contextRegistersField;
|
||||
private JLabel numFuncsField;
|
||||
private JScrollPane tableScrollPane;
|
||||
private JPanel funcInfoPanel;
|
||||
|
||||
private RandomForestFunctionFinderPlugin plugin;
|
||||
private List<RandomForestRowObject> rowObjects;
|
||||
private RandomForestTableModel tableModel;
|
||||
private Program trainingSource;
|
||||
private FunctionStartRFParams params;
|
||||
private Set<Program> openPrograms;
|
||||
private Vector<Long> moduli = new Vector<>(Arrays.asList(new Long[] { 4l, 8l, 16l, 32l }));
|
||||
private GComboBox<Long> modBox;
|
||||
private JButton trainButton;
|
||||
private JCheckBox includeBeforeAndAfterBox;
|
||||
private JCheckBox includeBitFeaturesBox;
|
||||
private JCheckBox restrictBox;
|
||||
private JButton restoreDefaultsButton;
|
||||
|
||||
/**
|
||||
* Creates a dialog for training models to find function starts using the
|
||||
* current program of {@code plugin}.
|
||||
* @param plugin plugin owning this dialog
|
||||
*/
|
||||
public FunctionStartRFParamsDialog(RandomForestFunctionFinderPlugin plugin) {
|
||||
super(TITLE + ": " + plugin.getCurrentProgram().getDomainFile().getPathname(), false, true,
|
||||
true, true);
|
||||
this.plugin = plugin;
|
||||
rowObjects = new ArrayList<>();
|
||||
openPrograms = new HashSet<>();
|
||||
trainingSource = plugin.getCurrentProgram();
|
||||
JPanel panel = createPanel();
|
||||
addWorkPanel(panel);
|
||||
trainButton = addTrainModelsButton();
|
||||
addHideDialogButton();
|
||||
addRestoreDefaultsButton();
|
||||
setHelpLocation(new HelpLocation(plugin.getName(), plugin.getName()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void taskCompleted(Task task) {
|
||||
super.taskCompleted(task);
|
||||
setStatusText("Training Completed");
|
||||
setEnabled(true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void taskCancelled(Task task) {
|
||||
super.taskCancelled(task);
|
||||
setStatusText("Training Canceled");
|
||||
setEnabled(true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the program used to train the models
|
||||
* @return source program
|
||||
*/
|
||||
Program getTrainingSource() {
|
||||
return trainingSource;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void dismissCallback() {
|
||||
TaskMonitorComponent monitorComp = getTaskMonitorComponent();
|
||||
if (monitorComp != null) {
|
||||
monitorComp.cancel();
|
||||
}
|
||||
setStatusText("");
|
||||
//rows in the table can have large objects as fields
|
||||
//make sure that memory is reclaimed
|
||||
tableModel.dispose();
|
||||
rowObjects.clear();
|
||||
plugin.resetDialog();
|
||||
dispose();
|
||||
}
|
||||
|
||||
private FunctionStartRFParams getMachineLearningParams() {
|
||||
FunctionStartRFParams rfParams = new FunctionStartRFParams(trainingSource);
|
||||
List<Integer> preBytes = FunctionStartRFParams.parseIntegerCSV(preBytesField.getText());
|
||||
rfParams.setPreBytes(preBytes);
|
||||
List<Integer> initialBytes =
|
||||
FunctionStartRFParams.parseIntegerCSV(initialBytesField.getText());
|
||||
rfParams.setInitialBytes(initialBytes);
|
||||
List<Integer> factors = FunctionStartRFParams.parseIntegerCSV(factorField.getText());
|
||||
rfParams.setFactors(factors);
|
||||
int minSize = minimumSizeField.getIntValue();
|
||||
if (minSize <= 0) {
|
||||
Msg.showWarn(this, null, "Invalid Minimum Size", "Minimum size must be positive!");
|
||||
return null;
|
||||
}
|
||||
rfParams.setMinFuncSize(minSize);
|
||||
|
||||
int maxStarts = maxStartsField.getIntValue();
|
||||
if (maxStarts <= 0) {
|
||||
Msg.showWarn(this, null, "Invalid Max Starts", "Max Starts must be positive!");
|
||||
return null;
|
||||
}
|
||||
rfParams.setMaxStarts(maxStarts);
|
||||
|
||||
String csv = contextRegistersField.getText();
|
||||
if (!StringUtils.isBlank(csv)) {
|
||||
try {
|
||||
rfParams.setRegistersAndValues(csv);
|
||||
}
|
||||
catch (IllegalArgumentException e) {
|
||||
Msg.showError(factors, null, "Context Register/Value Error", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
rfParams.setIncludePrecedingAndFollowing(includeBeforeAndAfterBox.isSelected());
|
||||
rfParams.setIncludeBitFeatures(includeBitFeaturesBox.isSelected());
|
||||
setProperties();
|
||||
return rfParams;
|
||||
}
|
||||
|
||||
private void trainModelsCallback() {
|
||||
rowObjects.clear();
|
||||
tableModel.reload();
|
||||
params = getMachineLearningParams();
|
||||
if (params == null) {
|
||||
return;
|
||||
}
|
||||
RandomForestTrainingTask trainingTask = new RandomForestTrainingTask(trainingSource, params,
|
||||
r -> tableModel.addObject(r), plugin.getTestMaxSize());
|
||||
trainingTask.addTaskListener(this);
|
||||
setEnabled(false);
|
||||
executeProgressTask(trainingTask, 500);
|
||||
}
|
||||
|
||||
private JButton addTrainModelsButton() {
|
||||
JButton trainModelsButton = new JButton("Train");
|
||||
trainModelsButton.setToolTipText("Train models using the specified parameters");
|
||||
trainModelsButton.addActionListener(e -> trainModelsCallback());
|
||||
addButton(trainModelsButton);
|
||||
return trainModelsButton;
|
||||
}
|
||||
|
||||
private JPanel createPanel() {
|
||||
JPanel mainPanel = new JPanel(new BorderLayout());
|
||||
|
||||
tableModel = new RandomForestTableModel(plugin.getTool(), rowObjects);
|
||||
GThreadedTablePanel<RandomForestRowObject> evalPanel =
|
||||
new GThreadedTablePanel<>(tableModel);
|
||||
GTable modelStatsTable = evalPanel.getTable();
|
||||
modelStatsTable.setSelectionMode(ListSelectionModel.SINGLE_SELECTION);
|
||||
evalPanel.setBorder(BorderFactory.createTitledBorder(MODEL_STATISTICS));
|
||||
mainPanel.add(evalPanel, BorderLayout.EAST);
|
||||
|
||||
DockingAction applyAction = new ActionBuilder(APPLY_MODEL_ACTION_NAME, plugin.getName())
|
||||
.description("Apply Model to Source Program")
|
||||
.popupWhen(c -> trainingSource != null)
|
||||
.enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1)
|
||||
.popupMenuPath(APPLY_MODEL_MENU_TEXT)
|
||||
.inWindow(ActionBuilder.When.ALWAYS)
|
||||
.onAction(c -> {
|
||||
searchTrainingProgram(tableModel.getLastSelectedObjects().get(0));
|
||||
})
|
||||
.build();
|
||||
addAction(applyAction);
|
||||
|
||||
DockingAction applyToAction =
|
||||
new ActionBuilder(APPLY_MODEL_TO_ACTION_NAME, plugin.getName())
|
||||
.description("Choose Program and Apply Model to it")
|
||||
.popupWhen(c -> trainingSource != null)
|
||||
.enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1)
|
||||
.popupMenuPath(APPLY_MODEL_TO_MENU_TEXT)
|
||||
.inWindow(ActionBuilder.When.ALWAYS)
|
||||
.onAction(c -> {
|
||||
searchOtherProgram(tableModel.getLastSelectedObjects().get(0));
|
||||
})
|
||||
.build();
|
||||
addAction(applyToAction);
|
||||
|
||||
DockingAction checkAction = new ActionBuilder(DEBUG_MODEL_ACTION_NAME, plugin.getName())
|
||||
.description("Show Test Set Errors")
|
||||
.popupWhen(c -> trainingSource != null)
|
||||
.enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1)
|
||||
.popupMenuPath(DEBUG_MODEL_MENU_TEXT)
|
||||
.inWindow(ActionBuilder.When.ALWAYS)
|
||||
.onAction(c -> {
|
||||
showTestErrors(tableModel.getLastSelectedObjects().get(0));
|
||||
})
|
||||
.build();
|
||||
addAction(checkAction);
|
||||
|
||||
JPanel paramsPanel = new JPanel();
|
||||
paramsPanel.setBorder(BorderFactory.createTitledBorder(DATA_GATHERING_PARAMETERS));
|
||||
PairLayout pairLayout = new PairLayout();
|
||||
paramsPanel.setLayout(pairLayout);
|
||||
|
||||
JLabel preLabel = new GDLabel(PRE_BYTES_TEXT);
|
||||
preLabel.setToolTipText(PRE_BYTES_TIP);
|
||||
paramsPanel.add(preLabel);
|
||||
preBytesField = new JTextField();
|
||||
String preBytes = Preferences.getProperty(PRE_BYTES_PROPERTY, DEFAULT_PRE_BYTES);
|
||||
preBytesField.setText(preBytes);
|
||||
paramsPanel.add(preBytesField);
|
||||
|
||||
JLabel initialLabel = new GDLabel(INITIAL_BYTES_TEXT);
|
||||
initialLabel.setToolTipText(INITIAL_BYTES_TIP);
|
||||
paramsPanel.add(initialLabel);
|
||||
initialBytesField = new JTextField();
|
||||
String initialBytes =
|
||||
Preferences.getProperty(INITIAL_BYTES_PROPERTY, DEFAULT_INITIAL_BYTES);
|
||||
initialBytesField.setText(initialBytes);
|
||||
paramsPanel.add(initialBytesField);
|
||||
|
||||
JLabel factorLabel = new GDLabel(FACTOR_TEXT);
|
||||
factorLabel.setToolTipText(FACTOR_TIP);
|
||||
paramsPanel.add(factorLabel);
|
||||
factorField = new JTextField();
|
||||
String factor = Preferences.getProperty(FACTOR_PROPERTY, DEFAULT_FACTOR);
|
||||
factorField.setText(factor);
|
||||
paramsPanel.add(factorField);
|
||||
|
||||
JLabel maxStartsLabel = new GDLabel(MAX_STARTS_TEXT);
|
||||
maxStartsLabel.setToolTipText(MAX_STARTS_TIP);
|
||||
paramsPanel.add(maxStartsLabel);
|
||||
maxStartsField = new IntegerTextField();
|
||||
String maxStarts = Preferences.getProperty(MAX_STARTS_PROPERTY, DEFAULT_MAX_STARTS);
|
||||
maxStartsField.setValue(Integer.parseInt(maxStarts));
|
||||
paramsPanel.add(maxStartsField.getComponent());
|
||||
|
||||
JLabel contextLabel = new GDLabel(CONTEXT_REGISTER_TEXT);
|
||||
contextLabel.setToolTipText(CONTEXT_REGISTER_TIP);
|
||||
paramsPanel.add(contextLabel);
|
||||
contextRegistersField = new JTextField(DEFAULT_CONTEXT_REGISTERS);
|
||||
String cRegs = Preferences.getProperty(CONTEXT_REGISTER_PROPERTY, "");
|
||||
if (!StringUtils.isEmpty(cRegs)) {
|
||||
contextRegistersField.setText(cRegs);
|
||||
}
|
||||
paramsPanel.add(contextRegistersField);
|
||||
|
||||
JLabel includeBeforeAndAfterLabel = new JLabel(INCLUDE_PRECEDING_FOLLOWING_TEXT);
|
||||
includeBeforeAndAfterLabel.setToolTipText(INCLUDE_PRECEDING_FOLLOWING_TIP);
|
||||
includeBeforeAndAfterBox = new JCheckBox();
|
||||
String defaultUseBeforeAfter =
|
||||
Preferences.getProperty(INCLUDE_PRECEDING_AND_FOLLOWING_PROPERTY, DEFAULT_INCLUDE_PF);
|
||||
includeBeforeAndAfterBox.setSelected(Boolean.valueOf(defaultUseBeforeAfter));
|
||||
paramsPanel.add(includeBeforeAndAfterLabel);
|
||||
paramsPanel.add(includeBeforeAndAfterBox);
|
||||
|
||||
JLabel includeSelectionLabel = new JLabel(INCLUDE_BIT_FEATURES_TEXT);
|
||||
includeSelectionLabel.setToolTipText(INCLUDE_BIT_FEATURES_TIP);
|
||||
includeBitFeaturesBox = new JCheckBox();
|
||||
String defaultIncludeBitFeatures =
|
||||
Preferences.getProperty(INCLUDE_BIT_FEATURES_PROPERTY, DEFAULT_INCLUDE_BIT_FEATURES);
|
||||
includeBitFeaturesBox.setSelected(Boolean.valueOf(defaultIncludeBitFeatures));
|
||||
paramsPanel.add(includeSelectionLabel);
|
||||
paramsPanel.add(includeBitFeaturesBox);
|
||||
|
||||
JLabel minFuncLabel = new GDLabel(MIN_FUNC_SIZE_TEXT);
|
||||
minFuncLabel.setToolTipText(MIN_FUNC_SIZE_TIP);
|
||||
paramsPanel.add(minFuncLabel);
|
||||
minimumSizeField = new IntegerTextField();
|
||||
String minSize =
|
||||
Preferences.getProperty(MIN_FUNC_SIZE_PROPERTY, DEFAULT_MINIMUM_FUNCTION_SIZE);
|
||||
minimumSizeField.setValue(Integer.parseInt(minSize));
|
||||
minimumSizeField.addChangeListener(e -> {
|
||||
updateNumFuncsField();
|
||||
updateModulusTable();
|
||||
});
|
||||
paramsPanel.add(minimumSizeField.getComponent());
|
||||
|
||||
JPanel funcDataPanel = new JPanel();
|
||||
pairLayout = new PairLayout();
|
||||
funcDataPanel.setLayout(pairLayout);
|
||||
|
||||
JLabel numFuncsLabel = new GDLabel(FUNCTIONS_MEETING_SIZE_TEXT);
|
||||
numFuncsLabel.setToolTipText(FUNCTIONS_MEETING_SIZE_TIP);
|
||||
funcDataPanel.add(numFuncsLabel);
|
||||
numFuncsField = new GDLabel();
|
||||
updateNumFuncsField();
|
||||
funcDataPanel.add(numFuncsField);
|
||||
|
||||
JLabel restrictLabel = new GDLabel(RESTRICT_SEARCH_TEXT);
|
||||
restrictLabel.setToolTipText(RESTRICT_SEARCH_TIP);
|
||||
funcDataPanel.add(restrictLabel);
|
||||
restrictBox = new JCheckBox();
|
||||
funcDataPanel.add(restrictBox);
|
||||
|
||||
JLabel modulusLabel = new GDLabel(ALIGNMENT_MODULUS_TEXT);
|
||||
modulusLabel.setToolTipText(ALIGNMENT_MODULUS_TIP);
|
||||
funcDataPanel.add(modulusLabel);
|
||||
modBox = new GComboBox<>(moduli);
|
||||
modBox.setSelectedItem(Long.valueOf(16));
|
||||
modBox.addActionListener(e -> updateModulusTable());
|
||||
funcDataPanel.add(modBox);
|
||||
|
||||
tableScrollPane = getFuncAlignmentScrollPane();
|
||||
|
||||
funcInfoPanel = new JPanel();
|
||||
funcInfoPanel.setLayout(new BorderLayout());
|
||||
funcInfoPanel.setBorder(BorderFactory.createTitledBorder(FUNCTION_INFO));
|
||||
funcInfoPanel.add(funcDataPanel, BorderLayout.NORTH);
|
||||
funcInfoPanel.add(tableScrollPane, BorderLayout.CENTER);
|
||||
|
||||
JPanel infoPanel = new JPanel(new BorderLayout());
|
||||
infoPanel.add(paramsPanel, BorderLayout.NORTH);
|
||||
infoPanel.add(funcInfoPanel, BorderLayout.CENTER);
|
||||
mainPanel.add(infoPanel, BorderLayout.WEST);
|
||||
return mainPanel;
|
||||
}
|
||||
|
||||
private JScrollPane getFuncAlignmentScrollPane() {
|
||||
Long modulus = (Long) modBox.getSelectedItem();
|
||||
int minSize = minimumSizeField.getIntValue();
|
||||
//initialize map
|
||||
Map<Long, Long> countMap =
|
||||
LongStream.range(0, modulus).boxed().collect(Collectors.toMap(i -> i, i -> 0l));
|
||||
FunctionIterator fIter = trainingSource.getFunctionManager().getFunctionsNoStubs(true);
|
||||
//needs some thought to determine how to display number of functions compatible
|
||||
//with context register restrictions
|
||||
for (Function f : fIter) {
|
||||
if ((f.getBody().getNumAddresses() >= minSize) &&
|
||||
(params == null || params.isContextCompatible(f.getEntryPoint()))) {
|
||||
countMap.merge(f.getEntryPoint().getOffset() % modulus, 1l, Long::sum);
|
||||
}
|
||||
}
|
||||
List<FunctionStartAlignmentRowObject> rows = countMap.entrySet()
|
||||
.stream()
|
||||
.map(e -> new FunctionStartAlignmentRowObject(e.getKey(), e.getValue()))
|
||||
.collect(Collectors.toList());
|
||||
FunctionStartAlignmentTableModel alignModel = new FunctionStartAlignmentTableModel(rows);
|
||||
GTable alignTable = new GTable(alignModel);
|
||||
return new JScrollPane(alignTable);
|
||||
}
|
||||
|
||||
private void updateModulusTable() {
|
||||
funcInfoPanel.remove(tableScrollPane);
|
||||
tableScrollPane = getFuncAlignmentScrollPane();
|
||||
|
||||
funcInfoPanel.add(tableScrollPane);
|
||||
getComponent().updateUI();
|
||||
}
|
||||
|
||||
private void updateNumFuncsField() {
|
||||
int numFuncs = 0;
|
||||
long bound = minimumSizeField.getLongValue();
|
||||
for (Function func : trainingSource.getFunctionManager().getFunctionsNoStubs(true)) {
|
||||
if (func.getBody().getNumAddresses() >= bound) {
|
||||
numFuncs += 1;
|
||||
}
|
||||
}
|
||||
numFuncsField.setText(Integer.toString(numFuncs));
|
||||
}
|
||||
|
||||
private void searchTrainingProgram(RandomForestRowObject modelRow) {
|
||||
searchProgram(trainingSource, modelRow);
|
||||
}
|
||||
|
||||
private void searchOtherProgram(RandomForestRowObject modelRow) {
|
||||
Program p = selectProgram();
|
||||
if (p == null) {
|
||||
return;
|
||||
}
|
||||
ProgramManager pm = plugin.getTool().getService(ProgramManager.class);
|
||||
pm.openProgram(p, ProgramManager.OPEN_VISIBLE);
|
||||
searchProgram(p, modelRow);
|
||||
}
|
||||
|
||||
private void showTestErrors(RandomForestRowObject modelRow) {
|
||||
FunctionStartTableProvider provider = new FunctionStartTableProvider(plugin, trainingSource,
|
||||
modelRow.getTestErrors(), modelRow, true);
|
||||
addGeneralActions(provider);
|
||||
}
|
||||
|
||||
private void searchProgram(Program prog, RandomForestRowObject modelRow) {
|
||||
GetAddressesToClassifyTask getTask =
|
||||
new GetAddressesToClassifyTask(prog, plugin.getMinUndefinedRangeSize());
|
||||
//don't want to use the dialog's progress bar
|
||||
TaskLauncher.launchModal("Gathering Addresses To Classify", getTask);
|
||||
if (getTask.isCancelled()) {
|
||||
return;
|
||||
}
|
||||
AddressSet execNonFunc = null;
|
||||
if (restrictBox.isSelected()) {
|
||||
execNonFunc = getTask.getAddressesToClassify((long) modBox.getSelectedItem());
|
||||
}
|
||||
else {
|
||||
execNonFunc = getTask.getAddressesToClassify();
|
||||
}
|
||||
FunctionStartTableProvider provider =
|
||||
new FunctionStartTableProvider(plugin, prog, execNonFunc, modelRow, false);
|
||||
addGeneralActions(provider);
|
||||
DisassembleFunctionStartsAction disassembleAction = null;
|
||||
if (params.isRestrictedByContext()) {
|
||||
disassembleAction = new DisassembleAndApplyContextAction(plugin, prog,
|
||||
provider.getTable(), provider.getTableModel());
|
||||
}
|
||||
else {
|
||||
disassembleAction = new DisassembleFunctionStartsAction(plugin, prog,
|
||||
provider.getTable(), provider.getTableModel());
|
||||
}
|
||||
plugin.getTool().addLocalAction(provider, disassembleAction);
|
||||
CreateFunctionsAction createActions =
|
||||
new CreateFunctionsAction(plugin, prog, provider.getTable(), provider.getTableModel());
|
||||
plugin.getTool().addLocalAction(provider, createActions);
|
||||
|
||||
}
|
||||
|
||||
private void addGeneralActions(FunctionStartTableProvider provider) {
|
||||
plugin.addProvider(provider);
|
||||
DockingAction programSelectAction =
|
||||
new MakeProgramSelectionAction(plugin, provider.getTable());
|
||||
programSelectAction.setEnabled(true);
|
||||
plugin.getTool().addLocalAction(provider, programSelectAction);
|
||||
DockingAction selectNavigationAction =
|
||||
new SelectionNavigationAction(plugin, provider.getTable());
|
||||
plugin.getTool().addLocalAction(provider, selectNavigationAction);
|
||||
ShowSimilarStartsAction similarStarts = new ShowSimilarStartsAction(plugin,
|
||||
plugin.getCurrentProgram(), provider.getTable(), provider.getTableModel());
|
||||
plugin.getTool().addLocalAction(provider, similarStarts);
|
||||
}
|
||||
|
||||
private Program selectProgram() {
|
||||
DataTreeDialog dtd = new DataTreeDialog(null, "Select Program", DataTreeDialog.OPEN, f -> {
|
||||
Class<?> c = f.getDomainObjectClass();
|
||||
return Program.class.isAssignableFrom(c);
|
||||
});
|
||||
dtd.show();
|
||||
if (dtd.wasCancelled()) {
|
||||
return null;
|
||||
}
|
||||
Program otherProgram = null;
|
||||
try {
|
||||
otherProgram = (Program) dtd.getDomainFile()
|
||||
.getDomainObject(plugin, true, true, getTaskMonitorComponent());
|
||||
openPrograms.add(otherProgram);
|
||||
}
|
||||
catch (VersionException | CancelledException | IOException e) {
|
||||
return null;
|
||||
}
|
||||
if (isProgramCompatible(otherProgram)) {
|
||||
return otherProgram;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
//checks whether otherProgram contains any specified context registers
|
||||
//at some point might be worth adding more restrictions
|
||||
private boolean isProgramCompatible(Program otherProgram) {
|
||||
if (params == null) {
|
||||
//shouldn't happen
|
||||
throw new IllegalStateException("null params");
|
||||
}
|
||||
if (!params.isRestrictedByContext()) {
|
||||
return true;
|
||||
}
|
||||
for (String regName : params.getContextRegisterNames()) {
|
||||
if (otherProgram.getRegister(regName) == null) {
|
||||
Msg.showError(this, null, "Error Applying Model", "Program " +
|
||||
otherProgram.getName() + " does not have a context register named " + regName);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private void setEnabled(boolean b) {
|
||||
minimumSizeField.setEnabled(b);
|
||||
initialBytesField.setEnabled(b);
|
||||
preBytesField.setEnabled(b);
|
||||
factorField.setEnabled(b);
|
||||
minimumSizeField.setEnabled(b);
|
||||
maxStartsField.setEnabled(b);
|
||||
trainButton.setEnabled(b);
|
||||
contextRegistersField.setEnabled(b);
|
||||
includeBeforeAndAfterBox.setEnabled(b);
|
||||
includeBitFeaturesBox.setEnabled(b);
|
||||
restoreDefaultsButton.setEnabled(b);
|
||||
}
|
||||
|
||||
private void addRestoreDefaultsButton() {
|
||||
restoreDefaultsButton = new JButton("Restore Defaults");
|
||||
restoreDefaultsButton.setToolTipText("Restore training parameters to the default values");
|
||||
restoreDefaultsButton.addActionListener(e -> restoreDefaults());
|
||||
addButton(restoreDefaultsButton);
|
||||
}
|
||||
|
||||
private void addHideDialogButton() {
|
||||
JButton hideDialogButton = new JButton("Hide Dialog");
|
||||
hideDialogButton.setToolTipText("Hide Dialog (does not cancel training or destory models)");
|
||||
hideDialogButton.addActionListener(e -> close());
|
||||
addButton(hideDialogButton);
|
||||
}
|
||||
|
||||
private void restoreDefaults() {
|
||||
initialBytesField.setText(DEFAULT_INITIAL_BYTES);
|
||||
preBytesField.setText(DEFAULT_PRE_BYTES);
|
||||
minimumSizeField.setValue((Integer.parseInt(DEFAULT_MINIMUM_FUNCTION_SIZE)));
|
||||
maxStartsField.setValue(Integer.parseInt(DEFAULT_MAX_STARTS));
|
||||
factorField.setText(DEFAULT_FACTOR);
|
||||
includeBeforeAndAfterBox.setSelected(Boolean.valueOf(DEFAULT_INCLUDE_PF));
|
||||
includeBitFeaturesBox.setSelected(Boolean.valueOf(DEFAULT_INCLUDE_BIT_FEATURES));
|
||||
contextRegistersField.setText(null);
|
||||
setProperties();
|
||||
}
|
||||
|
||||
private void setProperties() {
|
||||
Preferences.setProperty(INITIAL_BYTES_PROPERTY, initialBytesField.getText());
|
||||
Preferences.setProperty(PRE_BYTES_PROPERTY, preBytesField.getText());
|
||||
Preferences.setProperty(MIN_FUNC_SIZE_PROPERTY, minimumSizeField.getText());
|
||||
Preferences.setProperty(MAX_STARTS_PROPERTY, maxStartsField.getText());
|
||||
Preferences.setProperty(FACTOR_PROPERTY, factorField.getText());
|
||||
if (StringUtils.isBlank(contextRegistersField.getText())) {
|
||||
Preferences.removeProperty(CONTEXT_REGISTER_PROPERTY);
|
||||
}
|
||||
else {
|
||||
Preferences.setProperty(CONTEXT_REGISTER_PROPERTY, contextRegistersField.getText());
|
||||
}
|
||||
Preferences.setProperty(INCLUDE_PRECEDING_AND_FOLLOWING_PROPERTY,
|
||||
Boolean.toString(includeBeforeAndAfterBox.isSelected()));
|
||||
Preferences.setProperty(DEFAULT_INCLUDE_BIT_FEATURES,
|
||||
Boolean.toString(includeBitFeaturesBox.isSelected()));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,168 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.program.model.symbol.*;
|
||||
|
||||
/**
|
||||
* Represents a row in table showing the probabilities that addresses are function starts
|
||||
*/
|
||||
public class FunctionStartRowObject {
|
||||
private Address address;
|
||||
private double probability;
|
||||
private Interpretation currentInter;
|
||||
private int numDataRefs;
|
||||
private int numUnconditionalFlowRefs;
|
||||
private int numConditionalFlowRefs;
|
||||
|
||||
/**
|
||||
* Creates a row showing that {@link Address} {@code address} has probability
|
||||
* {@code probability} of being a function start. Use setter methods to set the other
|
||||
* entries in the row.
|
||||
* @param address address
|
||||
* @param probability prob of being function start
|
||||
*/
|
||||
public FunctionStartRowObject(Address address, double probability) {
|
||||
this.address = address;
|
||||
this.probability = probability;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the address.
|
||||
* @return address
|
||||
*/
|
||||
public Address getAddress() {
|
||||
return address;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the probability.
|
||||
* @return probability
|
||||
*/
|
||||
public double getProbability() {
|
||||
return probability;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link Interpretation}
|
||||
* @return interpretation
|
||||
*/
|
||||
public Interpretation getCurrentInterpretation() {
|
||||
return currentInter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link Interpretation}
|
||||
* @param inter interpretation
|
||||
*/
|
||||
public void setCurrentInterpretation(Interpretation inter) {
|
||||
currentInter = inter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of data references
|
||||
* @return num data refs
|
||||
*/
|
||||
public int getNumDataRefs() {
|
||||
return numDataRefs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the number of data references
|
||||
* @param numRefs num data refs
|
||||
*/
|
||||
public void setNumDataRefs(int numRefs) {
|
||||
numDataRefs = numRefs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of unconditional flow references
|
||||
* @return num unconditional flow refs
|
||||
*/
|
||||
public int getNumUnconditionalFlowRefs() {
|
||||
return numUnconditionalFlowRefs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the number of unconditional flow references
|
||||
* @param numRefs num unconditional flow refs
|
||||
*/
|
||||
public void setNumUnconditionalFlowRefs(int numRefs) {
|
||||
numUnconditionalFlowRefs = numRefs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the number of conditional flow references
|
||||
* @param numConditionalFlowRefs num conditional refs
|
||||
*/
|
||||
public void setNumConditionalFlowRefs(int numConditionalFlowRefs) {
|
||||
this.numConditionalFlowRefs = numConditionalFlowRefs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of conditional flow references
|
||||
* @return num conditional flow refs
|
||||
*/
|
||||
public int getNumConditionalFlowRefs() {
|
||||
return numConditionalFlowRefs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return address.hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
return address.equals(o);
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines and sets the number of data, conditional flow, and unconditional flow references
|
||||
* to the addresses corresponding to {@code rowObject}
|
||||
* @param rowObject row
|
||||
* @param program source program
|
||||
*/
|
||||
public static void setReferenceData(FunctionStartRowObject rowObject, Program program) {
|
||||
int numUnconditionalFlowRefs = 0;
|
||||
int numConditionalFlowRefs = 0;
|
||||
int numDataRefs = 0;
|
||||
ReferenceIterator refIter =
|
||||
program.getReferenceManager().getReferencesTo(rowObject.getAddress());
|
||||
while (refIter.hasNext()) {
|
||||
Reference ref = refIter.next();
|
||||
RefType type = ref.getReferenceType();
|
||||
if (type instanceof DataRefType) {
|
||||
numDataRefs++;
|
||||
continue;
|
||||
}
|
||||
if (type instanceof FlowType) {
|
||||
if (type.isConditional()) {
|
||||
numConditionalFlowRefs++;
|
||||
}
|
||||
else {
|
||||
numUnconditionalFlowRefs++;
|
||||
}
|
||||
}
|
||||
}
|
||||
rowObject.setNumDataRefs(numDataRefs);
|
||||
rowObject.setNumUnconditionalFlowRefs(numUnconditionalFlowRefs);
|
||||
rowObject.setNumConditionalFlowRefs(numConditionalFlowRefs);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
/* ###
|
||||
* IP: LICENSE
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import ghidra.framework.plugintool.ServiceProvider;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.util.table.ProgramLocationTableRowMapper;
|
||||
|
||||
public class FunctionStartRowObjectToAddressTableRowMapper
|
||||
extends ProgramLocationTableRowMapper<FunctionStartRowObject, Address> {
|
||||
|
||||
@Override
|
||||
public Address map(FunctionStartRowObject rowObject, Program data,
|
||||
ServiceProvider serviceProvider) {
|
||||
return rowObject.getAddress();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import ghidra.framework.plugintool.ServiceProvider;
|
||||
import ghidra.program.model.listing.Function;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.util.table.ProgramLocationTableRowMapper;
|
||||
|
||||
public class FunctionStartRowObjectToFunctionTableRowMapper
|
||||
extends ProgramLocationTableRowMapper<FunctionStartRowObject, Function> {
|
||||
|
||||
@Override
|
||||
public Function map(FunctionStartRowObject rowObject, Program data,
|
||||
ServiceProvider serviceProvider) {
|
||||
return data.getFunctionManager().getFunctionAt(rowObject.getAddress());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import ghidra.framework.plugintool.ServiceProvider;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.program.util.ProgramLocation;
|
||||
import ghidra.util.table.ProgramLocationTableRowMapper;
|
||||
|
||||
public class FunctionStartRowObjectToProgramLocationTableRowMapper
|
||||
extends ProgramLocationTableRowMapper<FunctionStartRowObject, ProgramLocation> {
|
||||
|
||||
@Override
|
||||
public ProgramLocation map(FunctionStartRowObject rowObject, Program program,
|
||||
ServiceProvider serviceProvider) {
|
||||
return new ProgramLocation(program, rowObject.getAddress());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,221 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
|
||||
import docking.widgets.table.AbstractDynamicTableColumn;
|
||||
import docking.widgets.table.TableColumnDescriptor;
|
||||
import ghidra.docking.settings.Settings;
|
||||
import ghidra.framework.plugintool.PluginTool;
|
||||
import ghidra.framework.plugintool.ServiceProvider;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.block.BasicBlockModel;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.util.datastruct.Accumulator;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.table.AddressBasedTableModel;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* A {@link AddressBasedTableModel} used to display information about addresses which are
|
||||
* likely to be function starts.
|
||||
*/
|
||||
public class FunctionStartTableModel extends AddressBasedTableModel<FunctionStartRowObject> {
|
||||
private RandomForestRowObject modelRow;
|
||||
private AddressSet addressesToClassify;
|
||||
private boolean debug;
|
||||
private BasicBlockModel blockModel;
|
||||
private Map<Address, Double> addressToProbability;
|
||||
|
||||
/**
|
||||
* Creates a table to display address likely to be function starts. If {@code debug}
|
||||
* is {@code true}, the table will contain a row for each address in
|
||||
* {@code addressesToClassify}. Otherwise it will only contain rows for addresses whose
|
||||
* associated probability is >= 0.5
|
||||
* @param plugin owning plugin
|
||||
* @param program source program
|
||||
* @param toClassify addresses to search
|
||||
* @param modelRow trained model info
|
||||
* @param debug is table displaying debug data
|
||||
*/
|
||||
public FunctionStartTableModel(PluginTool plugin, Program program, AddressSet toClassify,
|
||||
RandomForestRowObject modelRow, boolean debug) {
|
||||
super(program.getName(), plugin, program, null, false);
|
||||
this.modelRow = modelRow;
|
||||
this.addressesToClassify = toClassify;
|
||||
this.debug = debug;
|
||||
blockModel = new BasicBlockModel(program);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Address getAddress(int row) {
|
||||
return getRowObject(row).getAddress();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doLoad(Accumulator<FunctionStartRowObject> accumulator, TaskMonitor monitor)
|
||||
throws CancelledException {
|
||||
//table might change as users create functions/disassemble code, so this method
|
||||
//may be called repeatedly. However, the probabilities don't change, so only
|
||||
//compute the probability that an address is a function start once.
|
||||
if (addressToProbability == null) {
|
||||
FunctionStartClassifier classifier = new FunctionStartClassifier(program, modelRow,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START);
|
||||
//if debug, we want to display all errors in the test set
|
||||
//if we left the prob threshold at .5 we would not see true function starts
|
||||
//that the models thinks are not function starts
|
||||
if (debug) {
|
||||
classifier.setProbabilityThreshold(0.0);
|
||||
}
|
||||
addressToProbability = classifier.classify(addressesToClassify, monitor);
|
||||
}
|
||||
for (Entry<Address, Double> entry : addressToProbability.entrySet()) {
|
||||
Address addr = entry.getKey();
|
||||
FunctionStartRowObject rowObject = new FunctionStartRowObject(addr, entry.getValue());
|
||||
setInterpretation(rowObject, monitor);
|
||||
FunctionStartRowObject.setReferenceData(rowObject, program);
|
||||
accumulator.add(rowObject);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TableColumnDescriptor<FunctionStartRowObject> createTableColumnDescriptor() {
|
||||
TableColumnDescriptor<FunctionStartRowObject> descriptor = new TableColumnDescriptor<>();
|
||||
descriptor.addVisibleColumn(new AddressTableColumn());
|
||||
descriptor.addVisibleColumn(new ProbabilityTableColumn(), 0, false);
|
||||
descriptor.addVisibleColumn(new InterpretationTableColumn());
|
||||
descriptor.addVisibleColumn(new DataReferencesTableColumn());
|
||||
descriptor.addVisibleColumn(new UnconditionalFlowReferencesTableColumn());
|
||||
descriptor.addVisibleColumn(new ConditionalFlowReferencesTableColumn());
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link RandomForestRowObject} corresponding to this row
|
||||
* @return row object
|
||||
*/
|
||||
RandomForestRowObject getRandomForestRowObject() {
|
||||
return modelRow;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines and sets he {@link Interpretation} of the address corresponding to
|
||||
* {@code rowObject}.
|
||||
* @param rowObject row of table
|
||||
* @param monitor monitor
|
||||
* @throws CancelledException if user cancels
|
||||
*/
|
||||
void setInterpretation(FunctionStartRowObject rowObject, TaskMonitor monitor)
|
||||
throws CancelledException {
|
||||
Interpretation inter =
|
||||
Interpretation.getInterpretation(program, rowObject.getAddress(), blockModel, monitor);
|
||||
rowObject.setCurrentInterpretation(inter);
|
||||
}
|
||||
|
||||
private class AddressTableColumn
|
||||
extends AbstractDynamicTableColumn<FunctionStartRowObject, Address, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Address";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Address getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider services) throws IllegalArgumentException {
|
||||
return rowObject.getAddress();
|
||||
}
|
||||
}
|
||||
|
||||
private class ProbabilityTableColumn
|
||||
extends AbstractDynamicTableColumn<FunctionStartRowObject, Double, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Probability";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider services) throws IllegalArgumentException {
|
||||
return rowObject.getProbability();
|
||||
}
|
||||
}
|
||||
|
||||
private class InterpretationTableColumn
|
||||
extends AbstractDynamicTableColumn<FunctionStartRowObject, Interpretation, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Interpretation";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Interpretation getValue(FunctionStartRowObject rowObject, Settings settings,
|
||||
Object data, ServiceProvider services) throws IllegalArgumentException {
|
||||
return rowObject.getCurrentInterpretation();
|
||||
}
|
||||
}
|
||||
|
||||
private class DataReferencesTableColumn
|
||||
extends AbstractDynamicTableColumn<FunctionStartRowObject, Integer, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Data Refs";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider services) throws IllegalArgumentException {
|
||||
return rowObject.getNumDataRefs();
|
||||
}
|
||||
}
|
||||
|
||||
private class UnconditionalFlowReferencesTableColumn
|
||||
extends AbstractDynamicTableColumn<FunctionStartRowObject, Integer, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Unconditional Flow Refs";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider services) throws IllegalArgumentException {
|
||||
return rowObject.getNumUnconditionalFlowRefs();
|
||||
}
|
||||
}
|
||||
|
||||
private class ConditionalFlowReferencesTableColumn
|
||||
extends AbstractDynamicTableColumn<FunctionStartRowObject, Integer, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Conditional Flow Refs";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider services) throws IllegalArgumentException {
|
||||
return rowObject.getNumConditionalFlowRefs();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,172 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.awt.*;
|
||||
|
||||
import javax.swing.*;
|
||||
|
||||
import ghidra.app.services.GoToService;
|
||||
import ghidra.framework.model.*;
|
||||
import ghidra.framework.plugintool.ComponentProviderAdapter;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.program.util.ChangeManager;
|
||||
import ghidra.util.HelpLocation;
|
||||
import ghidra.util.table.*;
|
||||
|
||||
/**
|
||||
* A {@link ComponentProviderAdapter} for displaying tables of addresses that are likely
|
||||
* function starts
|
||||
*/
|
||||
public class FunctionStartTableProvider extends ProgramAssociatedComponentProviderAdapter
|
||||
implements DomainObjectListener {
|
||||
|
||||
private JComponent component;
|
||||
private FunctionStartTableModel model;
|
||||
private RandomForestFunctionFinderPlugin plugin;
|
||||
private Program program;
|
||||
private RandomForestRowObject modelRow;
|
||||
private AddressSet toClassify;
|
||||
private boolean debug;
|
||||
private String subTitle;
|
||||
private GhidraTable startTable;
|
||||
private GhidraThreadedTablePanel<FunctionStartRowObject> tablePanel;
|
||||
|
||||
/**
|
||||
* Constructs a table provider for the table of addresses to classify. If {@code debug}
|
||||
* is true, a debug version of the table will be created.
|
||||
*
|
||||
* @param plugin owning plugin
|
||||
* @param program program containing addresses to classify
|
||||
* @param toClassify addresses to classify
|
||||
* @param modelRow model to apply
|
||||
* @param debug whether to display debug version of table
|
||||
*/
|
||||
public FunctionStartTableProvider(RandomForestFunctionFinderPlugin plugin, Program program,
|
||||
AddressSet toClassify, RandomForestRowObject modelRow, boolean debug) {
|
||||
super(
|
||||
debug ? "Debug: Test Set Errors in " + program.getDomainFile().getPathname()
|
||||
: "Potential Functions in " + program.getDomainFile().getPathname(),
|
||||
plugin.getName(), program, plugin);
|
||||
this.program = program;
|
||||
this.plugin = plugin;
|
||||
this.modelRow = modelRow;
|
||||
this.toClassify = toClassify;
|
||||
this.debug = debug;
|
||||
subTitle = "Pre-bytes:" + modelRow.getNumPreBytes() + " Initial bytes:" +
|
||||
modelRow.getNumInitialBytes() + " Sampling Factor:" + modelRow.getSamplingFactor();
|
||||
component = build();
|
||||
program.addListener(this);
|
||||
String anchor = debug ? "DebugModelTable" : "FunctionStartTable";
|
||||
setHelpLocation(new HelpLocation(plugin.getName(), anchor));
|
||||
}
|
||||
|
||||
@Override
|
||||
public JComponent getComponent() {
|
||||
return component;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void domainObjectChanged(DomainObjectChangedEvent ev) {
|
||||
if (!isVisible()) {
|
||||
return;
|
||||
}
|
||||
if (ev.containsEvent(DomainObject.DO_OBJECT_RESTORED)) {
|
||||
model.reload();
|
||||
contextChanged();
|
||||
}
|
||||
for (int i = 0; i < ev.numRecords(); ++i) {
|
||||
DomainObjectChangeRecord doRecord = ev.getChangeRecord(i);
|
||||
int eventType = doRecord.getEventType();
|
||||
switch (eventType) {
|
||||
case ChangeManager.DOCR_FUNCTION_ADDED:
|
||||
case ChangeManager.DOCR_FUNCTION_REMOVED:
|
||||
case ChangeManager.DOCR_CODE_ADDED:
|
||||
case ChangeManager.DOCR_CODE_REMOVED:
|
||||
case ChangeManager.DOCR_CODE_REPLACED:
|
||||
case ChangeManager.DOCR_MEM_REF_TYPE_CHANGED:
|
||||
case ChangeManager.DOCR_MEM_REFERENCE_ADDED:
|
||||
case ChangeManager.DOCR_MEM_REFERENCE_REMOVED:
|
||||
model.reload();
|
||||
contextChanged();
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the underlying {@link GhidraTable}
|
||||
* @return table
|
||||
*/
|
||||
GhidraTable getTable() {
|
||||
return startTable;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the table model of this provider.
|
||||
* @return table model
|
||||
*/
|
||||
FunctionStartTableModel getTableModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
private JComponent build() {
|
||||
JPanel panel = new JPanel(new BorderLayout());
|
||||
Component table = buildTablePanel();
|
||||
panel.add(table, BorderLayout.CENTER);
|
||||
return panel;
|
||||
}
|
||||
|
||||
private Component buildTablePanel() {
|
||||
model = new FunctionStartTableModel(plugin.getTool(), program, toClassify, modelRow, debug);
|
||||
tablePanel = new GhidraThreadedTablePanel<>(model, 1000);
|
||||
startTable = tablePanel.getTable();
|
||||
startTable.setName("Potential Functions in " + model.getProgram().getName());
|
||||
|
||||
GoToService goToService = tool.getService(GoToService.class);
|
||||
if (goToService != null) {
|
||||
startTable.installNavigation(goToService, goToService.getDefaultNavigatable());
|
||||
}
|
||||
startTable.setNavigateOnSelectionEnabled(true);
|
||||
startTable.setAutoResizeMode(JTable.AUTO_RESIZE_SUBSEQUENT_COLUMNS);
|
||||
startTable.setPreferredScrollableViewportSize(new Dimension(900, 300));
|
||||
startTable.setRowSelectionAllowed(true);
|
||||
startTable.getSelectionModel().addListSelectionListener(e -> tool.contextChanged(this));
|
||||
|
||||
model.addTableModelListener(e -> {
|
||||
int rowCount = model.getRowCount();
|
||||
int unfilteredCount = model.getUnfilteredRowCount();
|
||||
|
||||
StringBuilder buffy = new StringBuilder();
|
||||
|
||||
buffy.append(" ").append(rowCount).append(" items");
|
||||
if (rowCount != unfilteredCount) {
|
||||
buffy.append(" (of ").append(unfilteredCount).append(" )");
|
||||
}
|
||||
|
||||
setSubTitle(subTitle + buffy.toString());
|
||||
});
|
||||
|
||||
JPanel container = new JPanel(new BorderLayout());
|
||||
container.add(tablePanel, BorderLayout.CENTER);
|
||||
var tableFilterPanel = new GhidraTableFilterPanel<>(startTable, model);
|
||||
container.add(tableFilterPanel, BorderLayout.SOUTH);
|
||||
|
||||
return container;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import ghidra.program.model.address.*;
|
||||
import ghidra.program.model.listing.*;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.Task;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* A {@link Task} for gathering addresses to feed to a function start classifier
|
||||
*/
|
||||
|
||||
public class GetAddressesToClassifyTask extends Task {
|
||||
private Program prog;
|
||||
private AddressSet execNonFunc;
|
||||
private long minUndefinedRangeSize;
|
||||
|
||||
/**
|
||||
* Creates a {@link Task} that creates a set of addresses to check for function starts. The
|
||||
* {code minUndefinedRangeSize} parameter determines how large a run of undefined bytes must be
|
||||
* to be checked for function starts
|
||||
* @param prog source program
|
||||
* @param minUndefinedRangeSize minimum size of undefined range
|
||||
*/
|
||||
public GetAddressesToClassifyTask(Program prog, long minUndefinedRangeSize) {
|
||||
super("Gathering Addresses to Classify", true, true, false, false);
|
||||
this.prog = prog;
|
||||
this.minUndefinedRangeSize = minUndefinedRangeSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run(TaskMonitor monitor) throws CancelledException {
|
||||
execNonFunc = new AddressSet();
|
||||
AddressSetView executable = prog.getMemory().getExecuteSet();
|
||||
AddressSetView initialized = prog.getMemory().getLoadedAndInitializedAddressSet();
|
||||
execNonFunc = executable.intersect(initialized);
|
||||
monitor.initialize(prog.getFunctionManager().getFunctionCount());
|
||||
FunctionIterator fIter = prog.getFunctionManager().getFunctions(true);
|
||||
while (fIter.hasNext()) {
|
||||
monitor.checkCanceled();
|
||||
monitor.incrementProgress(1);
|
||||
Function func = fIter.next();
|
||||
execNonFunc = execNonFunc.subtract(func.getBody());
|
||||
}
|
||||
//remove small undefined ranges to avoid (for example) searching for
|
||||
//function starts in an address range of length 3 between two known
|
||||
//functions. "small" is controlled by a plugin option.
|
||||
AddressSetView undefinedRanges =
|
||||
prog.getListing().getUndefinedRanges(execNonFunc, true, monitor);
|
||||
AddressSet toRemove = new AddressSet();
|
||||
AddressRangeIterator iter = undefinedRanges.getAddressRanges(true);
|
||||
while (iter.hasNext()) {
|
||||
AddressRange range = iter.next();
|
||||
if (range.getLength() <= minUndefinedRangeSize) {
|
||||
toRemove.add(range);
|
||||
}
|
||||
}
|
||||
execNonFunc = execNonFunc.subtract(toRemove);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the set of addresses to classify
|
||||
* @return addresses
|
||||
*/
|
||||
public AddressSet getAddressesToClassify() {
|
||||
return execNonFunc;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the subsets of the addresses to classify consisting of all addresses
|
||||
* which are aligned relative to the given modulus.
|
||||
* @param modulus alignment modulus
|
||||
* @return aligned addresses
|
||||
*/
|
||||
public AddressSet getAddressesToClassify(long modulus) {
|
||||
AddressSet aligned = new AddressSet();
|
||||
for (Address a : execNonFunc.getAddresses(true)) {
|
||||
if (a.getOffset() % modulus == 0) {
|
||||
aligned.add(a);
|
||||
}
|
||||
}
|
||||
return aligned;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.block.BasicBlockModel;
|
||||
import ghidra.program.model.listing.*;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* An enum representing possible interpretations of addresses
|
||||
* (e.g. data, undefined, function start ...)
|
||||
*/
|
||||
public enum Interpretation {
|
||||
UNDEFINED("Undefined"),
|
||||
DATA("Data"),
|
||||
OFFCUT("Offcut"),
|
||||
BLOCK_START("Block Start"),
|
||||
WITHIN_BLOCK("Within Block"),
|
||||
FUNCTION_START("Function Start"),
|
||||
//possibly want to refine this to block starts within functions
|
||||
//and within block within functions
|
||||
FUNCTION_INTERIOR("Function Interior");
|
||||
|
||||
private String display;
|
||||
|
||||
Interpretation(String display) {
|
||||
this.display = display;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return display;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the {@link Interpretation} for the given address in the given program.
|
||||
* @param program source program
|
||||
* @param addr address
|
||||
* @param monitor monitor
|
||||
* @return interpretation of addr
|
||||
* @throws CancelledException if user cancels monitor
|
||||
*/
|
||||
public static Interpretation getInterpretation(Program program, Address addr,
|
||||
TaskMonitor monitor) throws CancelledException {
|
||||
BasicBlockModel model = new BasicBlockModel(program);
|
||||
return getInterpretation(program, addr, model, monitor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the {@link Interpretation} for the given address in the given program. This
|
||||
* method is intended to be called repeatedly in a loop, so it takes a {@link BasicBlockModel}
|
||||
* as a parameter (which only need be created once).
|
||||
* @param program source program
|
||||
* @param addr address in question
|
||||
* @param model block model
|
||||
* @param monitor task model
|
||||
* @return interpretation of addr
|
||||
* @throws CancelledException if user cancels monitor
|
||||
*/
|
||||
public static Interpretation getInterpretation(Program program, Address addr,
|
||||
BasicBlockModel model, TaskMonitor monitor) throws CancelledException {
|
||||
CodeUnit cu = program.getListing().getCodeUnitContaining(addr);
|
||||
if (cu instanceof Data) {
|
||||
if (((Data) cu).isDefined()) {
|
||||
return DATA;
|
||||
}
|
||||
return UNDEFINED;
|
||||
}
|
||||
if (program.getListing().getInstructionAt(addr) == null &&
|
||||
program.getListing().getInstructionContaining(addr) != null) {
|
||||
return OFFCUT;
|
||||
}
|
||||
if (program.getFunctionManager().getFunctionAt(addr) != null) {
|
||||
return FUNCTION_START;
|
||||
}
|
||||
if (program.getFunctionManager().getFunctionContaining(addr) != null) {
|
||||
return FUNCTION_INTERIOR;
|
||||
}
|
||||
if (model.getCodeBlockAt(addr, monitor) == null) {
|
||||
return WITHIN_BLOCK;
|
||||
}
|
||||
return Interpretation.BLOCK_START;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,227 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import org.tribuo.Example;
|
||||
import org.tribuo.Feature;
|
||||
import org.tribuo.classification.Label;
|
||||
import org.tribuo.impl.ArrayExample;
|
||||
|
||||
import ghidra.program.model.address.*;
|
||||
import ghidra.program.model.listing.*;
|
||||
import ghidra.program.model.mem.MemoryAccessException;
|
||||
import ghidra.program.model.mem.MemoryBlock;
|
||||
import ghidra.util.Msg;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* This is a utility class containing static methods used when creating training/test
|
||||
* sets to train models to recognize function starts
|
||||
*/
|
||||
public class ModelTrainingUtils {
|
||||
|
||||
public static final int MAX_PRECEDING_CODE_UNIT_SIZE = 8;
|
||||
public static final double ZERO = 0.0d;
|
||||
public static final double ONE = 1.0d;
|
||||
|
||||
//utility class
|
||||
private ModelTrainingUtils() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a feature vector consisting of byte-level and optionally bit-level features around
|
||||
* {@code address}
|
||||
* @param program source program
|
||||
* @param address address
|
||||
* @param numPreBytes number of bytes to use preceding {@code address}
|
||||
* @param numInitialBytes number of bytes to use including and after address
|
||||
* @param includeBitFeatures whether to include bit-level features
|
||||
* @return feature vector
|
||||
*/
|
||||
public static List<Feature> getFeatureVector(Program program, Address address, int numPreBytes,
|
||||
int numInitialBytes, boolean includeBitFeatures) {
|
||||
MemoryBlock block = program.getMemory().getBlock(address);
|
||||
byte[] preBytesArray = new byte[numPreBytes];
|
||||
byte[] initialBytesArray = new byte[numInitialBytes];
|
||||
List<Feature> trainingVector = new ArrayList<>();
|
||||
try {
|
||||
Address preStart = address.add(-numPreBytes);
|
||||
block.getBytes(preStart, preBytesArray);
|
||||
block.getBytes(address, initialBytesArray);
|
||||
}
|
||||
catch (MemoryAccessException | AddressOutOfBoundsException e) {
|
||||
//most likely an exception means that you are trying to read beyond a block
|
||||
//boundary. This will happen occasionally when the sliding window is near the
|
||||
//begining or end of a block.
|
||||
Msg.warn(RandomForestTrainingTask.class,
|
||||
"MemoryAccessException at " + address.toString());
|
||||
return trainingVector;
|
||||
}
|
||||
|
||||
for (int i = 0; i < numPreBytes; i++) {
|
||||
int currentByte = Byte.toUnsignedInt(preBytesArray[i]);
|
||||
trainingVector.add(new Feature("pbyte_" + i, currentByte));
|
||||
if (!includeBitFeatures) {
|
||||
continue;
|
||||
}
|
||||
for (int bit = 7; bit >= 0; bit--) {
|
||||
String featureName = "pbit_" + i + "_" + bit;
|
||||
double val = ((currentByte & (1 << bit)) > 0) ? ONE : ZERO;
|
||||
trainingVector.add(new Feature(featureName, val));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < numInitialBytes; i++) {
|
||||
int currentByte = Byte.toUnsignedInt(initialBytesArray[i]);
|
||||
trainingVector.add(new Feature("ibyte_" + i, currentByte));
|
||||
if (!includeBitFeatures) {
|
||||
continue;
|
||||
}
|
||||
for (int bit = 7; bit >= 0; bit--) {
|
||||
String featureName = "ibit_" + i + "_" + bit;
|
||||
double val = ((currentByte & (1 << bit)) > 0) ? ONE : ZERO;
|
||||
trainingVector.add(new Feature(featureName, val));
|
||||
}
|
||||
}
|
||||
return trainingVector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an {@link AddressSet} constructed as follows: for each {@link Address} {@code addr}
|
||||
* in {@code addresses}, add the {@link Address} of the {@link CodeUnit} returned by
|
||||
* {@link Listing#getCodeUnitAfter(Address)}
|
||||
* <p> Addresses which correspond to function starts are not added to the returned set.
|
||||
* @param program source program
|
||||
* @param addresses addresses to follow
|
||||
* @param monitor monitor
|
||||
* @return following addresses
|
||||
* @throws CancelledException if the monitor is canceled
|
||||
*/
|
||||
public static AddressSet getFollowingAddresses(Program program, AddressSetView addresses,
|
||||
TaskMonitor monitor) throws CancelledException {
|
||||
AddressSet following = new AddressSet();
|
||||
for (Address addr : addresses.getAddresses(true)) {
|
||||
monitor.checkCanceled();
|
||||
CodeUnit cu = program.getListing().getCodeUnitAfter(addr);
|
||||
if (cu == null) {
|
||||
continue;
|
||||
}
|
||||
if (program.getFunctionManager().getFunctionAt(cu.getAddress()) != null) {
|
||||
Msg.warn(ModelTrainingUtils.class,
|
||||
"Function start following " + addr.toString() + ", skipping...");
|
||||
continue;
|
||||
}
|
||||
following.add(cu.getAddress());
|
||||
}
|
||||
return following;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an {@link AddressSet} constructed as follows: for each {@link Address} {@code addr}
|
||||
* in {@code addresses}, add the {@link Address} of the {@link CodeUnit} returned by
|
||||
* {@link Listing#getCodeUnitBefore(Address)}
|
||||
* <p> Addresses which correspond to function starts are not added to the returned set.
|
||||
* Addresses of {@link CodeUnit}s which are more than
|
||||
* {@link ModelTrainingUtils#MAX_PRECEDING_CODE_UNIT_SIZE} bytes away from the addresses they
|
||||
* precede are also not added to the returned set.
|
||||
* @param program source program
|
||||
* @param addresses addresses to precede
|
||||
* @param monitor monitor
|
||||
* @return preceding addresses
|
||||
* @throws CancelledException if the monitor is canceled
|
||||
*/
|
||||
public static AddressSet getPrecedingAddresses(Program program, AddressSetView addresses,
|
||||
TaskMonitor monitor) throws CancelledException {
|
||||
AddressSet preceding = new AddressSet();
|
||||
for (Address addr : addresses.getAddresses(true)) {
|
||||
monitor.checkCanceled();
|
||||
CodeUnit cu = program.getListing().getCodeUnitBefore(addr);
|
||||
if (cu == null) {
|
||||
continue;
|
||||
}
|
||||
if (program.getFunctionManager().getFunctionAt(cu.getAddress()) != null) {
|
||||
Msg.warn(ModelTrainingUtils.class,
|
||||
"Function start preceding " + addr.toString() + ", skipping...");
|
||||
continue;
|
||||
}
|
||||
if (addr.getOffset() - cu.getAddress().getOffset() > MAX_PRECEDING_CODE_UNIT_SIZE) {
|
||||
continue;
|
||||
}
|
||||
preceding.add(cu.getAddress());
|
||||
}
|
||||
return preceding;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an {@link AddressSet} consisting of all {@link Address}es where data is defined
|
||||
* in {@code program}. Note that this includes addresses within defined data and not just
|
||||
* addresses where defined data starts.
|
||||
* @param program source program
|
||||
* @param monitor task monitor
|
||||
* @return addresses where data is defined
|
||||
* @throws CancelledException if monitor is canceled
|
||||
*/
|
||||
public static AddressSet getDefinedData(Program program, TaskMonitor monitor)
|
||||
throws CancelledException {
|
||||
DataIterator dataIter =
|
||||
program.getListing().getDefinedData(program.getMemory().getExecuteSet(), true);
|
||||
AddressSet definedData = new AddressSet();
|
||||
for (Data d : dataIter) {
|
||||
monitor.checkCanceled();
|
||||
definedData.add(
|
||||
new AddressRangeImpl(d.getAddress(), d.getAddress().add(d.getLength() - 1)));
|
||||
}
|
||||
return definedData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a feature vector for every address in {@code source} and applies the {@Label}
|
||||
* {@code label}
|
||||
*
|
||||
* @param program program
|
||||
* @param source input addresses
|
||||
* @param label label to apply
|
||||
* @param numPreBytes bytes before address to include
|
||||
* @param numInitialBytes bytes after and including addresses
|
||||
* @param includeBitFeatures whether to include bit-level features
|
||||
* @param monitor monitor
|
||||
* @return list of vectors
|
||||
* @throws CancelledException if monitor is canceled
|
||||
*/
|
||||
public static List<Example<Label>> getVectorsFromAddresses(Program program,
|
||||
AddressSetView source, Label label, int numPreBytes, int numInitialBytes,
|
||||
boolean includeBitFeatures, TaskMonitor monitor) throws CancelledException {
|
||||
List<Example<Label>> examples = new ArrayList<>();
|
||||
monitor.initialize(source.getNumAddresses());
|
||||
Iterator<Address> addressIter = source.getAddresses(true);
|
||||
while (addressIter.hasNext()) {
|
||||
monitor.checkCanceled();
|
||||
Address addr = addressIter.next();
|
||||
monitor.incrementProgress(1L);
|
||||
List<Feature> trainingVector =
|
||||
getFeatureVector(program, addr, numPreBytes, numInitialBytes, includeBitFeatures);
|
||||
if (trainingVector.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
ArrayExample<Label> vec = new ArrayExample<>(label, trainingVector);
|
||||
examples.add(vec);
|
||||
}
|
||||
return examples;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import ghidra.framework.plugintool.ComponentProviderAdapter;
|
||||
import ghidra.program.model.listing.Program;
|
||||
|
||||
/**
|
||||
* This class is used by {@link RandomForestFunctionFinderPlugin}, which can create various
|
||||
* components for displaying potential function starts in a given program (and related info).
|
||||
* The purpose of this class is to facilitate closing all components associated with a program
|
||||
* when that program is closed.
|
||||
*/
|
||||
public abstract class ProgramAssociatedComponentProviderAdapter extends ComponentProviderAdapter {
|
||||
|
||||
private Program program;
|
||||
private RandomForestFunctionFinderPlugin plugin;
|
||||
|
||||
/**
|
||||
* Creates a {@link ComponentProviderAdapter} with an associated {@link Program} and {
|
||||
* {@link RandomForestFunctionFinderPlugin}
|
||||
* @param name name
|
||||
* @param owner owner
|
||||
* @param program associated program
|
||||
* @param plugin plugin
|
||||
*/
|
||||
public ProgramAssociatedComponentProviderAdapter(String name, String owner, Program program,
|
||||
RandomForestFunctionFinderPlugin plugin) {
|
||||
super(plugin.getTool(), name, owner);
|
||||
this.program = program;
|
||||
this.plugin = plugin;
|
||||
setTransient();
|
||||
setWindowMenuGroup("Search for Code and Functions");
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the associated program
|
||||
* @return the program
|
||||
*/
|
||||
Program getProgram() {
|
||||
return program;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void closeComponent() {
|
||||
plugin.removeProvider(this);
|
||||
super.closeComponent();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,235 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import org.tribuo.classification.Label;
|
||||
|
||||
import aQute.bnd.unmodifiable.Lists;
|
||||
import docking.action.builder.ActionBuilder;
|
||||
import docking.tool.ToolConstants;
|
||||
import ghidra.MiscellaneousPluginPackage;
|
||||
import ghidra.app.context.NavigatableActionContext;
|
||||
import ghidra.app.context.RestrictedAddressSetContext;
|
||||
import ghidra.app.events.ProgramClosedPluginEvent;
|
||||
import ghidra.app.events.ProgramLocationPluginEvent;
|
||||
import ghidra.app.plugin.PluginCategoryNames;
|
||||
import ghidra.app.plugin.ProgramPlugin;
|
||||
import ghidra.app.services.GoToService;
|
||||
import ghidra.framework.options.OptionsChangeListener;
|
||||
import ghidra.framework.options.ToolOptions;
|
||||
import ghidra.framework.plugintool.PluginInfo;
|
||||
import ghidra.framework.plugintool.PluginTool;
|
||||
import ghidra.framework.plugintool.util.PluginStatus;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.program.util.ProgramSelection;
|
||||
import ghidra.util.HelpLocation;
|
||||
import ghidra.util.Msg;
|
||||
import ghidra.util.bean.opteditor.OptionsVetoException;
|
||||
|
||||
//@formatter:off
|
||||
@PluginInfo(
|
||||
status = PluginStatus.UNSTABLE,
|
||||
packageName = MiscellaneousPluginPackage.NAME,
|
||||
category = PluginCategoryNames.ANALYSIS,
|
||||
shortDescription = "Function Finder",
|
||||
description = "Trains a random forest model to find function starts.",
|
||||
servicesRequired = { GoToService.class},
|
||||
eventsProduced = { ProgramLocationPluginEvent.class },
|
||||
eventsConsumed = { ProgramClosedPluginEvent.class}
|
||||
)
|
||||
//@formatter:on
|
||||
|
||||
/**
|
||||
* A {@link ProgramPlugin} for training a model on the starts of known functions in a
|
||||
* program and then using that model to look for more functions (in the source program or
|
||||
* another program selected by the user).
|
||||
*/
|
||||
public class RandomForestFunctionFinderPlugin extends ProgramPlugin
|
||||
implements OptionsChangeListener {
|
||||
|
||||
public static final Label FUNC_START = new Label("S");
|
||||
public static final Label NON_START = new Label("N");
|
||||
private static final String ACTION_NAME = "Search for Code and Functions";
|
||||
private static final String MENU_PATH_ENTRY = "For Code and Functions...";
|
||||
private static final String TEST_SET_MAX_SIZE_OPTION_NAME = "Maximum size of test sets";
|
||||
static final Long TEST_SET_MAX_SIZE_DEFAULT = 1000000l;
|
||||
private Long testSetMax;
|
||||
private static final String MIN_UNDEFINED_RANGE_SIZE_OPTION_NAME =
|
||||
"Minimum Length of Undefined Range to Search";
|
||||
static final Long MIN_UNDEFINED_RANGE_SIZE_DEFAULT = 16l;
|
||||
private Long minUndefinedRangeSize;
|
||||
|
||||
private FunctionStartRFParamsDialog paramsDialog;
|
||||
|
||||
//this map is used to close all providers associated with a program p
|
||||
//when p is closed
|
||||
private Map<Program, List<ProgramAssociatedComponentProviderAdapter>> programsToProviders;
|
||||
|
||||
/**
|
||||
* Creates the plugin for the given tool.
|
||||
* @param tool tool for plugin
|
||||
*/
|
||||
public RandomForestFunctionFinderPlugin(PluginTool tool) {
|
||||
super(tool, false, true);
|
||||
programsToProviders = new HashMap<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void init() {
|
||||
createActions();
|
||||
initOptions(getTool().getOptions("Random Forest Function Finder"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void optionsChanged(ToolOptions options, String optionName, Object oldValue,
|
||||
Object newValue) throws OptionsVetoException {
|
||||
switch (optionName) {
|
||||
case TEST_SET_MAX_SIZE_OPTION_NAME:
|
||||
Long newMax = (Long) newValue;
|
||||
if (newMax <= 0) {
|
||||
//does this actually inform the user of the problem?
|
||||
throw new OptionsVetoException(
|
||||
TEST_SET_MAX_SIZE_OPTION_NAME + " must be positive!");
|
||||
}
|
||||
testSetMax = newMax;
|
||||
break;
|
||||
case MIN_UNDEFINED_RANGE_SIZE_OPTION_NAME:
|
||||
Long newMin = (Long) newValue;
|
||||
if (newMin <= 0) {
|
||||
throw new OptionsVetoException(
|
||||
MIN_UNDEFINED_RANGE_SIZE_OPTION_NAME + " must be positive!");
|
||||
}
|
||||
minUndefinedRangeSize = newMin;
|
||||
break;
|
||||
default:
|
||||
Msg.showError(this, null, "Unknown option", "Unknown option: " + optionName);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Record the existence of a {@link ProgramAssociatedComponentProviderAdapter} so that it can
|
||||
* be closed if its associated program is closed
|
||||
* @param provider provider
|
||||
*/
|
||||
void addProvider(ProgramAssociatedComponentProviderAdapter provider) {
|
||||
List<ProgramAssociatedComponentProviderAdapter> providers =
|
||||
programsToProviders.computeIfAbsent(provider.getProgram(), p -> new ArrayList<>());
|
||||
providers.add(provider);
|
||||
tool.addComponentProvider(provider, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove the provider from the list of tracked providers
|
||||
* @param provider provider
|
||||
*/
|
||||
void removeProvider(ProgramAssociatedComponentProviderAdapter provider) {
|
||||
programsToProviders.get(provider.getProgram()).remove(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the current selection. Cf. {@link FunctionStartRFParamsDialog#addGeneralActions}
|
||||
* @param selection new selection
|
||||
*/
|
||||
void setSelection(ProgramSelection selection) {
|
||||
currentSelection = selection;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the maximum size of a test set. Users can set this value via a plugin option.
|
||||
* @return max size
|
||||
*/
|
||||
Long getTestMaxSize() {
|
||||
return testSetMax;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the minimum size of an undefined range to search for function starts. Users
|
||||
* can set this value via a plugion option.
|
||||
* @return min undefined range size
|
||||
*/
|
||||
Long getMinUndefinedRangeSize() {
|
||||
return minUndefinedRangeSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Null out the dialog
|
||||
*/
|
||||
void resetDialog() {
|
||||
paramsDialog = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void programClosed(Program p) {
|
||||
//ProgramAssociatedComponentProviderAdapter.closeComponent modifies values of
|
||||
//programsToProviders, so make a copy to avoid a ConcurrentModificationException
|
||||
List<ProgramAssociatedComponentProviderAdapter> providersToClose =
|
||||
Lists.copyOf(programsToProviders.getOrDefault(p, Collections.emptyList()));
|
||||
for (ProgramAssociatedComponentProviderAdapter provider : providersToClose) {
|
||||
provider.closeComponent();
|
||||
}
|
||||
programsToProviders.remove(p);
|
||||
if (paramsDialog == null) {
|
||||
return;
|
||||
}
|
||||
if (!paramsDialog.getTrainingSource().equals(p)) {
|
||||
return;
|
||||
}
|
||||
paramsDialog.dismissCallback();
|
||||
}
|
||||
|
||||
private void createActions() {
|
||||
|
||||
new ActionBuilder(ACTION_NAME, getName())
|
||||
.menuPath(ToolConstants.MENU_SEARCH, MENU_PATH_ENTRY)
|
||||
.menuGroup("search for", null)
|
||||
.description("Train models to search for function starts")
|
||||
.helpLocation(new HelpLocation(getName(), getName()))
|
||||
.withContext(NavigatableActionContext.class)
|
||||
.validContextWhen(c -> !(c instanceof RestrictedAddressSetContext))
|
||||
.supportsDefaultToolContext(true)
|
||||
.onAction(c -> {
|
||||
displayDialog(c);
|
||||
})
|
||||
.buildAndInstall(tool);
|
||||
}
|
||||
|
||||
private void displayDialog(NavigatableActionContext c) {
|
||||
if (paramsDialog == null) {
|
||||
paramsDialog = new FunctionStartRFParamsDialog(this);
|
||||
}
|
||||
if (!paramsDialog.getTrainingSource().equals(this.getCurrentProgram())) {
|
||||
paramsDialog.dismissCallback();
|
||||
paramsDialog = new FunctionStartRFParamsDialog(this);
|
||||
}
|
||||
tool.showDialog(paramsDialog, c.getComponentProvider());
|
||||
}
|
||||
|
||||
private void initOptions(ToolOptions options) {
|
||||
options.registerOption(TEST_SET_MAX_SIZE_OPTION_NAME, TEST_SET_MAX_SIZE_DEFAULT,
|
||||
new HelpLocation(getName(), "MaxTestSetSize"),
|
||||
"Maximum sizes for test sets (must be positive).");
|
||||
testSetMax = options.getLong(TEST_SET_MAX_SIZE_OPTION_NAME, TEST_SET_MAX_SIZE_DEFAULT);
|
||||
options.registerOption(MIN_UNDEFINED_RANGE_SIZE_OPTION_NAME,
|
||||
MIN_UNDEFINED_RANGE_SIZE_DEFAULT,
|
||||
new HelpLocation(getName(), "MinLengthUndefinedRange"),
|
||||
"Minimum Size of an Undefined AddressRange to search (must be positive).");
|
||||
minUndefinedRangeSize =
|
||||
options.getLong(MIN_UNDEFINED_RANGE_SIZE_OPTION_NAME, MIN_UNDEFINED_RANGE_SIZE_DEFAULT);
|
||||
options.addOptionsChangeListener(this);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.math.*;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import org.tribuo.classification.Label;
|
||||
import org.tribuo.ensemble.EnsembleModel;
|
||||
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
|
||||
/**
|
||||
* A class for row objects in a table whose rows are associated with models trained to
|
||||
* find function starts. Some of the fields of this class are used to populate these
|
||||
* rows with data about how accurate the model was on the test set. Other fields
|
||||
* (such as {@code numPreBytes}, {@code numInitialBytes}, and {@code includeBitLevelFeatures})
|
||||
* are needed by actions defined on this table apply the model (and thus must know how
|
||||
* to generate feature vectors the model consumes).
|
||||
*/
|
||||
public class RandomForestRowObject {
|
||||
|
||||
private BigDecimal precision;
|
||||
private BigDecimal recall;
|
||||
private int numPreBytes;
|
||||
private int numInitialBytes;
|
||||
private int samplingFactor;
|
||||
private boolean includeBitLevelFeatures;
|
||||
private List<String> contextRegisters;
|
||||
private List<BigInteger> contextRegisterValues;
|
||||
private EnsembleModel<Label> randomForest;
|
||||
private AddressSet testErrors;
|
||||
private AddressSet trainingPositive;
|
||||
private int[] confusionMatrix;
|
||||
|
||||
/**
|
||||
* Constructs a row
|
||||
* @param numPreBytes number of prebytes in vectors consumed by model
|
||||
* @param numInitialBytes number of initialBytes in vectors consumed by model
|
||||
* @param samplingFactor non-start to start sampling factor
|
||||
* @param confusionMatrix confusion matrix of model on test set
|
||||
* @param randomForest model
|
||||
* @param testErrors set of addresses in test set with errors
|
||||
* @param trainingPositive set of positive training examples (i.e. function starts)
|
||||
* @param includeBitLevelFeatures whether bit-level features were included in model
|
||||
*/
|
||||
public RandomForestRowObject(int numPreBytes, int numInitialBytes, int samplingFactor,
|
||||
int[] confusionMatrix, EnsembleModel<Label> randomForest, AddressSet testErrors,
|
||||
AddressSet trainingPositive, boolean includeBitLevelFeatures) {
|
||||
this.numPreBytes = numPreBytes;
|
||||
this.numInitialBytes = numInitialBytes;
|
||||
this.samplingFactor = samplingFactor;
|
||||
this.randomForest = randomForest;
|
||||
this.testErrors = testErrors;
|
||||
this.contextRegisters = Collections.emptyList();
|
||||
this.contextRegisterValues = Collections.emptyList();
|
||||
this.confusionMatrix = confusionMatrix;
|
||||
this.includeBitLevelFeatures = includeBitLevelFeatures;
|
||||
BigDecimal numerator = new BigDecimal(confusionMatrix[RandomForestTrainingTask.TP]);
|
||||
BigDecimal denominator = new BigDecimal(confusionMatrix[RandomForestTrainingTask.TP] +
|
||||
confusionMatrix[RandomForestTrainingTask.FP]);
|
||||
if (denominator.equals(BigDecimal.ZERO)) {
|
||||
precision = null;
|
||||
}
|
||||
else {
|
||||
precision = numerator.divide(denominator, 2, RoundingMode.HALF_EVEN);
|
||||
}
|
||||
denominator = new BigDecimal(confusionMatrix[RandomForestTrainingTask.TP] +
|
||||
confusionMatrix[RandomForestTrainingTask.FN]);
|
||||
if (denominator.equals(BigDecimal.ZERO)) {
|
||||
recall = null;
|
||||
}
|
||||
else {
|
||||
recall = numerator.divide(denominator, 2, RoundingMode.HALF_EVEN);
|
||||
}
|
||||
this.trainingPositive = trainingPositive;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the values for context register the model is aware of
|
||||
* @param regList register names
|
||||
* @param valueList register values
|
||||
*/
|
||||
public void setContextRegistersAndValues(List<String> regList, List<BigInteger> valueList) {
|
||||
if (regList.size() != valueList.size()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Register list and value list must have the same size!");
|
||||
}
|
||||
contextRegisters = List.copyOf(regList);
|
||||
contextRegisterValues = List.copyOf(valueList);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a boolean indicating whether the model is aware of any context registers
|
||||
* @return aware of context
|
||||
*/
|
||||
public boolean isContextRestricted() {
|
||||
return !contextRegisters.isEmpty();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the names of the context registers the model is aware of
|
||||
* @return context reg names
|
||||
*/
|
||||
public List<String> getContextRegisterList() {
|
||||
return contextRegisters;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the list of values of context registers the model is aware of
|
||||
* @return context reg values
|
||||
*/
|
||||
public List<BigInteger> getContextRegisterValues() {
|
||||
return contextRegisterValues;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the precision of the model on the test set
|
||||
* @return precision
|
||||
*/
|
||||
public BigDecimal getPrecision() {
|
||||
return precision;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the recall of the model on the test set
|
||||
* @return recall
|
||||
*/
|
||||
public BigDecimal getRecall() {
|
||||
return recall;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a boolean indicating whether bit-level features were included when training the model
|
||||
* @return bit-level features used
|
||||
*/
|
||||
public boolean getIncludeBitLevelFeatures() {
|
||||
return includeBitLevelFeatures;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of pre-bytes used when training the model
|
||||
* @return pre-bytes
|
||||
*/
|
||||
public int getNumPreBytes() {
|
||||
return numPreBytes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the sampling factor used when training the model
|
||||
* @return sampling factor
|
||||
*/
|
||||
public int getSamplingFactor() {
|
||||
return samplingFactor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of initial bytes used when training the model
|
||||
* @return num initial bytes
|
||||
*/
|
||||
public int getNumInitialBytes() {
|
||||
return numInitialBytes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the model
|
||||
* @return model
|
||||
*/
|
||||
public EnsembleModel<Label> getRandomForest() {
|
||||
return randomForest;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the addresses in the test set where the model made an error
|
||||
* @return error set
|
||||
*/
|
||||
public AddressSet getTestErrors() {
|
||||
return testErrors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of false positives the model produces when classifying the test set.
|
||||
* @return num false positives
|
||||
*/
|
||||
public int getNumFalsePositives() {
|
||||
return confusionMatrix[RandomForestTrainingTask.FP];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the set of function starts in the training set.
|
||||
* @return known starts
|
||||
*/
|
||||
public AddressSet getTrainingPositives() {
|
||||
return trainingPositive;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.List;
|
||||
|
||||
import docking.widgets.table.AbstractDynamicTableColumn;
|
||||
import docking.widgets.table.TableColumnDescriptor;
|
||||
import docking.widgets.table.threaded.ThreadedTableModelStub;
|
||||
import ghidra.docking.settings.Settings;
|
||||
import ghidra.framework.plugintool.ServiceProvider;
|
||||
import ghidra.util.datastruct.Accumulator;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* A table model for tables that display information about random forests trained to find
|
||||
* function starts.
|
||||
*/
|
||||
public class RandomForestTableModel extends ThreadedTableModelStub<RandomForestRowObject> {
|
||||
|
||||
private static final String MODEL_NAME = "Random Forest Evaluations";
|
||||
private List<RandomForestRowObject> rowObjects;
|
||||
|
||||
/**
|
||||
* Creates a table model
|
||||
* @param serviceProvider service provider
|
||||
* @param rowObjects rows of table
|
||||
*/
|
||||
public RandomForestTableModel(ServiceProvider serviceProvider,
|
||||
List<RandomForestRowObject> rowObjects) {
|
||||
super(MODEL_NAME, serviceProvider);
|
||||
this.rowObjects = rowObjects;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doLoad(Accumulator<RandomForestRowObject> accumulator, TaskMonitor monitor)
|
||||
throws CancelledException {
|
||||
accumulator.addAll(rowObjects);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TableColumnDescriptor<RandomForestRowObject> createTableColumnDescriptor() {
|
||||
TableColumnDescriptor<RandomForestRowObject> descriptor = new TableColumnDescriptor<>();
|
||||
descriptor.addVisibleColumn(new NumPreBytesColumn());
|
||||
descriptor.addVisibleColumn(new NumInitialBytesColumn());
|
||||
descriptor.addVisibleColumn(new SamplingFactorColumn());
|
||||
descriptor.addVisibleColumn(new FalsePositivesColumn(), 1, true);
|
||||
descriptor.addVisibleColumn(new PrecisionTableColumn());
|
||||
descriptor.addVisibleColumn(new RecallTableColumn(), 2, false);
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Table column classes
|
||||
*/
|
||||
|
||||
class NumPreBytesColumn
|
||||
extends AbstractDynamicTableColumn<RandomForestRowObject, Integer, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Pre-Bytes";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getValue(RandomForestRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider sProvider) throws IllegalArgumentException {
|
||||
return rowObject.getNumPreBytes();
|
||||
}
|
||||
}
|
||||
|
||||
class NumInitialBytesColumn
|
||||
extends AbstractDynamicTableColumn<RandomForestRowObject, Integer, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Initial Bytes";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getValue(RandomForestRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider sProvider) throws IllegalArgumentException {
|
||||
return rowObject.getNumInitialBytes();
|
||||
}
|
||||
}
|
||||
|
||||
class SamplingFactorColumn
|
||||
extends AbstractDynamicTableColumn<RandomForestRowObject, Integer, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Factor";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getValue(RandomForestRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider sProvider) throws IllegalArgumentException {
|
||||
return rowObject.getSamplingFactor();
|
||||
}
|
||||
}
|
||||
|
||||
class PrecisionTableColumn
|
||||
extends AbstractDynamicTableColumn<RandomForestRowObject, BigDecimal, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Precision";
|
||||
}
|
||||
|
||||
@Override
|
||||
public BigDecimal getValue(RandomForestRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider sProvider) throws IllegalArgumentException {
|
||||
return rowObject.getPrecision();
|
||||
}
|
||||
}
|
||||
|
||||
class RecallTableColumn
|
||||
extends AbstractDynamicTableColumn<RandomForestRowObject, BigDecimal, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Recall";
|
||||
}
|
||||
|
||||
@Override
|
||||
public BigDecimal getValue(RandomForestRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider sProvider) throws IllegalArgumentException {
|
||||
return rowObject.getRecall();
|
||||
}
|
||||
}
|
||||
|
||||
class FalsePositivesColumn
|
||||
extends AbstractDynamicTableColumn<RandomForestRowObject, Integer, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "False Positives";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getValue(RandomForestRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider sProvider) throws IllegalArgumentException {
|
||||
return rowObject.getNumFalsePositives();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,485 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import org.tribuo.*;
|
||||
import org.tribuo.classification.Label;
|
||||
import org.tribuo.classification.LabelFactory;
|
||||
import org.tribuo.classification.dtree.CARTClassificationTrainer;
|
||||
import org.tribuo.classification.ensemble.VotingCombiner;
|
||||
import org.tribuo.common.tree.TreeModel;
|
||||
import org.tribuo.dataset.DatasetView;
|
||||
import org.tribuo.datasource.ListDataSource;
|
||||
import org.tribuo.ensemble.EnsembleModel;
|
||||
import org.tribuo.ensemble.WeightedEnsembleModel;
|
||||
import org.tribuo.provenance.SimpleDataSourceProvenance;
|
||||
|
||||
import generic.concurrent.*;
|
||||
import ghidra.program.model.address.*;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.program.util.ProgramSelection;
|
||||
import ghidra.util.Msg;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.Task;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* This {@link Task} is used to train and evaluate random forests.
|
||||
*/
|
||||
public class RandomForestTrainingTask extends Task {
|
||||
|
||||
//NUM_TREES should be odd to avoid nuances involving tiebreaking
|
||||
public static final int NUM_TREES = 99;
|
||||
public static final String TITLE = "Training Model for ";
|
||||
public static final int MAX_EVAL_SET_SIZE = 50000;
|
||||
public static final String RANDOM_FOREST_TRAINING_THREADPOOL = "RandomForestTrainer";
|
||||
public static final double NANOSECONDS_PER_SECOND = 1000000000.0d;
|
||||
|
||||
//for the confusion matrix
|
||||
public static final int TP = 0;
|
||||
public static final int FP = 1;
|
||||
public static final int TN = 2;
|
||||
public static final int FN = 3;
|
||||
public static final int CONFUSION_MATRIX_SIZE = 4;
|
||||
|
||||
private FunctionStartRFParams params;
|
||||
private Program program;
|
||||
//trainingSet is accessed from different threads during training
|
||||
private Dataset<Label> trainingSet;
|
||||
private Consumer<RandomForestRowObject> rowObjectConsumer;
|
||||
private AddressSet additionalStarts;
|
||||
private AddressSet additionalNonStarts;
|
||||
private Long testSetMax;
|
||||
|
||||
/**
|
||||
* Creates a {@link Task} for training (and evaluating) random forests for function
|
||||
* start identification.
|
||||
*
|
||||
* @param program source of training
|
||||
* @param params parameters controlling training
|
||||
* @param rowObjectConsumer consumes data about the trained models
|
||||
* @param testSetMax maximum size of test sets
|
||||
*/
|
||||
public RandomForestTrainingTask(Program program, FunctionStartRFParams params,
|
||||
Consumer<RandomForestRowObject> rowObjectConsumer, long testSetMax) {
|
||||
super(TITLE + program.getName(), true, true, false, false);
|
||||
this.program = program;
|
||||
this.params = params;
|
||||
this.rowObjectConsumer = rowObjectConsumer;
|
||||
this.testSetMax = testSetMax;
|
||||
additionalStarts = new AddressSet();
|
||||
additionalNonStarts = new AddressSet();
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a {@ProgramSelection} to the training set. Function starts within the selection are
|
||||
* added as positive examples and everything else is added as a negative example.
|
||||
*
|
||||
*<p>
|
||||
* Addresses which are not aligned or which do not agree with the context register values
|
||||
* in the {@code params} variable of the constructor are ignored.
|
||||
* @param selection selection to add
|
||||
* @return number of aligned addresses conflicting with the context register data specified in
|
||||
* {@code params}
|
||||
*/
|
||||
public int setAdditional(ProgramSelection selection) {
|
||||
if (selection == null) {
|
||||
return 0;
|
||||
}
|
||||
int numConflicts = 0;
|
||||
int instructionAlignment = program.getLanguage().getInstructionAlignment();
|
||||
for (Address addr : selection.getAddresses(true)) {
|
||||
if (addr.getOffset() % instructionAlignment != 0) {
|
||||
continue;
|
||||
}
|
||||
if (params.isRestrictedByContext()) {
|
||||
if (!params.isContextCompatible(addr)) {
|
||||
numConflicts += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (program.getFunctionManager().getFunctionAt(addr) == null) {
|
||||
additionalNonStarts.add(addr);
|
||||
}
|
||||
else {
|
||||
additionalStarts.add(addr);
|
||||
}
|
||||
}
|
||||
return numConflicts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run(TaskMonitor monitor) throws CancelledException {
|
||||
monitor.setIndeterminate(true);
|
||||
monitor.setMessage("Gathering function entries and interiors");
|
||||
|
||||
//get all the function entries and interiors that are consistent
|
||||
//with the requirements in params
|
||||
params.computeFuncEntriesAndInteriors(monitor);
|
||||
AddressSet allEntries = params.getFuncEntries();
|
||||
AddressSet allInteriors = params.getFuncInteriors();
|
||||
|
||||
//defined data in executable sections will be added to the test set
|
||||
AddressSet definedData = ModelTrainingUtils.getDefinedData(program, monitor);
|
||||
|
||||
for (Integer factor : params.getSamplingFactors()) {
|
||||
//get the addresses used for training and testing
|
||||
TrainingAndTestData data =
|
||||
getTrainingAndTestData(allEntries, allInteriors, definedData, factor, monitor);
|
||||
if (data == null) {
|
||||
continue;
|
||||
}
|
||||
data.reduceTestSetSize(testSetMax, monitor);
|
||||
monitor.setIndeterminate(false);
|
||||
for (Integer preBytes : params.getPreBytes()) {
|
||||
for (Integer initialBytes : params.getInitialBytes()) {
|
||||
Msg.info(this,
|
||||
String.format(
|
||||
"Data Gathering Parameters: factor: %d preBytes: %s initialBytes: %d",
|
||||
factor, preBytes, initialBytes));
|
||||
|
||||
//create the vectors for the training addresses
|
||||
List<Example<Label>> trainingData = new ArrayList<>();
|
||||
monitor.setMessage("Generating vectors for function entries");
|
||||
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program,
|
||||
data.getTrainingPositive(), RandomForestFunctionFinderPlugin.FUNC_START,
|
||||
preBytes, initialBytes, params.getIncludeBitFeatures(), monitor));
|
||||
monitor.setMessage("Generating vectors for function interiors");
|
||||
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program,
|
||||
data.getTrainingNegative(), RandomForestFunctionFinderPlugin.NON_START,
|
||||
preBytes, initialBytes, params.getIncludeBitFeatures(), monitor));
|
||||
|
||||
//should only happen with very small training sets where MemoryAccessExceptions
|
||||
//were thrown during vector generation
|
||||
if (trainingData.isEmpty()) {
|
||||
Msg.showWarn(this, null, "Empty Training Set", String.format(
|
||||
"No vectors were generated for supplied addresses. preBytes = %d, " +
|
||||
"initialBytes = %d",
|
||||
preBytes, initialBytes));
|
||||
continue;
|
||||
}
|
||||
|
||||
//train the model
|
||||
EnsembleModel<Label> randomForest = trainModel(trainingData, monitor);
|
||||
|
||||
//evaluate the model and create a RandomForestRowObject
|
||||
AddressSet errors = new AddressSet();
|
||||
int[] confusionMatrix = evaluateModel(randomForest, data.getTestPositive(),
|
||||
data.getTestNegative(), errors, preBytes, initialBytes, monitor);
|
||||
if (!monitor.isCancelled()) {
|
||||
RandomForestRowObject row = new RandomForestRowObject(preBytes,
|
||||
initialBytes, factor, confusionMatrix, randomForest, errors,
|
||||
data.getTrainingPositive(), params.getIncludeBitFeatures());
|
||||
row.setContextRegistersAndValues(params.getContextRegisterNames(),
|
||||
params.getContextRegisterVals());
|
||||
rowObjectConsumer.accept(row);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the training and test sets
|
||||
* @param allEntries function entries
|
||||
* @param allInteriors function interiors
|
||||
* @param definedData defined data
|
||||
* @param factor sampling factor
|
||||
* @param monitor task monitor
|
||||
* @return training and test sets
|
||||
* @throws CancelledException if monitor is canceled
|
||||
*/
|
||||
TrainingAndTestData getTrainingAndTestData(AddressSet allEntries, AddressSet allInteriors,
|
||||
AddressSet definedData, int factor, TaskMonitor monitor) throws CancelledException {
|
||||
//if the user has specified addresses to use in training,
|
||||
//don't allow those addresses to also be selected at random
|
||||
AddressSet selectableEntries = allEntries.subtract(additionalStarts);
|
||||
AddressSet selectableInteriors = allInteriors.subtract(additionalNonStarts);
|
||||
AddressSet trainingPositive = new AddressSet(additionalStarts);
|
||||
AddressSet trainingNegative = new AddressSet(additionalNonStarts);
|
||||
|
||||
//select function entries at random and add them to the training set
|
||||
long numEntries =
|
||||
(int) Math.min(params.getMaxStarts(), selectableEntries.getNumAddresses());
|
||||
monitor.setIndeterminate(true);
|
||||
monitor.setMessage("Selecting " + numEntries + " random function entries");
|
||||
long start = System.nanoTime();
|
||||
AddressSetView randomFuncEntries =
|
||||
RandomSubsetUtils.randomSubset(selectableEntries, numEntries, monitor);
|
||||
long end = System.nanoTime();
|
||||
Msg.info(this, String.format("factor: %d elapsed selecting random entries: %g", factor,
|
||||
(end - start) / NANOSECONDS_PER_SECOND));
|
||||
trainingPositive = trainingPositive.union(randomFuncEntries);
|
||||
if (trainingPositive.isEmpty()) {
|
||||
Msg.showError(this, null, "Data Gathering Error", "No functions in training set");
|
||||
return null;
|
||||
}
|
||||
//function entries that weren't selected are used for testing
|
||||
AddressSet testPositive = selectableEntries.subtract(randomFuncEntries);
|
||||
if (testPositive.isEmpty()) {
|
||||
Msg.showWarn(this, null, "Test Set Warning",
|
||||
"No function entries in test set for models with sampling factor " + factor);
|
||||
}
|
||||
//for the randomly-selected function entries, optionally add the immediately
|
||||
//preceding and following code units to the training and test sets as negative examples
|
||||
AddressSet immediatelyPrecedingTraining = new AddressSet();
|
||||
AddressSet immediatelyFollowingTraining = new AddressSet();
|
||||
AddressSet immediatelyPrecedingTest = new AddressSet();
|
||||
AddressSet immediatelyFollowingTest = new AddressSet();
|
||||
if (params.getIncludePrecedingAndFollowing()) {
|
||||
immediatelyPrecedingTraining =
|
||||
ModelTrainingUtils.getPrecedingAddresses(program, randomFuncEntries, monitor);
|
||||
immediatelyFollowingTraining =
|
||||
ModelTrainingUtils.getFollowingAddresses(program, randomFuncEntries, monitor);
|
||||
immediatelyPrecedingTest =
|
||||
ModelTrainingUtils.getPrecedingAddresses(program, testPositive, monitor);
|
||||
immediatelyFollowingTest =
|
||||
ModelTrainingUtils.getFollowingAddresses(program, testPositive, monitor);
|
||||
//immediatelyPreceding and immediately following can intersect in
|
||||
//certain cases; subtract the overlap
|
||||
AddressSet before = immediatelyPrecedingTraining.union(immediatelyPrecedingTest);
|
||||
AddressSet after = immediatelyFollowingTraining.union(immediatelyFollowingTest);
|
||||
immediatelyPrecedingTraining = immediatelyPrecedingTraining.subtract(after);
|
||||
immediatelyPrecedingTest = immediatelyPrecedingTest.subtract(after);
|
||||
immediatelyFollowingTraining = immediatelyFollowingTraining.subtract(before);
|
||||
immediatelyFollowingTest = immediatelyFollowingTest.subtract(before);
|
||||
//Since immediatelyFollowingXand immediatelyPrecedingX are explicitly
|
||||
//added to the training/test sets, remove them from the set of interiors which could be
|
||||
//randomly selected for inclusion in the training set
|
||||
selectableInteriors = selectableInteriors
|
||||
.subtract(immediatelyFollowingTraining.union(immediatelyPrecedingTraining));
|
||||
selectableInteriors = selectableInteriors
|
||||
.subtract(immediatelyFollowingTest.union(immediatelyPrecedingTest));
|
||||
//remove before and after from definedData
|
||||
definedData = definedData.subtract(before.union(after));
|
||||
}
|
||||
trainingNegative = trainingNegative
|
||||
.union(immediatelyPrecedingTraining.union(immediatelyFollowingTraining));
|
||||
|
||||
//now select random function interiors and add them to training set
|
||||
monitor.setMessage(
|
||||
"Selecting " + numEntries * factor + " random addresses within function interiors");
|
||||
start = System.nanoTime();
|
||||
AddressSetView randomFuncInteriors =
|
||||
RandomSubsetUtils.randomSubset(selectableInteriors, numEntries * factor, monitor);
|
||||
end = System.nanoTime();
|
||||
Msg.info(this, String.format("factor: %d elapsed selecting random interiors: %g seconds",
|
||||
factor, (end - start) / NANOSECONDS_PER_SECOND));
|
||||
trainingNegative = trainingNegative.union(randomFuncInteriors);
|
||||
if (trainingNegative.isEmpty()) {
|
||||
Msg.showError(this, null, "Data Gathering Error",
|
||||
"No function interiors in training set");
|
||||
return null;
|
||||
}
|
||||
if (trainingPositive.intersects(trainingNegative)) {
|
||||
Address first = trainingPositive.findFirstAddressInCommon(trainingNegative);
|
||||
Msg.showWarn(this, null, "Overlap between Training Positive and Training Negative Sets",
|
||||
"Example: " + first.toString());
|
||||
}
|
||||
AddressSet unusedInteriors = selectableInteriors.subtract(randomFuncInteriors);
|
||||
AddressSet testNegative = unusedInteriors.union(definedData);
|
||||
testNegative = testNegative.union(immediatelyPrecedingTest).union(immediatelyFollowingTest);
|
||||
if (testNegative.isEmpty()) {
|
||||
Msg.showWarn(this, null, "Test Set Warning",
|
||||
"No function interiors in test set for models with sampling factor " + factor);
|
||||
}
|
||||
if (testPositive.intersects(testNegative)) {
|
||||
Address first = testPositive.findFirstAddressInCommon(testNegative);
|
||||
Msg.showWarn(this, null, "Overlapping Test Positive and Negative sets",
|
||||
"Example: " + first.toString());
|
||||
}
|
||||
if ((trainingPositive.union(trainingNegative)
|
||||
.intersects(testPositive.union(testNegative)))) {
|
||||
Address first = trainingPositive.union(trainingNegative)
|
||||
.findFirstAddressInCommon(testPositive.union(testNegative));
|
||||
Msg.showWarn(this, null, "Overlapping Training and Test Sets",
|
||||
"Example: " + first.toString());
|
||||
}
|
||||
return new TrainingAndTestData(trainingPositive, trainingNegative, testPositive,
|
||||
testNegative);
|
||||
}
|
||||
|
||||
/**
|
||||
* Trains a model to recognize function entries. Training is performed in parallel.
|
||||
* @param trainingData training vectors
|
||||
* @param monitor task monitor
|
||||
* @return model
|
||||
* @throws CancelledException if monitor is canceled
|
||||
*/
|
||||
EnsembleModel<Label> trainModel(List<Example<Label>> trainingData, TaskMonitor monitor)
|
||||
throws CancelledException {
|
||||
LabelFactory lf = new LabelFactory();
|
||||
ListDataSource<Label> trainingSource = new ListDataSource<>(trainingData, lf,
|
||||
new SimpleDataSourceProvenance(program.getDomainFile().getPathname(), lf));
|
||||
trainingSet = new MutableDataset<>(trainingSource);
|
||||
|
||||
//want to select from sqrt(num features) features at each split
|
||||
float featureFraction = (float) (1.0f / Math.sqrt(trainingSet.getFeatureMap().size()));
|
||||
|
||||
List<CARTClassificationTrainer> trainers = new ArrayList<>();
|
||||
for (int i = 0; i < NUM_TREES; ++i) {
|
||||
monitor.checkCanceled();
|
||||
//Integer.MAX_VALUE: unlimited depth
|
||||
trainers.add(new CARTClassificationTrainer(Integer.MAX_VALUE, featureFraction,
|
||||
ThreadLocalRandom.current().nextLong()));
|
||||
}
|
||||
|
||||
GThreadPool threadPool = GThreadPool.getSharedThreadPool(RANDOM_FOREST_TRAINING_THREADPOOL);
|
||||
monitor.initialize(NUM_TREES);
|
||||
monitor.setMessage("Training random forest");
|
||||
|
||||
ConcurrentQBuilder<CARTClassificationTrainer, TreeModel<Label>> builder =
|
||||
new ConcurrentQBuilder<>();
|
||||
ConcurrentQ<CARTClassificationTrainer, TreeModel<Label>> q =
|
||||
builder.setThreadPool(threadPool)
|
||||
.setCollectResults(true)
|
||||
.setMonitor(monitor)
|
||||
.build(new SingleTreeTrainer());
|
||||
q.add(trainers);
|
||||
EnsembleModel<Label> randomForest = null;
|
||||
try {
|
||||
long start = System.nanoTime();
|
||||
var results = q.waitForResults();
|
||||
long end = System.nanoTime();
|
||||
Msg.info(this,
|
||||
String.format("Training time: %g seconds", (end - start) / NANOSECONDS_PER_SECOND));
|
||||
List<Model<Label>> trees = new ArrayList<>();
|
||||
for (var r : results) {
|
||||
trees.add(r.getResult());
|
||||
}
|
||||
randomForest = WeightedEnsembleModel.createEnsembleFromExistingModels("rf", trees,
|
||||
new VotingCombiner());
|
||||
}
|
||||
catch (Exception e) {
|
||||
monitor.checkCanceled();
|
||||
Msg.error(this, "Exception while training model: " + e.getMessage());
|
||||
}
|
||||
return randomForest;
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates a model
|
||||
*
|
||||
* @param randomForest model to evaluate
|
||||
* @param testPositive test set of function entries
|
||||
* @param testNegative test set of function interiors
|
||||
* @param errors set to place addresses with classifier errors
|
||||
* @param preBytes number of bytes before entries
|
||||
* @param initialBytes number of bytes before interiors
|
||||
* @param monitor task monitor
|
||||
* @return confusion matrix
|
||||
* @throws CancelledException if monitor is canceled
|
||||
*/
|
||||
int[] evaluateModel(EnsembleModel<Label> randomForest, AddressSet testPositive,
|
||||
AddressSet testNegative, AddressSet errors, int preBytes, int initialBytes,
|
||||
TaskMonitor monitor) throws CancelledException {
|
||||
GThreadPool threadPool = GThreadPool.getSharedThreadPool(RANDOM_FOREST_TRAINING_THREADPOOL);
|
||||
long start = System.nanoTime();
|
||||
monitor.setMessage(
|
||||
"Evaluating model (step 1 of 2; " + testPositive.getNumAddresses() + " addresses)");
|
||||
ConcurrentQBuilder<Address, Boolean> evalBuilder = new ConcurrentQBuilder<>();
|
||||
ConcurrentQ<Address, Boolean> evalQ = evalBuilder.setThreadPool(threadPool)
|
||||
.setCollectResults(true)
|
||||
.setMonitor(monitor)
|
||||
.build(new EnsembleEvaluatorCallback(randomForest, program, preBytes, initialBytes,
|
||||
params.getIncludeBitFeatures(), RandomForestFunctionFinderPlugin.FUNC_START));
|
||||
evalQ.add(testPositive.getAddresses(true));
|
||||
int[] confusionMatrix = new int[4];
|
||||
try {
|
||||
Collection<QResult<Address, Boolean>> results = evalQ.waitForResults();
|
||||
updateConfusionMatrix(results, confusionMatrix,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START, errors);
|
||||
}
|
||||
catch (Exception e) {
|
||||
monitor.checkCanceled();
|
||||
Msg.error(this,
|
||||
"Exception while evaluating model on known function starts: " + e.getMessage());
|
||||
}
|
||||
monitor.setMessage(
|
||||
"Evaluating model (step 2 of 2; " + testNegative.getNumAddresses() + " addresses)");
|
||||
evalQ = evalBuilder.setThreadPool(threadPool)
|
||||
.setCollectResults(true)
|
||||
.setMonitor(monitor)
|
||||
.build(new EnsembleEvaluatorCallback(randomForest, program, preBytes, initialBytes,
|
||||
params.getIncludeBitFeatures(), RandomForestFunctionFinderPlugin.NON_START));
|
||||
|
||||
evalQ.add(testNegative.getAddresses(true));
|
||||
try {
|
||||
Collection<QResult<Address, Boolean>> results = evalQ.waitForResults();
|
||||
updateConfusionMatrix(results, confusionMatrix,
|
||||
RandomForestFunctionFinderPlugin.NON_START, errors);
|
||||
}
|
||||
catch (Exception e) {
|
||||
monitor.checkCanceled();
|
||||
Msg.error(this,
|
||||
"Exception while evaluating model on known function interiors: " + e.getMessage());
|
||||
}
|
||||
long end = System.nanoTime();
|
||||
Msg.info(this,
|
||||
String.format("Evaluation time: %g seconds", (end - start) / NANOSECONDS_PER_SECOND));
|
||||
return confusionMatrix;
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the confusion matrix
|
||||
*
|
||||
* @param results results of classifier
|
||||
* @param confusion confusion matrix
|
||||
* @param target correct answer
|
||||
* @param errors set to place addresses with classifier errors
|
||||
* @throws Exception if exception encountered during processing
|
||||
*/
|
||||
void updateConfusionMatrix(Collection<QResult<Address, Boolean>> results, int[] confusion,
|
||||
Label target, AddressSet errors) throws Exception {
|
||||
int trueIndex = target.equals(RandomForestFunctionFinderPlugin.FUNC_START) ? TP : TN;
|
||||
int falseIndex = target.equals(RandomForestFunctionFinderPlugin.FUNC_START) ? FN : FP;
|
||||
for (QResult<Address, Boolean> result : results) {
|
||||
Boolean ans = result.getResult();
|
||||
if (ans == null) {
|
||||
continue;
|
||||
}
|
||||
if (ans) {
|
||||
confusion[trueIndex] += 1;
|
||||
}
|
||||
else {
|
||||
confusion[falseIndex] += 1;
|
||||
errors.add(result.getItem());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private synchronized DatasetView<Label> getBag() {
|
||||
return DatasetView.createBootstrapView(trainingSet, trainingSet.size(),
|
||||
ThreadLocalRandom.current().nextLong());
|
||||
}
|
||||
|
||||
private class SingleTreeTrainer
|
||||
implements QCallback<CARTClassificationTrainer, TreeModel<Label>> {
|
||||
|
||||
@Override
|
||||
public TreeModel<Label> process(CARTClassificationTrainer trainer, TaskMonitor monitor)
|
||||
throws Exception {
|
||||
DatasetView<Label> bag = getBag();
|
||||
TreeModel<Label> tree = trainer.train(bag);
|
||||
monitor.incrementProgress(1);
|
||||
return tree;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
import ghidra.program.model.address.*;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* This is a utility class for generating random subsets of an {@link AddressSetView}.
|
||||
*/
|
||||
public class RandomSubsetUtils {
|
||||
|
||||
private RandomSubsetUtils() {
|
||||
//utility class
|
||||
}
|
||||
|
||||
/**
|
||||
* This method generates a random subset of size {@code k} of {@code addresses}.
|
||||
* <p>
|
||||
* The parameter {@code k} can be of type {@code long}, but you will probably run out of heap
|
||||
* space for large values.
|
||||
*
|
||||
* @param addresses addresses
|
||||
* @param k size of random subset to generate
|
||||
* @param monitor monitor
|
||||
* @return random subset of size k
|
||||
* @throws CancelledException if monitor is canceled
|
||||
*/
|
||||
public static AddressSet randomSubset(AddressSetView addresses, long k, TaskMonitor monitor)
|
||||
throws CancelledException {
|
||||
List<Long> sortedRandom = generateRandomIntegerSubset(addresses.getNumAddresses(), k);
|
||||
Collections.sort(sortedRandom);
|
||||
AddressSet randomAddresses = new AddressSet();
|
||||
AddressIterator iter = addresses.getAddresses(true);
|
||||
int addressesAdded = 0;
|
||||
int addressesVisited = 0;
|
||||
int listIndex = 0;
|
||||
while (iter.hasNext() && addressesAdded < k) {
|
||||
monitor.checkCanceled();
|
||||
Address addr = iter.next();
|
||||
if (sortedRandom.get(listIndex) == addressesVisited) {
|
||||
randomAddresses.add(addr);
|
||||
addressesAdded += 1;
|
||||
listIndex += 1;
|
||||
}
|
||||
addressesVisited += 1;
|
||||
}
|
||||
return randomAddresses;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates of random subset of size {@code k} of the set [0,1,...n-1] by generating
|
||||
* a random permutation
|
||||
* @param n size of set (must be >= 0)
|
||||
* @param k size of random subset (must be >= 0)
|
||||
* @return list of indices of elements in random subset
|
||||
*/
|
||||
public static List<Long> generateRandomIntegerSubset(long n, long k) {
|
||||
if (n < 0) {
|
||||
throw new IllegalArgumentException("n cannot be negative");
|
||||
}
|
||||
if (k < 0) {
|
||||
throw new IllegalArgumentException("k cannot be negative");
|
||||
}
|
||||
if (n < k) {
|
||||
throw new IllegalArgumentException(
|
||||
"size of subset (" + k + ") cannot be larger than size of set (" + n + ")");
|
||||
}
|
||||
Map<Long, Long> permutation = new HashMap<>();
|
||||
for (long i = 0; i < k; ++i) {
|
||||
swap(permutation, i, ThreadLocalRandom.current().nextLong(i, n));
|
||||
}
|
||||
List<Long> random = new ArrayList<>();
|
||||
for (long i = 0; i < k; i++) {
|
||||
random.add(permutation.getOrDefault(i, i));
|
||||
}
|
||||
return random;
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a Map<Long,Long> treated as a permutation p to produce a new permutation p'
|
||||
* such that p'(i) = p(j) and p'(j) = p(i). For i not in the keySet of the map, it is
|
||||
* assumed that p(i) = i.
|
||||
* @param permutation permutation map
|
||||
* @param i index
|
||||
* @param j index
|
||||
*/
|
||||
public static void swap(Map<Long, Long> permutation, long i, long j) {
|
||||
if (i == j) {
|
||||
return;
|
||||
}
|
||||
long ith = permutation.getOrDefault(i, i);
|
||||
long jth = permutation.getOrDefault(j, j);
|
||||
permutation.put(i, jth);
|
||||
permutation.put(j, ith);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import docking.ActionContext;
|
||||
import docking.action.DockingAction;
|
||||
import docking.action.MenuData;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.util.HelpLocation;
|
||||
import ghidra.util.table.GhidraTable;
|
||||
|
||||
/**
|
||||
* A {@link DockingAction} for showing the most similar function starts in the training
|
||||
* set to a possible function start
|
||||
*/
|
||||
public class ShowSimilarStartsAction extends DockingAction {
|
||||
private static final String MENU_TEXT = "Show Similar Function Starts";
|
||||
private static final String ACTION_NAME = "ShowSimilarStartsAction";
|
||||
private static final int NUM_NEIGHBORS = 10;
|
||||
private Program program;
|
||||
private FunctionStartTableModel model;
|
||||
private GhidraTable table;
|
||||
private RandomForestRowObject modelAndParams;
|
||||
private RandomForestFunctionFinderPlugin plugin;
|
||||
private SimilarStartsFinder finder;
|
||||
|
||||
/**
|
||||
* Constructs an action display similar function starts
|
||||
* @param plugin plugin
|
||||
* @param program source program
|
||||
* @param table table
|
||||
* @param model table with action
|
||||
*/
|
||||
public ShowSimilarStartsAction(RandomForestFunctionFinderPlugin plugin, Program program,
|
||||
GhidraTable table, FunctionStartTableModel model) {
|
||||
super(ACTION_NAME, plugin.getName());
|
||||
this.program = program;
|
||||
this.model = model;
|
||||
this.table = table;
|
||||
this.plugin = plugin;
|
||||
this.modelAndParams = model.getRandomForestRowObject();
|
||||
init();
|
||||
finder = new SimilarStartsFinder(program, modelAndParams);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isAddToPopup(ActionContext context) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEnabledForContext(ActionContext context) {
|
||||
return table.getSelectedRowCount() == 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void actionPerformed(ActionContext context) {
|
||||
Address potential = model.getAddress(table.getSelectedRow());
|
||||
List<SimilarStartRowObject> closeNeighbors =
|
||||
finder.getSimilarFunctionStarts(potential, NUM_NEIGHBORS);
|
||||
SimilarStartsTableProvider provider = new SimilarStartsTableProvider(plugin, program,
|
||||
potential, closeNeighbors, modelAndParams);
|
||||
plugin.addProvider(provider);
|
||||
|
||||
}
|
||||
|
||||
private void init() {
|
||||
setPopupMenuData(new MenuData(new String[] { MENU_TEXT }));
|
||||
setDescription(
|
||||
"Displays the most similar function starts in the training set to the given " +
|
||||
"potential start.");
|
||||
setHelpLocation(new HelpLocation(plugin.getName(), ACTION_NAME));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import ghidra.program.model.address.Address;
|
||||
|
||||
/**
|
||||
* Record for storing random forest proximity information for a potential
|
||||
* function start
|
||||
* @param funcStart address
|
||||
* @param numAgreements number of agreeing trees
|
||||
*/
|
||||
public record SimilarStartRowObject(Address funcStart, int numAgreements) {
|
||||
};
|
|
@ -0,0 +1,132 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Map.Entry;
|
||||
|
||||
import org.tribuo.Feature;
|
||||
import org.tribuo.Model;
|
||||
import org.tribuo.classification.Label;
|
||||
import org.tribuo.common.tree.Node;
|
||||
import org.tribuo.common.tree.TreeModel;
|
||||
import org.tribuo.ensemble.EnsembleModel;
|
||||
import org.tribuo.impl.ArrayExample;
|
||||
import org.tribuo.math.la.SparseVector;
|
||||
|
||||
import ghidra.program.model.address.*;
|
||||
import ghidra.program.model.listing.Program;
|
||||
|
||||
/**
|
||||
* Given a potential function start {@code S} and a random forest trained to recognize
|
||||
* function starts, this class is used to find the function starts in the training set
|
||||
* most similar to {@code S}. Here "similar" is defined in terms of proximity in a
|
||||
* random forest (i.e., proportion of trees which agree on two feature vectors).
|
||||
*/
|
||||
public class SimilarStartsFinder {
|
||||
|
||||
private RandomForestRowObject modelAndParams;
|
||||
private Program program;
|
||||
private int preBytes;
|
||||
private int initialBytes;
|
||||
private boolean includeBitFeatures;
|
||||
private Map<Address, List<Node<Label>>> startsToLeafList = new HashMap<>();
|
||||
private EnsembleModel<Label> randomForest;
|
||||
|
||||
/**
|
||||
* Creates a {@link SimilarStartsFinder} for the given program and model
|
||||
* @param program program
|
||||
* @param modelAndParams model and params
|
||||
*/
|
||||
public SimilarStartsFinder(Program program, RandomForestRowObject modelAndParams) {
|
||||
this.program = program;
|
||||
this.modelAndParams = modelAndParams;
|
||||
preBytes = modelAndParams.getNumPreBytes();
|
||||
initialBytes = modelAndParams.getNumInitialBytes();
|
||||
includeBitFeatures = modelAndParams.getIncludeBitLevelFeatures();
|
||||
randomForest = modelAndParams.getRandomForest();
|
||||
computeLeafNodeLists();
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the functions starts in the training set that are most similar to {@code potential}
|
||||
* according to the model
|
||||
* @param potential address of potential start
|
||||
* @param numStarts max size of returned list
|
||||
* @return similar starts (in descending order)
|
||||
*/
|
||||
public List<SimilarStartRowObject> getSimilarFunctionStarts(Address potential, int numStarts) {
|
||||
List<Node<Label>> leafNodes = getLeafNodes(potential);
|
||||
List<SimilarStartRowObject> neighbors = new ArrayList<>(startsToLeafList.size());
|
||||
for (Entry<Address, List<Node<Label>>> entry : startsToLeafList.entrySet()) {
|
||||
Address start = entry.getKey();
|
||||
List<Node<Label>> leafList = entry.getValue();
|
||||
int matches = 0;
|
||||
for (int i = 0; i < randomForest.getNumModels(); ++i) {
|
||||
if (leafNodes.get(i).equals(leafList.get(i))) {
|
||||
matches++;
|
||||
}
|
||||
}
|
||||
neighbors.add(new SimilarStartRowObject(start, matches));
|
||||
}
|
||||
Collections.sort(neighbors,
|
||||
(x, y) -> Integer.compare(y.numAgreements(), x.numAgreements()));
|
||||
List<SimilarStartRowObject> closeNeighbors =
|
||||
neighbors.subList(0, Math.min(numStarts, neighbors.size()));
|
||||
return closeNeighbors;
|
||||
}
|
||||
|
||||
/**
|
||||
* For each function start in the training set and tree in the random forest,
|
||||
* run the corresponding feature vector down the tree and record the leaf node
|
||||
* reached.
|
||||
*/
|
||||
private void computeLeafNodeLists() {
|
||||
AddressSet knownStarts = modelAndParams.getTrainingPositives();
|
||||
AddressIterator addrIter = knownStarts.getAddresses(true);
|
||||
while (addrIter.hasNext()) {
|
||||
Address start = addrIter.next();
|
||||
List<Node<Label>> nodeList = getLeafNodes(start);
|
||||
startsToLeafList.put(start, nodeList);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a feature vector for {@code addr}, runs it down each tree in the forest,
|
||||
* and records the leaf node reached.
|
||||
* @param addr potential function start
|
||||
* @return list of leaf nodes
|
||||
*/
|
||||
List<Node<Label>> getLeafNodes(Address addr) {
|
||||
List<Node<Label>> leafNodes = new ArrayList<>(randomForest.getNumModels());
|
||||
List<Feature> potentialFeatureVector = ModelTrainingUtils.getFeatureVector(program, addr,
|
||||
preBytes, initialBytes, includeBitFeatures);
|
||||
ArrayExample<Label> example =
|
||||
new ArrayExample<>(RandomForestFunctionFinderPlugin.FUNC_START, potentialFeatureVector);
|
||||
SparseVector vec =
|
||||
SparseVector.createSparseVector(example, randomForest.getFeatureIDMap(), false);
|
||||
for (Model<Label> member : randomForest.getModels()) {
|
||||
TreeModel<Label> tree = (TreeModel<Label>) member;
|
||||
Node<Label> node = tree.getRoot();
|
||||
while (!node.isLeaf()) {
|
||||
node = node.getNextNode(vec);
|
||||
}
|
||||
leafNodes.add(node);
|
||||
}
|
||||
return leafNodes;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,177 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import javax.swing.JTable;
|
||||
import javax.swing.table.TableModel;
|
||||
|
||||
import docking.widgets.table.AbstractDynamicTableColumn;
|
||||
import docking.widgets.table.TableColumnDescriptor;
|
||||
import ghidra.docking.settings.Settings;
|
||||
import ghidra.framework.plugintool.PluginTool;
|
||||
import ghidra.framework.plugintool.ServiceProvider;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.program.model.mem.MemoryAccessException;
|
||||
import ghidra.util.datastruct.Accumulator;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.table.AddressBasedTableModel;
|
||||
import ghidra.util.table.column.AbstractGColumnRenderer;
|
||||
import ghidra.util.table.column.GColumnRenderer;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* Table model for a table displaying the closest function starts from the training set
|
||||
* to a given potential start.
|
||||
*/
|
||||
public class SimilarStartsTableModel extends AddressBasedTableModel<SimilarStartRowObject> {
|
||||
|
||||
private Address potentialStart;
|
||||
private List<SimilarStartRowObject> rows;
|
||||
private RandomForestRowObject randomForestRow;
|
||||
|
||||
/**
|
||||
* Construct a table model for a table to display the closest function starts to
|
||||
* a potential function start
|
||||
* @param plugin owning program
|
||||
* @param program program
|
||||
* @param potentialStart address of potential start
|
||||
* @param rows similar function starts
|
||||
* @param randomForestRow model and params
|
||||
*/
|
||||
public SimilarStartsTableModel(PluginTool plugin, Program program, Address potentialStart,
|
||||
List<SimilarStartRowObject> rows, RandomForestRowObject randomForestRow) {
|
||||
super("test", plugin, program, null, false);
|
||||
this.potentialStart = potentialStart;
|
||||
this.rows = rows;
|
||||
this.randomForestRow = randomForestRow;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Address getAddress(int row) {
|
||||
return getRowObject(row).funcStart();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doLoad(Accumulator<SimilarStartRowObject> accumulator, TaskMonitor monitor)
|
||||
throws CancelledException {
|
||||
//add a special row corresponding to the potential function start
|
||||
//want it in the table to facilitate (visual) byte string comparisons
|
||||
accumulator.add(new SimilarStartRowObject(potentialStart,
|
||||
randomForestRow.getRandomForest().getNumModels()));
|
||||
accumulator.addAll(rows);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TableColumnDescriptor<SimilarStartRowObject> createTableColumnDescriptor() {
|
||||
TableColumnDescriptor<SimilarStartRowObject> descriptor = new TableColumnDescriptor<>();
|
||||
descriptor.addVisibleColumn(new AddressTableColumn());
|
||||
descriptor.addVisibleColumn(new SimilarityTableColumn(), 1, false);
|
||||
descriptor.addVisibleColumn(new ByteStringTableColumn());
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
private class AddressTableColumn
|
||||
extends AbstractDynamicTableColumn<SimilarStartRowObject, String, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Address";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getValue(SimilarStartRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider services) throws IllegalArgumentException {
|
||||
String addrString = rowObject.funcStart().toString();
|
||||
//address corresponding to the potential start should stand out in the table
|
||||
//so surround it with asterisks
|
||||
if (rowObject.funcStart().equals(potentialStart)) {
|
||||
addrString = "*" + addrString + "*";
|
||||
}
|
||||
return addrString;
|
||||
}
|
||||
}
|
||||
|
||||
private class SimilarityTableColumn
|
||||
extends AbstractDynamicTableColumn<SimilarStartRowObject, Double, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Similarity";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double getValue(SimilarStartRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider services) throws IllegalArgumentException {
|
||||
return rowObject.numAgreements() * 1.0 /
|
||||
randomForestRow.getRandomForest().getNumModels();
|
||||
}
|
||||
}
|
||||
|
||||
private class ByteStringTableColumn
|
||||
extends AbstractDynamicTableColumn<SimilarStartRowObject, String, Object> {
|
||||
|
||||
@Override
|
||||
public String getColumnName() {
|
||||
return "Byte String";
|
||||
}
|
||||
|
||||
@Override
|
||||
public GColumnRenderer<String> getColumnRenderer() {
|
||||
final GColumnRenderer<String> monospacedRenderer = new AbstractGColumnRenderer<>() {
|
||||
@Override
|
||||
protected void configureFont(JTable table, TableModel model, int column) {
|
||||
setFont(getFixedWidthFont());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFilterString(String t, Settings settings) {
|
||||
return t;
|
||||
}
|
||||
};
|
||||
return monospacedRenderer;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getValue(SimilarStartRowObject rowObject, Settings settings, Object data,
|
||||
ServiceProvider services) throws IllegalArgumentException {
|
||||
Address funcStart = rowObject.funcStart();
|
||||
byte[] bytes =
|
||||
new byte[randomForestRow.getNumPreBytes() + randomForestRow.getNumInitialBytes()];
|
||||
try {
|
||||
program.getMemory()
|
||||
.getBytes(funcStart.subtract(randomForestRow.getNumPreBytes()), bytes);
|
||||
}
|
||||
catch (MemoryAccessException e) {
|
||||
return "??";
|
||||
}
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (int i = 0; i < randomForestRow.getNumPreBytes(); ++i) {
|
||||
sb.append(String.format("%02x ", bytes[i] & 0xff));
|
||||
}
|
||||
sb.append("* ");
|
||||
for (int i = randomForestRow.getNumPreBytes(); i < randomForestRow.getNumPreBytes() +
|
||||
randomForestRow.getNumInitialBytes(); ++i) {
|
||||
sb.append(String.format("%02x ", bytes[i] & 0xff));
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import java.awt.BorderLayout;
|
||||
import java.awt.Dimension;
|
||||
import java.util.List;
|
||||
|
||||
import javax.swing.*;
|
||||
|
||||
import ghidra.app.services.GoToService;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.util.HelpLocation;
|
||||
import ghidra.util.table.GhidraTable;
|
||||
import ghidra.util.table.GhidraThreadedTablePanel;
|
||||
|
||||
/**
|
||||
* Table provider for a table to display the closest function starts in the training
|
||||
* set to a potential function start
|
||||
*/
|
||||
public class SimilarStartsTableProvider extends ProgramAssociatedComponentProviderAdapter {
|
||||
private Program program;
|
||||
private Address potentialStart;
|
||||
private List<SimilarStartRowObject> rows;
|
||||
private JComponent component;
|
||||
private RandomForestRowObject randomForestRow;
|
||||
|
||||
/**
|
||||
* Create a table provider
|
||||
* @param plugin owning plugin
|
||||
* @param program program being search
|
||||
* @param potentialStart address of potential start
|
||||
* @param rows closest potential starts
|
||||
* @param randomForestRow model and params
|
||||
*/
|
||||
public SimilarStartsTableProvider(RandomForestFunctionFinderPlugin plugin, Program program,
|
||||
Address potentialStart, List<SimilarStartRowObject> rows,
|
||||
RandomForestRowObject randomForestRow) {
|
||||
super(program.getName() + ": Similar Function Starts", plugin.getName(), program, plugin);
|
||||
this.program = program;
|
||||
this.potentialStart = potentialStart;
|
||||
this.rows = rows;
|
||||
this.randomForestRow = randomForestRow;
|
||||
this.setSubTitle("Function Starts Similar to " + potentialStart.toString());
|
||||
build();
|
||||
setHelpLocation(new HelpLocation(plugin.getName(), "SimilarStartsTable"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public JComponent getComponent() {
|
||||
return component;
|
||||
}
|
||||
|
||||
private void build() {
|
||||
component = new JPanel(new BorderLayout());
|
||||
SimilarStartsTableModel model =
|
||||
new SimilarStartsTableModel(tool, program, potentialStart, rows, randomForestRow);
|
||||
GhidraThreadedTablePanel<SimilarStartRowObject> similarStartsPanel =
|
||||
new GhidraThreadedTablePanel<>(model, 1000);
|
||||
GhidraTable similarStartsTable = similarStartsPanel.getTable();
|
||||
similarStartsTable.setName(
|
||||
program.getName() + ": Known Starts Similar to " + potentialStart.toString());
|
||||
GoToService goToService = tool.getService(GoToService.class);
|
||||
if (goToService != null) {
|
||||
similarStartsTable.installNavigation(goToService, goToService.getDefaultNavigatable());
|
||||
}
|
||||
similarStartsTable.setNavigateOnSelectionEnabled(true);
|
||||
similarStartsTable.setAutoResizeMode(JTable.AUTO_RESIZE_SUBSEQUENT_COLUMNS);
|
||||
similarStartsTable.setPreferredScrollableViewportSize(new Dimension(900, 300));
|
||||
component.add(similarStartsPanel, BorderLayout.CENTER);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
/**
|
||||
* Container class for {@link AddressSet}s used during model training and testing
|
||||
*/
|
||||
public class TrainingAndTestData {
|
||||
private AddressSet trainingPositive;
|
||||
private AddressSet trainingNegative;
|
||||
private AddressSet testPositive;
|
||||
private AddressSet testNegative;
|
||||
|
||||
public TrainingAndTestData(AddressSet trainingPositive, AddressSet trainingNegative,
|
||||
AddressSet testPositive, AddressSet testNegative) {
|
||||
this.trainingPositive = trainingPositive;
|
||||
this.trainingNegative = trainingNegative;
|
||||
this.testPositive = testPositive;
|
||||
this.testNegative = testNegative;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link AddressSet} of positive examples for training
|
||||
* @return training positive
|
||||
*/
|
||||
public AddressSet getTrainingPositive() {
|
||||
return trainingPositive;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link AddressSet} of negative examples for training
|
||||
* @return training negative
|
||||
*/
|
||||
public AddressSet getTrainingNegative() {
|
||||
return trainingNegative;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link AddressSet} of positive examples for testing
|
||||
* @return test positive
|
||||
*/
|
||||
public AddressSet getTestPositive() {
|
||||
return testPositive;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link AddressSet} of negative examples for testing
|
||||
* @return test negative
|
||||
*/
|
||||
public AddressSet getTestNegative() {
|
||||
return testNegative;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks the sizes of the sets {@code testPositive} and {@code testNegative}. Any set
|
||||
* that is larger than {@code max} is replaced with a random subset of size {@code max}.
|
||||
*
|
||||
* @param max max size of each set
|
||||
* @param monitor task monitor
|
||||
* @throws CancelledException if the monitor is canceled
|
||||
*/
|
||||
public void reduceTestSetSize(long max, TaskMonitor monitor) throws CancelledException {
|
||||
if (testPositive.getNumAddresses() > max) {
|
||||
testPositive = RandomSubsetUtils.randomSubset(testPositive, max, monitor);
|
||||
}
|
||||
if (testNegative.getNumAddresses() > max) {
|
||||
testNegative = RandomSubsetUtils.randomSubset(testNegative, max, monitor);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
The "src/resources/images" directory is intended to hold all image/icon files used by
|
||||
this contrib.
|
|
@ -0,0 +1,242 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.tribuo.Example;
|
||||
import org.tribuo.Feature;
|
||||
import org.tribuo.classification.Label;
|
||||
|
||||
import ghidra.program.database.ProgramBuilder;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.data.DataType;
|
||||
import ghidra.program.model.data.IntegerDataType;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.program.model.mem.MemoryBlock;
|
||||
import ghidra.test.AbstractProgramBasedTest;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
public class DataGatheringUtilsTest extends AbstractProgramBasedTest {
|
||||
|
||||
private final static String BASE_ADDRESS = "0x10000";
|
||||
private final static String ADD_R0_R1_R2_ARM = "02 00 81 e0";
|
||||
private final static String SUB_R4_R5_R6_ARM = "06 40 45 e0";
|
||||
|
||||
private final static String ADD_R0_R1_THUMB = "08 44";
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
initialize();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Program getProgram() throws Exception {
|
||||
return buildProgram();
|
||||
}
|
||||
|
||||
private Program buildProgram() throws Exception {
|
||||
ProgramBuilder builder = new ProgramBuilder("DataGatheringUtilsTest", ProgramBuilder._ARM);
|
||||
MemoryBlock block = builder.createMemory(".text", BASE_ADDRESS, 0x100);
|
||||
builder.setExecute(block, true);
|
||||
//undefined
|
||||
builder.setBytes(BASE_ADDRESS, "00 01 02 03", false);
|
||||
|
||||
builder.setBytes("0x10004", ADD_R0_R1_R2_ARM, true);
|
||||
builder.setBytes("0x10008", SUB_R4_R5_R6_ARM, true);
|
||||
|
||||
builder.setRegisterValue("TMode", "0x1000c", "0x1000f", 1);
|
||||
builder.setBytes("0x1000c", ADD_R0_R1_THUMB, true);
|
||||
builder.setBytes("0x1000e", ADD_R0_R1_THUMB, true);
|
||||
DataType intType = new IntegerDataType();
|
||||
builder.setBytes("0x10020", "00 01 02 03", false);
|
||||
builder.applyDataType("0x10020", intType);
|
||||
builder.setBytes("0x10024", "04 05 06 07", false);
|
||||
builder.applyDataType("0x10024", intType);
|
||||
builder.setBytes("0x10030", "08 09 0a 0b", false);
|
||||
builder.applyDataType("0x10030", intType);
|
||||
return builder.getProgram();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetByteValues() {
|
||||
Address base = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10008);
|
||||
List<Feature> test = ModelTrainingUtils.getFeatureVector(program, base, 1, 1, false);
|
||||
assertEquals(2, test.size());
|
||||
assertEquals("pbyte_0", test.get(0).getName());
|
||||
assertEquals(224.0d, test.get(0).getValue(), 0.0);
|
||||
assertEquals("ibyte_0", test.get(1).getName());
|
||||
assertEquals(6d, test.get(1).getValue(), 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetBitValues() {
|
||||
Address base = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10008);
|
||||
List<Feature> test = ModelTrainingUtils.getFeatureVector(program, base, 1, 1, true);
|
||||
assertEquals(18, test.size());
|
||||
assertEquals("pbyte_0", test.get(0).getName());
|
||||
assertEquals("pbit_0_7", test.get(1).getName());
|
||||
assertEquals("pbit_0_6", test.get(2).getName());
|
||||
assertEquals("pbit_0_5", test.get(3).getName());
|
||||
assertEquals("pbit_0_4", test.get(4).getName());
|
||||
assertEquals("pbit_0_3", test.get(5).getName());
|
||||
assertEquals("pbit_0_2", test.get(6).getName());
|
||||
assertEquals("pbit_0_1", test.get(7).getName());
|
||||
assertEquals("pbit_0_0", test.get(8).getName());
|
||||
assertEquals("ibyte_0", test.get(9).getName());
|
||||
assertEquals("ibit_0_7", test.get(10).getName());
|
||||
assertEquals("ibit_0_6", test.get(11).getName());
|
||||
assertEquals("ibit_0_5", test.get(12).getName());
|
||||
assertEquals("ibit_0_4", test.get(13).getName());
|
||||
assertEquals("ibit_0_3", test.get(14).getName());
|
||||
assertEquals("ibit_0_2", test.get(15).getName());
|
||||
assertEquals("ibit_0_1", test.get(16).getName());
|
||||
assertEquals("ibit_0_0", test.get(17).getName());
|
||||
|
||||
//e0 06
|
||||
assertEquals(224.0d, test.get(0).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ONE, test.get(1).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ONE, test.get(2).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ONE, test.get(3).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(4).getValue(), 0.0);
|
||||
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(5).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(6).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(7).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(8).getValue(), 0.0);
|
||||
|
||||
assertEquals(6d, test.get(9).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(10).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(11).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(12).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(13).getValue(), 0.0);
|
||||
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(14).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ONE, test.get(15).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ONE, test.get(16).getValue(), 0.0);
|
||||
assertEquals(ModelTrainingUtils.ZERO, test.get(17).getValue(), 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFollowingAddresses() throws CancelledException {
|
||||
Address one = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10004);
|
||||
Address two = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1000c);
|
||||
AddressSet addrs = new AddressSet(one);
|
||||
addrs.add(two);
|
||||
AddressSet following =
|
||||
ModelTrainingUtils.getFollowingAddresses(program, addrs, TaskMonitor.DUMMY);
|
||||
assertEquals(2, following.getNumAddresses());
|
||||
assertTrue(following.contains(
|
||||
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10008)));
|
||||
assertTrue(following.contains(
|
||||
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1000e)));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetPrecedingAddresses() throws CancelledException {
|
||||
Address one = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10004);
|
||||
Address two = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1000c);
|
||||
AddressSet addrs = new AddressSet(one);
|
||||
addrs.add(two);
|
||||
AddressSet following =
|
||||
ModelTrainingUtils.getPrecedingAddresses(program, addrs, TaskMonitor.DUMMY);
|
||||
assertEquals(2, following.getNumAddresses());
|
||||
assertTrue(following.contains(
|
||||
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10003)));
|
||||
assertTrue(following.contains(
|
||||
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10008)));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetDefinedData() throws CancelledException {
|
||||
AddressSet data = ModelTrainingUtils.getDefinedData(program, TaskMonitor.DUMMY);
|
||||
assertEquals(12, data.getNumAddresses());
|
||||
Address start = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10020);
|
||||
Address end = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10027);
|
||||
assertTrue(data.contains(start, end));
|
||||
start = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10030);
|
||||
end = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10033);
|
||||
assertTrue(data.contains(start, end));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetVectorsFromAddresses() throws CancelledException {
|
||||
Address base = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10008);
|
||||
AddressSet testSet = new AddressSet(base);
|
||||
|
||||
List<Example<Label>> testVectorList = ModelTrainingUtils.getVectorsFromAddresses(program,
|
||||
testSet, RandomForestFunctionFinderPlugin.FUNC_START, 1, 1, true, TaskMonitor.DUMMY);
|
||||
assertEquals(1, testVectorList.size());
|
||||
Example<Label> testVector = testVectorList.get(0);
|
||||
assertEquals(RandomForestFunctionFinderPlugin.FUNC_START, testVector.getOutput());
|
||||
testFeatureVector(testVector.iterator());
|
||||
|
||||
testVectorList = ModelTrainingUtils.getVectorsFromAddresses(program, testSet,
|
||||
RandomForestFunctionFinderPlugin.NON_START, 1, 1, true, TaskMonitor.DUMMY);
|
||||
assertEquals(1, testVectorList.size());
|
||||
testVector = testVectorList.get(0);
|
||||
assertEquals(RandomForestFunctionFinderPlugin.NON_START, testVector.getOutput());
|
||||
testFeatureVector(testVector.iterator());
|
||||
}
|
||||
|
||||
private void testFeatureVector(Iterator<Feature> iter) {
|
||||
while (iter.hasNext()) {
|
||||
Feature feature = iter.next();
|
||||
switch (feature.getName()) {
|
||||
case "pbyte_0":
|
||||
assertEquals(224d, feature.getValue(), 0.0);
|
||||
break;
|
||||
case "ibyte_0":
|
||||
assertEquals(6d, feature.getValue(), 0.0);
|
||||
break;
|
||||
case "pbit_0_7":
|
||||
case "pbit_0_6":
|
||||
case "pbit_0_5":
|
||||
assertEquals(ModelTrainingUtils.ONE, feature.getValue(), 0.0);
|
||||
break;
|
||||
case "pbit_0_4":
|
||||
case "pbit_0_3":
|
||||
case "pbit_0_2":
|
||||
case "pbit_0_1":
|
||||
case "pbit_0_0":
|
||||
case "ibit_0_7":
|
||||
case "ibit_0_6":
|
||||
case "ibit_0_5":
|
||||
case "ibit_0_4":
|
||||
case "ibit_0_3":
|
||||
assertEquals(ModelTrainingUtils.ZERO, feature.getValue(), 0.0);
|
||||
break;
|
||||
case "ibit_0_2":
|
||||
case "ibit_0_1":
|
||||
assertEquals(ModelTrainingUtils.ONE, feature.getValue(), 0.0);
|
||||
break;
|
||||
case "ibit_0_0":
|
||||
assertEquals(ModelTrainingUtils.ZERO, feature.getValue(), 0.0);
|
||||
break;
|
||||
default:
|
||||
fail("Unknown feature name: " + feature.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.tribuo.*;
|
||||
import org.tribuo.classification.Label;
|
||||
import org.tribuo.classification.LabelFactory;
|
||||
import org.tribuo.classification.baseline.DummyClassifierTrainer;
|
||||
import org.tribuo.classification.ensemble.VotingCombiner;
|
||||
import org.tribuo.datasource.ListDataSource;
|
||||
import org.tribuo.ensemble.EnsembleModel;
|
||||
import org.tribuo.ensemble.WeightedEnsembleModel;
|
||||
import org.tribuo.provenance.SimpleDataSourceProvenance;
|
||||
|
||||
import ghidra.program.database.ProgramBuilder;
|
||||
import ghidra.program.database.ProgramDB;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.test.AbstractProgramBasedTest;
|
||||
import ghidra.test.ClassicSampleX86ProgramBuilder;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
public class EnsembleEvaluatorTest extends AbstractProgramBasedTest {
|
||||
|
||||
private ProgramBuilder builder;
|
||||
|
||||
@Override
|
||||
protected Program getProgram() throws Exception {
|
||||
builder = new ClassicSampleX86ProgramBuilder();
|
||||
ProgramDB p = builder.getProgram();
|
||||
return p;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void basicEvaluatorTest() throws Exception {
|
||||
initialize();
|
||||
DummyClassifierTrainer dummyStartTrainer = DummyClassifierTrainer
|
||||
.createConstantTrainer(RandomForestFunctionFinderPlugin.FUNC_START.getLabel());
|
||||
DummyClassifierTrainer dummyNonStartTrainer = DummyClassifierTrainer
|
||||
.createConstantTrainer(RandomForestFunctionFinderPlugin.NON_START.getLabel());
|
||||
List<Example<Label>> trainingData = new ArrayList<>();
|
||||
AddressSet starts = new AddressSet();
|
||||
AddressSet nonStarts = new AddressSet();
|
||||
|
||||
//just need an address with 1 defined preByte and 1 defined byte
|
||||
Address testStart = program.getSymbolTable().getSymbols("entry").next().getAddress().add(1);
|
||||
Address testNonStart = testStart.add(1);
|
||||
starts.add(testStart);
|
||||
nonStarts.add(testNonStart);
|
||||
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, starts,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START, 1, 1, true, TaskMonitor.DUMMY));
|
||||
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, nonStarts,
|
||||
RandomForestFunctionFinderPlugin.NON_START, 1, 1, true, TaskMonitor.DUMMY));
|
||||
LabelFactory lf = new LabelFactory();
|
||||
ListDataSource<Label> trainingSource =
|
||||
new ListDataSource<Label>(trainingData, lf, new SimpleDataSourceProvenance("test", lf));
|
||||
MutableDataset<Label> trainingSet = new MutableDataset<>(trainingSource);
|
||||
|
||||
//10 models, encounters 5 yes votes first
|
||||
List<Model<Label>> models = new ArrayList<>();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
models.add(dummyStartTrainer.train(trainingSet));
|
||||
models.add(dummyNonStartTrainer.train(trainingSet));
|
||||
}
|
||||
|
||||
EnsembleModel<Label> ensemble = WeightedEnsembleModel
|
||||
.createEnsembleFromExistingModels("test", models, new VotingCombiner());
|
||||
EnsembleEvaluatorCallback testEval = new EnsembleEvaluatorCallback(ensemble, program, 1, 1,
|
||||
true, RandomForestFunctionFinderPlugin.FUNC_START);
|
||||
Boolean res = testEval.process(testStart, TaskMonitor.DUMMY);
|
||||
assertTrue(res);
|
||||
|
||||
//10 models, encounters 5 no votes first
|
||||
//there are also 5 voting yes, ties should go to yes
|
||||
models.clear();
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
models.add(dummyNonStartTrainer.train(trainingSet));
|
||||
models.add(dummyStartTrainer.train(trainingSet));
|
||||
}
|
||||
ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("test", models,
|
||||
new VotingCombiner());
|
||||
testEval = new EnsembleEvaluatorCallback(ensemble, program, 1, 1, true,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START);
|
||||
res = testEval.process(testStart, TaskMonitor.DUMMY);
|
||||
assertTrue(res);
|
||||
|
||||
//10 models, encounters 4 yes votes then 6 no votes
|
||||
models.clear();
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
models.add(dummyStartTrainer.train(trainingSet));
|
||||
}
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
models.add(dummyNonStartTrainer.train(trainingSet));
|
||||
}
|
||||
ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("test", models,
|
||||
new VotingCombiner());
|
||||
testEval = new EnsembleEvaluatorCallback(ensemble, program, 1, 1, true,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START);
|
||||
res = testEval.process(testStart, TaskMonitor.DUMMY);
|
||||
assertFalse(res);
|
||||
|
||||
//11 models, encounters 5 no votes first then 6 yes votes
|
||||
models.clear();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
models.add(dummyNonStartTrainer.train(trainingSet));
|
||||
models.add(dummyStartTrainer.train(trainingSet));
|
||||
}
|
||||
models.add(dummyStartTrainer.train(trainingSet));
|
||||
ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("test", models,
|
||||
new VotingCombiner());
|
||||
testEval = new EnsembleEvaluatorCallback(ensemble, program, 1, 1, true,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START);
|
||||
res = testEval.process(testStart, TaskMonitor.DUMMY);
|
||||
assertTrue(res);
|
||||
|
||||
//11 models, encounters 5 yes votes first then 6 no votes
|
||||
models.clear();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
models.add(dummyStartTrainer.train(trainingSet));
|
||||
}
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
models.add(dummyNonStartTrainer.train(trainingSet));
|
||||
}
|
||||
ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("test", models,
|
||||
new VotingCombiner());
|
||||
testEval = new EnsembleEvaluatorCallback(ensemble, program, 1, 1, true,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START);
|
||||
res = testEval.process(testStart, TaskMonitor.DUMMY);
|
||||
assertFalse(res);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,186 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import ghidra.program.database.ProgramBuilder;
|
||||
import ghidra.program.model.address.*;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.program.model.mem.MemoryBlock;
|
||||
import ghidra.test.AbstractProgramBasedTest;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
public class FunctionStartRFParamsProgramBasedTest extends AbstractProgramBasedTest {
|
||||
|
||||
private final static String BASE_ADDRESS = "0x10000";
|
||||
private final static String ADD_R0_R1_R2_ARM = "02 00 81 e0";
|
||||
private final static String BX_LR_ARM = "1e ff 2f e1";
|
||||
|
||||
private final static String ADD_R0_R1_THUMB = "08 44";
|
||||
private final static String BX_LR_THUMB = "70 47";
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
initialize();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Program getProgram() throws Exception {
|
||||
return buildProgram();
|
||||
}
|
||||
|
||||
private Program buildProgram() throws Exception {
|
||||
ProgramBuilder builder = new ProgramBuilder("DataGatheringUtilsTest", ProgramBuilder._ARM);
|
||||
MemoryBlock block = builder.createMemory(".text", BASE_ADDRESS, 0x100);
|
||||
builder.setExecute(block, true);
|
||||
//undefined
|
||||
builder.setBytes(BASE_ADDRESS, "00 01 02 03", false);
|
||||
|
||||
//small arm function
|
||||
builder.setBytes("0x10004", ADD_R0_R1_R2_ARM, true);
|
||||
builder.setBytes("0x10008", BX_LR_ARM, true);
|
||||
builder.createFunction("0x10004");
|
||||
|
||||
//larger arm function
|
||||
builder.setBytes("0x1000c", ADD_R0_R1_R2_ARM, true);
|
||||
builder.setBytes("0x10010", ADD_R0_R1_R2_ARM, true);
|
||||
builder.setBytes("0x10014", ADD_R0_R1_R2_ARM, true);
|
||||
builder.setBytes("0x10018", ADD_R0_R1_R2_ARM, true);
|
||||
builder.setBytes("0x1001c", BX_LR_ARM, true);
|
||||
builder.createFunction("0x1000c");
|
||||
|
||||
builder.setRegisterValue("TMode", "0x10020", "0x10036", 1);
|
||||
|
||||
//small thumb function
|
||||
builder.setBytes("0x10020", ADD_R0_R1_THUMB, true);
|
||||
builder.setBytes("0x10022", BX_LR_THUMB, true);
|
||||
builder.createFunction("0x10020");
|
||||
|
||||
//larger thumb function
|
||||
builder.setBytes("0x10024", ADD_R0_R1_THUMB, true);
|
||||
builder.setBytes("0x10026", ADD_R0_R1_THUMB, true);
|
||||
builder.setBytes("0x10028", ADD_R0_R1_THUMB, true);
|
||||
builder.setBytes("0x1002a", ADD_R0_R1_THUMB, true);
|
||||
builder.setBytes("0x1002c", ADD_R0_R1_THUMB, true);
|
||||
builder.setBytes("0x1002e", ADD_R0_R1_THUMB, true);
|
||||
builder.setBytes("0x10030", ADD_R0_R1_THUMB, true);
|
||||
builder.setBytes("0x10032", BX_LR_THUMB, true);
|
||||
builder.createFunction("0x10024");
|
||||
|
||||
return builder.getProgram();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCheckContextRegisters() {
|
||||
FunctionStartRFParams params = new FunctionStartRFParams(program);
|
||||
Address armFunc = program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1000c);
|
||||
Address thumbFunc =
|
||||
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10020);
|
||||
assertTrue(params.isContextCompatible(armFunc));
|
||||
assertTrue(params.isContextCompatible(thumbFunc));
|
||||
params.setRegistersAndValues("TMode=1");
|
||||
assertFalse(params.isContextCompatible(armFunc));
|
||||
assertTrue(params.isContextCompatible(thumbFunc));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputeEntriesAndInteriors() throws CancelledException {
|
||||
|
||||
AddressSpace defaultSpace = program.getAddressFactory().getDefaultAddressSpace();
|
||||
|
||||
Address smallArmEntry = defaultSpace.getAddress(0x10004);
|
||||
//instruction alignment for the processor is 2, even in ARM mode
|
||||
//so we need to adjust the arm interiors
|
||||
AddressSet smallArmInterior = new AddressSet(defaultSpace.getAddress(0x10008));
|
||||
smallArmInterior.add(defaultSpace.getAddress(0x10006));
|
||||
smallArmInterior.add(defaultSpace.getAddress(0x1000a));
|
||||
|
||||
Address largeArmEntry = defaultSpace.getAddress(0x1000c);
|
||||
AddressSet largeArmInterior = new AddressSet(defaultSpace.getAddress(0x10010));
|
||||
largeArmInterior.add(defaultSpace.getAddress(0x1000e));
|
||||
largeArmInterior.add(defaultSpace.getAddress(0x10012));
|
||||
largeArmInterior.add(defaultSpace.getAddress(0x10014));
|
||||
largeArmInterior.add(defaultSpace.getAddress(0x10016));
|
||||
largeArmInterior.add(defaultSpace.getAddress(0x10018));
|
||||
largeArmInterior.add(defaultSpace.getAddress(0x1001a));
|
||||
largeArmInterior.add(defaultSpace.getAddress(0x1001c));
|
||||
largeArmInterior.add(defaultSpace.getAddress(0x1001e));
|
||||
|
||||
Address smallThumbEntry = defaultSpace.getAddress(0x10020);
|
||||
AddressSet smallThumbInterior = new AddressSet(defaultSpace.getAddress(0x10022));
|
||||
|
||||
Address largeThumbEntry = defaultSpace.getAddress(0x10024);
|
||||
AddressSet largeThumbInterior = new AddressSet(defaultSpace.getAddress(0x10026));
|
||||
largeThumbInterior.add(defaultSpace.getAddress(0x10028));
|
||||
largeThumbInterior.add(defaultSpace.getAddress(0x1002a));
|
||||
largeThumbInterior.add(defaultSpace.getAddress(0x1002c));
|
||||
largeThumbInterior.add(defaultSpace.getAddress(0x1002e));
|
||||
largeThumbInterior.add(defaultSpace.getAddress(0x10030));
|
||||
largeThumbInterior.add(defaultSpace.getAddress(0x10032));
|
||||
|
||||
FunctionStartRFParams params = new FunctionStartRFParams(program);
|
||||
params.computeFuncEntriesAndInteriors(TaskMonitor.DUMMY);
|
||||
AddressSet entries = params.getFuncEntries();
|
||||
AddressSet interiors = params.getFuncInteriors();
|
||||
assertTrue(!entries.intersects(interiors));
|
||||
assertEquals(4, entries.getNumAddresses());
|
||||
assertTrue(entries.contains(smallArmEntry));
|
||||
assertTrue(entries.contains(largeArmEntry));
|
||||
assertTrue(entries.contains(smallThumbEntry));
|
||||
assertTrue(entries.contains(largeThumbEntry));
|
||||
|
||||
assertEquals(
|
||||
smallThumbInterior.getNumAddresses() + largeThumbInterior.getNumAddresses() +
|
||||
smallArmInterior.getNumAddresses() + largeArmInterior.getNumAddresses(),
|
||||
interiors.getNumAddresses());
|
||||
assertTrue(interiors.contains(smallThumbInterior.union(largeThumbInterior)
|
||||
.union(smallArmInterior)
|
||||
.union(largeArmInterior)));
|
||||
|
||||
params = new FunctionStartRFParams(program);
|
||||
params.setMinFuncSize(10);
|
||||
params.setRegistersAndValues("TMode=0");
|
||||
params.computeFuncEntriesAndInteriors(TaskMonitor.DUMMY);
|
||||
entries = params.getFuncEntries();
|
||||
interiors = params.getFuncInteriors();
|
||||
assertTrue(!entries.intersects(interiors));
|
||||
assertEquals(1, entries.getNumAddresses());
|
||||
assertTrue(entries.contains(largeArmEntry));
|
||||
|
||||
assertTrue(interiors.contains(largeArmInterior));
|
||||
assertEquals(interiors.getNumAddresses(), largeArmInterior.getNumAddresses());
|
||||
|
||||
params = new FunctionStartRFParams(program);
|
||||
params.setMinFuncSize(10);
|
||||
params.setRegistersAndValues("TMode=1");
|
||||
params.computeFuncEntriesAndInteriors(TaskMonitor.DUMMY);
|
||||
entries = params.getFuncEntries();
|
||||
interiors = params.getFuncInteriors();
|
||||
assertTrue(!entries.intersects(interiors));
|
||||
assertEquals(1, entries.getNumAddresses());
|
||||
assertTrue(entries.contains(largeThumbEntry));
|
||||
|
||||
assertTrue(interiors.contains(largeThumbInterior));
|
||||
assertEquals(interiors.getNumAddresses(), largeThumbInterior.getNumAddresses());
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,361 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.tribuo.*;
|
||||
import org.tribuo.classification.Label;
|
||||
import org.tribuo.classification.LabelFactory;
|
||||
import org.tribuo.classification.baseline.DummyClassifierTrainer;
|
||||
import org.tribuo.classification.ensemble.VotingCombiner;
|
||||
import org.tribuo.datasource.ListDataSource;
|
||||
import org.tribuo.ensemble.EnsembleModel;
|
||||
import org.tribuo.ensemble.WeightedEnsembleModel;
|
||||
import org.tribuo.provenance.SimpleDataSourceProvenance;
|
||||
|
||||
import generic.concurrent.QResult;
|
||||
import ghidra.program.database.ProgramBuilder;
|
||||
import ghidra.program.database.ProgramDB;
|
||||
import ghidra.program.model.address.Address;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.listing.Function;
|
||||
import ghidra.program.model.listing.Program;
|
||||
import ghidra.program.util.ProgramSelection;
|
||||
import ghidra.test.AbstractProgramBasedTest;
|
||||
import ghidra.test.ClassicSampleX86ProgramBuilder;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskLauncher;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
public class RandomForestTrainingTaskTest extends AbstractProgramBasedTest {
|
||||
|
||||
private ProgramBuilder builder;
|
||||
private FunctionStartRFParams params;
|
||||
|
||||
@Before
|
||||
public void setup() throws Exception {
|
||||
initialize();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Program getProgram() throws Exception {
|
||||
builder = new ClassicSampleX86ProgramBuilder();
|
||||
ProgramDB p = builder.getProgram();
|
||||
return p;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUpdateConfusionMatrix() throws Exception {
|
||||
//test with one true positive, one false positive, one true negative, and one false negative
|
||||
|
||||
Address truePositive = program.getSymbolTable().getSymbols("entry").next().getAddress();
|
||||
Address falsePositive = truePositive.add(1);
|
||||
Address falseNegative = falsePositive.add(1);
|
||||
Address trueNegative = falseNegative.add(1);
|
||||
int[] confusionMatrix = new int[RandomForestTrainingTask.CONFUSION_MATRIX_SIZE];
|
||||
AddressSet errors = new AddressSet();
|
||||
List<QResult<Address, Boolean>> examples = new ArrayList<>();
|
||||
examples.add(new QResult<>(truePositive, CompletableFuture.completedFuture(true)));
|
||||
examples.add(new QResult<>(falsePositive, CompletableFuture.completedFuture(false)));
|
||||
//don't need any of the fields of RandomForestTrainingTask for this test
|
||||
RandomForestTrainingTask task = new RandomForestTrainingTask(program, null, null,
|
||||
RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT);
|
||||
task.updateConfusionMatrix(examples, confusionMatrix,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START, errors);
|
||||
assertEquals(1, errors.getNumAddresses());
|
||||
assertTrue(errors.contains(falsePositive));
|
||||
errors.clear();
|
||||
examples.clear();
|
||||
examples.add(new QResult<>(falseNegative, CompletableFuture.completedFuture(false)));
|
||||
examples.add(new QResult<>(trueNegative, CompletableFuture.completedFuture(true)));
|
||||
task.updateConfusionMatrix(examples, confusionMatrix,
|
||||
RandomForestFunctionFinderPlugin.NON_START, errors);
|
||||
assertEquals(1, errors.getNumAddresses());
|
||||
assertTrue(errors.contains(falseNegative));
|
||||
assertEquals(1, confusionMatrix[RandomForestTrainingTask.TP]);
|
||||
assertEquals(1, confusionMatrix[RandomForestTrainingTask.FP]);
|
||||
assertEquals(1, confusionMatrix[RandomForestTrainingTask.TN]);
|
||||
assertEquals(1, confusionMatrix[RandomForestTrainingTask.FN]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluateModel() throws CancelledException {
|
||||
DummyClassifierTrainer dummyStartTrainer = DummyClassifierTrainer
|
||||
.createConstantTrainer(RandomForestFunctionFinderPlugin.FUNC_START.getLabel());
|
||||
List<Example<Label>> trainingData = new ArrayList<>();
|
||||
AddressSet starts = new AddressSet();
|
||||
AddressSet nonStarts = new AddressSet();
|
||||
|
||||
//just need an address with 1 defined preByte and 1 defined byte
|
||||
Address testStart = program.getSymbolTable().getSymbols("entry").next().getAddress().add(1);
|
||||
Address testNonStart = testStart.add(1);
|
||||
starts.add(testStart);
|
||||
nonStarts.add(testNonStart);
|
||||
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, starts,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START, 1, 1, true, TaskMonitor.DUMMY));
|
||||
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, nonStarts,
|
||||
RandomForestFunctionFinderPlugin.NON_START, 1, 1, true, TaskMonitor.DUMMY));
|
||||
LabelFactory lf = new LabelFactory();
|
||||
ListDataSource<Label> trainingSource =
|
||||
new ListDataSource<Label>(trainingData, lf, new SimpleDataSourceProvenance("test", lf));
|
||||
MutableDataset<Label> trainingSet = new MutableDataset<>(trainingSource);
|
||||
|
||||
//create ensemble with two models which always report FUNC_START
|
||||
List<Model<Label>> models = new ArrayList<>();
|
||||
models.add(dummyStartTrainer.train(trainingSet));
|
||||
models.add(dummyStartTrainer.train(trainingSet));
|
||||
WeightedEnsembleModel<Label> ensemble = WeightedEnsembleModel
|
||||
.createEnsembleFromExistingModels("test", models, new VotingCombiner());
|
||||
params = new FunctionStartRFParams(program);
|
||||
params.setIncludeBitFeatures(true);
|
||||
AddressSet errors = new AddressSet();
|
||||
RandomForestTrainingTask task = new RandomForestTrainingTask(program, params, null,
|
||||
RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT);
|
||||
int[] confusion =
|
||||
task.evaluateModel(ensemble, starts, nonStarts, errors, 1, 1, TaskMonitor.DUMMY);
|
||||
assertEquals(1, errors.getNumAddresses());
|
||||
assertTrue(errors.contains(testNonStart));
|
||||
assertEquals(1, confusion[RandomForestTrainingTask.TP]);
|
||||
assertEquals(1, confusion[RandomForestTrainingTask.FP]);
|
||||
assertEquals(0, confusion[RandomForestTrainingTask.TN]);
|
||||
assertEquals(0, confusion[RandomForestTrainingTask.FN]);
|
||||
|
||||
//create ensemble with two models which always report NON_START
|
||||
DummyClassifierTrainer dummyNonStartTrainer = DummyClassifierTrainer
|
||||
.createConstantTrainer(RandomForestFunctionFinderPlugin.NON_START.getLabel());
|
||||
models = new ArrayList<>();
|
||||
models.add(dummyNonStartTrainer.train(trainingSet));
|
||||
models.add(dummyNonStartTrainer.train(trainingSet));
|
||||
ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("test", models,
|
||||
new VotingCombiner());
|
||||
errors = new AddressSet();
|
||||
confusion =
|
||||
task.evaluateModel(ensemble, starts, nonStarts, errors, 1, 1, TaskMonitor.DUMMY);
|
||||
assertEquals(1, errors.getNumAddresses());
|
||||
assertTrue(errors.contains(testStart));
|
||||
assertEquals(0, confusion[RandomForestTrainingTask.TP]);
|
||||
assertEquals(0, confusion[RandomForestTrainingTask.FP]);
|
||||
assertEquals(1, confusion[RandomForestTrainingTask.TN]);
|
||||
assertEquals(1, confusion[RandomForestTrainingTask.FN]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTrainModel() throws CancelledException {
|
||||
//train an ensemble on a trivial data set and verify that it contains the correct number
|
||||
//of models
|
||||
List<Example<Label>> trainingData = new ArrayList<>();
|
||||
AddressSet starts = new AddressSet();
|
||||
AddressSet nonStarts = new AddressSet();
|
||||
//just need an address with 1 defined preByte and 1 defined byte
|
||||
Address testStart = program.getSymbolTable().getSymbols("entry").next().getAddress().add(1);
|
||||
Address testNonStart = testStart.add(1);
|
||||
starts.add(testStart);
|
||||
nonStarts.add(testNonStart);
|
||||
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, starts,
|
||||
RandomForestFunctionFinderPlugin.FUNC_START, 1, 1, true, TaskMonitor.DUMMY));
|
||||
trainingData.addAll(ModelTrainingUtils.getVectorsFromAddresses(program, nonStarts,
|
||||
RandomForestFunctionFinderPlugin.NON_START, 1, 1, true, TaskMonitor.DUMMY));
|
||||
//don't need any of the fields of RandomForestTrainingTask for this test
|
||||
RandomForestTrainingTask task = new RandomForestTrainingTask(program, null, null,
|
||||
RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT);
|
||||
EnsembleModel<Label> ensemble = task.trainModel(trainingData, TaskMonitor.DUMMY);
|
||||
assertEquals(RandomForestTrainingTask.NUM_TREES, ensemble.getNumModels());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimilarStartsFinder() {
|
||||
params = new FunctionStartRFParams(program);
|
||||
List<Integer> testList = new ArrayList<>();
|
||||
testList.add(5);
|
||||
params.setFactors(testList);
|
||||
params.setIncludeBitFeatures(true);
|
||||
params.setIncludePrecedingAndFollowing(true);
|
||||
params.setInitialBytes(testList);
|
||||
params.setMaxStarts(100);
|
||||
params.setMinFuncSize(16);
|
||||
params.setPreBytes(testList);
|
||||
List<RandomForestRowObject> rows = new ArrayList<>();
|
||||
RandomForestTrainingTask task =
|
||||
new RandomForestTrainingTask(program, params, x -> rows.add(x), 100);
|
||||
TaskLauncher.launchModal("test", task);
|
||||
assertEquals(1, rows.size());
|
||||
SimilarStartsFinder finder = new SimilarStartsFinder(program, rows.get(0));
|
||||
Address entryAddr = program.getSymbolTable().getSymbols("entry").next().getAddress();
|
||||
List<SimilarStartRowObject> res = finder.getSimilarFunctionStarts(entryAddr, 7);
|
||||
//just verify that the number of elements is correct, each element is a function start,
|
||||
//and that the list is in descending order.
|
||||
assertEquals(7, res.size());
|
||||
res.forEach(
|
||||
r -> assertTrue(program.getFunctionManager().getFunctionAt(r.funcStart()) != null));
|
||||
int currentNum = res.get(0).numAgreements();
|
||||
for (int i = 1; i < 7; ++i) {
|
||||
assertTrue(currentNum >= res.get(i).numAgreements());
|
||||
currentNum = res.get(i).numAgreements();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getTrainingAndTestDataBasicTest() throws CancelledException {
|
||||
params = new FunctionStartRFParams(program);
|
||||
params.setMaxStarts(5);
|
||||
Address begin = program.getSymbolTable().getSymbols("entry").next().getAddress();
|
||||
AddressSet entries = new AddressSet();
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
entries.add(begin.add(i));
|
||||
}
|
||||
AddressSet interiors = new AddressSet();
|
||||
for (int i = 10; i < 25; ++i) {
|
||||
interiors.add(begin.add(i));
|
||||
}
|
||||
AddressSet definedData = new AddressSet();
|
||||
for (int i = 25; i < 30; ++i) {
|
||||
definedData.add(begin.add(i));
|
||||
}
|
||||
RandomForestTrainingTask task = new RandomForestTrainingTask(program, params, null,
|
||||
RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT);
|
||||
TrainingAndTestData data =
|
||||
task.getTrainingAndTestData(entries, interiors, definedData, 2, TaskMonitor.DUMMY);
|
||||
//5 function starts chosen from 10 possible
|
||||
assertEquals(5, data.getTrainingPositive().getNumAddresses());
|
||||
//5*2 interiors chosen from 15 possible
|
||||
assertEquals(10, data.getTrainingNegative().getNumAddresses());
|
||||
//5 function starts were not chosen
|
||||
assertEquals(5, data.getTestPositive().getNumAddresses());
|
||||
//5 interiors were not chosen + 5 defined data
|
||||
assertEquals(10, data.getTestNegative().getNumAddresses());
|
||||
assertTrue(data.getTestPositive()
|
||||
.union(data.getTestNegative())
|
||||
.intersect(data.getTrainingNegative().union(data.getTrainingPositive()))
|
||||
.isEmpty());
|
||||
assertTrue(data.getTestPositive().intersect(data.getTestNegative()).isEmpty());
|
||||
assertTrue(data.getTrainingPositive().intersect(data.getTrainingNegative()).isEmpty());
|
||||
assertTrue(entries.contains(data.getTestPositive()));
|
||||
assertTrue(entries.contains(data.getTrainingPositive()));
|
||||
assertTrue(interiors.contains(data.getTrainingNegative()));
|
||||
assertTrue(interiors.contains(data.getTestNegative().subtract(definedData)));
|
||||
assertTrue(data.getTestNegative().contains(definedData));
|
||||
}
|
||||
|
||||
//FUN_010059a3 is legit
|
||||
|
||||
//
|
||||
// 100641f 00
|
||||
// 1006420 55 PUSH EBP <- entry
|
||||
// 1006421 8b ec MOV EBP, ESP
|
||||
//
|
||||
// 1006423 6a ff PUSH -0x1
|
||||
// 1006425 66 88 18 00 01 PUSH DAT_01001888
|
||||
// 100642a 68 d0 65 00 01 PUSH DAT_010065d0
|
||||
//
|
||||
// 100642f 64 a1 00 00 00 00 MOV EAX, FS:[0x0]
|
||||
// 1006435 50 PUSH EAX
|
||||
// 1006436 64 89 25 00 00 00 00 MOV dword ptr FS:[0x0],ESP
|
||||
//
|
||||
@Test
|
||||
public void getTrainingAndTestDataDeluxeTest() throws CancelledException {
|
||||
params = new FunctionStartRFParams(program);
|
||||
params.setMaxStarts(2);
|
||||
params.setIncludePrecedingAndFollowing(true);
|
||||
Address begin = program.getSymbolTable().getSymbols("entry").next().getAddress();
|
||||
|
||||
//create 3 starts, spaced out because we want to include the previous and following
|
||||
AddressSet entries = new AddressSet();
|
||||
entries.add(begin);
|
||||
entries.add(program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1006425));
|
||||
entries.add(program.getAddressFactory().getDefaultAddressSpace().getAddress(0x1006435));
|
||||
|
||||
AddressSet interiors = new AddressSet();
|
||||
for (int i = 0x20; i < 0x25; ++i) {
|
||||
interiors.add(begin.add(i));
|
||||
}
|
||||
AddressSet definedData = new AddressSet();
|
||||
for (int i = 0x30; i < 0x35; ++i) {
|
||||
definedData.add(begin.add(i));
|
||||
}
|
||||
|
||||
RandomForestTrainingTask task = new RandomForestTrainingTask(program, params, null,
|
||||
RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT);
|
||||
Address otherFuncEntry =
|
||||
program.getAddressFactory().getDefaultAddressSpace().getAddress(0x10059a3);
|
||||
Function func = program.getFunctionManager().getFunctionAt(otherFuncEntry);
|
||||
task.setAdditional(
|
||||
new ProgramSelection(func.getBody().getMinAddress(), func.getBody().getMaxAddress()));
|
||||
|
||||
TrainingAndTestData data =
|
||||
task.getTrainingAndTestData(entries, interiors, definedData, 2, TaskMonitor.DUMMY);
|
||||
|
||||
//2 function starts chosen from 3 possible
|
||||
//plus the entry of the function at 0x1005913 which was added explicitly
|
||||
assertEquals(3, data.getTrainingPositive().getNumAddresses());
|
||||
assertTrue(data.getTrainingPositive().contains(otherFuncEntry));
|
||||
|
||||
//2*2 interiors chosen from 5 possible
|
||||
//2*2 plus preceding and following for two functions selected at random
|
||||
//plus size of interior of function at 0x10059a3
|
||||
assertEquals(4 + 4 + func.getBody().getNumAddresses() - 1,
|
||||
data.getTrainingNegative().getNumAddresses());
|
||||
//1 function start was not chosen
|
||||
assertEquals(1, data.getTestPositive().getNumAddresses());
|
||||
//1 interior was not chosen + (preceding and following for unchosen entry) + 5 defined data
|
||||
assertEquals(8, data.getTestNegative().getNumAddresses());
|
||||
assertTrue(data.getTestPositive()
|
||||
.union(data.getTestNegative())
|
||||
.intersect(data.getTrainingNegative().union(data.getTrainingPositive()))
|
||||
.isEmpty());
|
||||
assertTrue(data.getTestPositive().intersect(data.getTestNegative()).isEmpty());
|
||||
assertTrue(data.getTrainingPositive().intersect(data.getTrainingNegative()).isEmpty());
|
||||
assertTrue(entries.contains(data.getTestPositive()));
|
||||
assertFalse(entries.contains(data.getTrainingPositive()));
|
||||
AddressSet deluxeEntries = entries.union(new AddressSet(otherFuncEntry));
|
||||
assertTrue(deluxeEntries.contains(data.getTrainingPositive()));
|
||||
|
||||
int numContained = 0;
|
||||
Address entry_1006420 = begin;
|
||||
Address entry_1006425 = begin.add(5l);
|
||||
Address entry_1006435 = begin.add(0x15l);
|
||||
|
||||
if (data.getTrainingPositive().contains(entry_1006420)) {
|
||||
numContained += 1;
|
||||
assertTrue(data.getTrainingNegative().contains(entry_1006420.subtract(1l)));
|
||||
assertTrue(data.getTrainingNegative().contains(entry_1006420.add(1l)));
|
||||
}
|
||||
|
||||
if (data.getTrainingPositive().contains(entry_1006425)) {
|
||||
numContained += 1;
|
||||
assertTrue(data.getTrainingNegative().contains(entry_1006425.subtract(2l)));
|
||||
assertTrue(data.getTrainingNegative().contains(entry_1006425.add(5l)));
|
||||
}
|
||||
|
||||
if (data.getTrainingPositive().contains(entry_1006435)) {
|
||||
numContained += 1;
|
||||
assertTrue(data.getTrainingNegative().contains(entry_1006435.subtract(6l)));
|
||||
assertTrue(data.getTrainingNegative().contains(entry_1006435.add(1l)));
|
||||
}
|
||||
|
||||
assertEquals(2, numContained);
|
||||
|
||||
assertTrue(data.getTestNegative().contains(interiors.subtract(data.getTrainingNegative())));
|
||||
assertTrue(data.getTestNegative().contains(definedData));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import generic.test.AbstractGenericTest;
|
||||
|
||||
public class FunctionStartRFParamsTest extends AbstractGenericTest {
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testBadParse1() {
|
||||
FunctionStartRFParams.parseIntegerCSV("");
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testBadParse2() {
|
||||
FunctionStartRFParams.parseIntegerCSV(" ");
|
||||
}
|
||||
|
||||
@Test(expected = NumberFormatException.class)
|
||||
public void testBadParse3() {
|
||||
FunctionStartRFParams.parseIntegerCSV("1,,2");
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testBadParse4() {
|
||||
FunctionStartRFParams.parseIntegerCSV("-1");
|
||||
}
|
||||
|
||||
@Test(expected = NumberFormatException.class)
|
||||
public void testBadParse5() {
|
||||
FunctionStartRFParams.parseIntegerCSV("--1");
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testBadParse6() {
|
||||
FunctionStartRFParams.parseIntegerCSV(",");
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testBadParse7() {
|
||||
FunctionStartRFParams.parseIntegerCSV("1,");
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testBadParse8() {
|
||||
FunctionStartRFParams.parseIntegerCSV("1,2,3,");
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testBadParse9() {
|
||||
FunctionStartRFParams.parseIntegerCSV(",1,2,3,");
|
||||
}
|
||||
|
||||
@Test(expected = NumberFormatException.class)
|
||||
public void testBadParse10() {
|
||||
FunctionStartRFParams.parseIntegerCSV("1,0xabcdv,3");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicValidParses() {
|
||||
List<Integer> results = FunctionStartRFParams.parseIntegerCSV("12345678");
|
||||
assertEquals(1, results.size());
|
||||
assertEquals(Integer.valueOf(12345678), results.get(0));
|
||||
results = FunctionStartRFParams.parseIntegerCSV("0x1,2 , 0x3, 4");
|
||||
assertEquals(4, results.size());
|
||||
assertEquals(Integer.valueOf(1), results.get(0));
|
||||
assertEquals(Integer.valueOf(2), results.get(1));
|
||||
assertEquals(Integer.valueOf(3), results.get(2));
|
||||
assertEquals(Integer.valueOf(4), results.get(3));
|
||||
results = FunctionStartRFParams.parseIntegerCSV("4,3,4,3");
|
||||
assertEquals(2, results.size());
|
||||
assertEquals(Integer.valueOf(3), results.get(0));
|
||||
assertEquals(Integer.valueOf(4), results.get(1));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import generic.test.AbstractGenericTest;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.address.TestAddress;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
public class RandomSubsetTest extends AbstractGenericTest {
|
||||
|
||||
@Test
|
||||
public void testGenerateTrivialSubsets() {
|
||||
List<Long> empty = RandomSubsetUtils.generateRandomIntegerSubset(10, 0);
|
||||
assertEquals(0, empty.size());
|
||||
empty = RandomSubsetUtils.generateRandomIntegerSubset(0, 0);
|
||||
assertEquals(0, empty.size());
|
||||
List<Long> complete = RandomSubsetUtils.generateRandomIntegerSubset(1000000, 1000000);
|
||||
Collections.sort(complete);
|
||||
Iterator<Long> iter = complete.iterator();
|
||||
long current = 0;
|
||||
while (iter.hasNext()) {
|
||||
long elem = iter.next();
|
||||
assertEquals(current++, elem);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicRandomSubsetOfAddresses() throws CancelledException {
|
||||
AddressSet addrs = new AddressSet();
|
||||
for (long i = 0; i < 10000; ++i) {
|
||||
addrs.add(new TestAddress(i));
|
||||
}
|
||||
AddressSet rand = RandomSubsetUtils.randomSubset(addrs, 9998, TaskMonitor.DUMMY);
|
||||
assertEquals(9998, rand.getNumAddresses());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSwap() {
|
||||
Map<Long, Long> permuted = new HashMap<>();
|
||||
assertTrue(permuted.isEmpty());
|
||||
//should do nothing
|
||||
RandomSubsetUtils.swap(permuted, 1, 1);
|
||||
assertTrue(permuted.isEmpty());
|
||||
permuted.put(0l, 5l);
|
||||
permuted.put(1l, 10l);
|
||||
RandomSubsetUtils.swap(permuted, 0, 1);
|
||||
assertEquals(2, permuted.size());
|
||||
assertEquals(Long.valueOf(5), permuted.get(1l));
|
||||
assertEquals(Long.valueOf(10), permuted.get(0l));
|
||||
RandomSubsetUtils.swap(permuted, 100l, 200L);
|
||||
assertEquals(4, permuted.size());
|
||||
assertEquals(Long.valueOf(100), permuted.get(200l));
|
||||
assertEquals(Long.valueOf(200), permuted.get(100l));
|
||||
}
|
||||
|
||||
/**
|
||||
@Test
|
||||
public void timingTest() throws CancelledException {
|
||||
AddressSet big = new AddressSet(new TestAddress(0), new TestAddress(999999));
|
||||
|
||||
long start = System.nanoTime();
|
||||
List<Long> complete = RandomSubset.generateRandomIntegerSubset(1000000, 500000);
|
||||
long end = System.nanoTime();
|
||||
Msg.info(this, "choosing random subset of integers: " +
|
||||
(end - start) / RandomForestTrainingTask.NANOSECONDS_PER_SECOND);
|
||||
start = System.nanoTime();
|
||||
AddressSet random = RandomSubset.randomSubset(big, 500000, TaskMonitor.DUMMY);
|
||||
end = System.nanoTime();
|
||||
Msg.info(this, "choosing random subset of addresses: " +
|
||||
(end - start) / RandomForestTrainingTask.NANOSECONDS_PER_SECOND);
|
||||
}*/
|
||||
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/* ###
|
||||
* IP: GHIDRA
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package ghidra.machinelearning.functionfinding;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import generic.test.AbstractGenericTest;
|
||||
import ghidra.program.model.address.AddressSet;
|
||||
import ghidra.program.model.address.TestAddress;
|
||||
import ghidra.util.exception.CancelledException;
|
||||
import ghidra.util.task.TaskMonitor;
|
||||
|
||||
public class TrainingAndTestDataTest extends AbstractGenericTest {
|
||||
|
||||
private TrainingAndTestData data;
|
||||
private AddressSet originalPos;
|
||||
private AddressSet originalNeg;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
AddressSet testPositive = new AddressSet();
|
||||
testPositive.add(new TestAddress(0), new TestAddress(100));
|
||||
originalPos = new AddressSet(testPositive);
|
||||
AddressSet testNegative = new AddressSet();
|
||||
testNegative.add(new TestAddress(500), new TestAddress(1000));
|
||||
originalNeg = new AddressSet(testNegative);
|
||||
data =
|
||||
new TrainingAndTestData(new AddressSet(), new AddressSet(), testPositive, testNegative);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void reduceTest1() throws CancelledException {
|
||||
data.reduceTestSetSize(1001, TaskMonitor.DUMMY);
|
||||
assertTrue(data.getTestPositive().hasSameAddresses(originalPos));
|
||||
assertTrue(data.getTestNegative().hasSameAddresses(originalNeg));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void reduceTest2() throws CancelledException {
|
||||
data.reduceTestSetSize(250, TaskMonitor.DUMMY);
|
||||
assertTrue(data.getTestPositive().hasSameAddresses(originalPos));
|
||||
assertEquals(250, data.getTestNegative().getNumAddresses());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void reduceTest3() throws CancelledException {
|
||||
data.reduceTestSetSize(10, TaskMonitor.DUMMY);
|
||||
assertEquals(10, data.getTestPositive().getNumAddresses());
|
||||
assertEquals(10, data.getTestNegative().getNumAddresses());
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue