diff --git a/src/cmd/compile/internal/typecheck/builtin.go b/src/cmd/compile/internal/typecheck/builtin.go index 875af37215e..e452f23ff05 100644 --- a/src/cmd/compile/internal/typecheck/builtin.go +++ b/src/cmd/compile/internal/typecheck/builtin.go @@ -205,6 +205,8 @@ var runtimeDecls = [...]struct { {"libfuzzerTraceConstCmp2", funcTag, 146}, {"libfuzzerTraceConstCmp4", funcTag, 147}, {"libfuzzerTraceConstCmp8", funcTag, 148}, + {"libfuzzerHookStrCmp", funcTag, 149}, + {"libfuzzerHookEqualFold", funcTag, 149}, {"x86HasPOPCNT", varTag, 6}, {"x86HasSSE41", varTag, 6}, {"x86HasFMA", varTag, 6}, @@ -228,7 +230,7 @@ func params(tlist ...*types.Type) []*types.Field { } func runtimeTypes() []*types.Type { - var typs [149]*types.Type + var typs [150]*types.Type typs[0] = types.ByteType typs[1] = types.NewPtr(typs[0]) typs[2] = types.Types[types.TANY] @@ -378,5 +380,6 @@ func runtimeTypes() []*types.Type { typs[146] = newSig(params(typs[60], typs[60]), nil) typs[147] = newSig(params(typs[62], typs[62]), nil) typs[148] = newSig(params(typs[24], typs[24]), nil) + typs[149] = newSig(params(typs[28], typs[28], typs[15]), nil) return typs[:] } diff --git a/src/cmd/compile/internal/typecheck/builtin/runtime.go b/src/cmd/compile/internal/typecheck/builtin/runtime.go index dd19eefa298..97b8318f7ff 100644 --- a/src/cmd/compile/internal/typecheck/builtin/runtime.go +++ b/src/cmd/compile/internal/typecheck/builtin/runtime.go @@ -267,6 +267,8 @@ func libfuzzerTraceConstCmp1(uint8, uint8) func libfuzzerTraceConstCmp2(uint16, uint16) func libfuzzerTraceConstCmp4(uint32, uint32) func libfuzzerTraceConstCmp8(uint64, uint64) +func libfuzzerHookStrCmp(string, string, int) +func libfuzzerHookEqualFold(string, string, int) // architecture variants var x86HasPOPCNT bool diff --git a/src/cmd/compile/internal/walk/compare.go b/src/cmd/compile/internal/walk/compare.go index d271698c519..b02cf22acff 100644 --- a/src/cmd/compile/internal/walk/compare.go +++ b/src/cmd/compile/internal/walk/compare.go @@ -5,7 +5,11 @@ package walk import ( + "encoding/binary" + "fmt" "go/constant" + "hash/fnv" + "io" "cmd/compile/internal/base" "cmd/compile/internal/compare" @@ -16,6 +20,22 @@ import ( "cmd/compile/internal/types" ) +func fakePC(n ir.Node) ir.Node { + // In order to get deterministic IDs, we include the package path, absolute filename, line number, column number + // in the calculation of the fakePC for the IR node. + hash := fnv.New32() + // We ignore the errors here because the `io.Writer` in the `hash.Hash` interface never returns an error. + io.WriteString(hash, base.Ctxt.Pkgpath) + io.WriteString(hash, base.Ctxt.PosTable.Pos(n.Pos()).AbsFilename()) + binary.Write(hash, binary.LittleEndian, int64(n.Pos().Line())) + binary.Write(hash, binary.LittleEndian, int64(n.Pos().Col())) + // We also include the string representation of the node to distinguish autogenerated expression since + // those get the same `src.XPos` + io.WriteString(hash, fmt.Sprintf("%v", n)) + + return ir.NewInt(int64(hash.Sum32())) +} + // The result of walkCompare MUST be assigned back to n, e.g. // // n.Left = walkCompare(n.Left, init) @@ -290,6 +310,15 @@ func walkCompareInterface(n *ir.BinaryExpr, init *ir.Nodes) ir.Node { } func walkCompareString(n *ir.BinaryExpr, init *ir.Nodes) ir.Node { + if base.Debug.Libfuzzer != 0 { + if !ir.IsConst(n.X, constant.String) || !ir.IsConst(n.Y, constant.String) { + fn := "libfuzzerHookStrCmp" + n.X = cheapExpr(n.X, init) + n.Y = cheapExpr(n.Y, init) + paramType := types.Types[types.TSTRING] + init.Append(mkcall(fn, nil, init, tracecmpArg(n.X, paramType, init), tracecmpArg(n.Y, paramType, init), fakePC(n))) + } + } // Rewrite comparisons to short constant strings as length+byte-wise comparisons. var cs, ncs ir.Node // const string, non-const string switch { diff --git a/src/cmd/compile/internal/walk/expr.go b/src/cmd/compile/internal/walk/expr.go index 9aabf916795..803a07ae73e 100644 --- a/src/cmd/compile/internal/walk/expr.go +++ b/src/cmd/compile/internal/walk/expr.go @@ -496,6 +496,16 @@ func walkAddString(n *ir.AddStringExpr, init *ir.Nodes) ir.Node { return r1 } +type hookInfo struct { + paramType types.Kind + argsNum int + runtimeFunc string +} + +var hooks = map[string]hookInfo{ + "strings.EqualFold": {paramType: types.TSTRING, argsNum: 2, runtimeFunc: "libfuzzerHookEqualFold"}, +} + // walkCall walks an OCALLFUNC or OCALLINTER node. func walkCall(n *ir.CallExpr, init *ir.Nodes) ir.Node { if n.Op() == ir.OCALLMETH { @@ -591,6 +601,20 @@ func walkCall1(n *ir.CallExpr, init *ir.Nodes) { } n.Args = args + funSym := n.X.Sym() + if base.Debug.Libfuzzer != 0 && funSym != nil { + if hook, found := hooks[funSym.Pkg.Path+"."+funSym.Name]; found { + if len(args) != hook.argsNum { + panic(fmt.Sprintf("%s.%s expects %d arguments, but received %d", funSym.Pkg.Path, funSym.Name, hook.argsNum, len(args))) + } + var hookArgs []ir.Node + for _, arg := range args { + hookArgs = append(hookArgs, tracecmpArg(arg, types.Types[hook.paramType], init)) + } + hookArgs = append(hookArgs, fakePC(n)) + init.Append(mkcall(hook.runtimeFunc, nil, init, hookArgs...)) + } + } } // walkDivMod walks an ODIV or OMOD node. diff --git a/src/cmd/internal/goobj/builtinlist.go b/src/cmd/internal/goobj/builtinlist.go index 608c0d72223..2d13222984d 100644 --- a/src/cmd/internal/goobj/builtinlist.go +++ b/src/cmd/internal/goobj/builtinlist.go @@ -195,6 +195,8 @@ var builtins = [...]struct { {"runtime.libfuzzerTraceConstCmp2", 1}, {"runtime.libfuzzerTraceConstCmp4", 1}, {"runtime.libfuzzerTraceConstCmp8", 1}, + {"runtime.libfuzzerHookStrCmp", 1}, + {"runtime.libfuzzerHookEqualFold", 1}, {"runtime.x86HasPOPCNT", 0}, {"runtime.x86HasSSE41", 0}, {"runtime.x86HasFMA", 0}, diff --git a/src/internal/fuzz/trace.go b/src/internal/fuzz/trace.go index cab0838fab4..3aa684b49cf 100644 --- a/src/internal/fuzz/trace.go +++ b/src/internal/fuzz/trace.go @@ -18,6 +18,9 @@ import _ "unsafe" // for go:linkname //go:linkname libfuzzerTraceConstCmp4 runtime.libfuzzerTraceConstCmp4 //go:linkname libfuzzerTraceConstCmp8 runtime.libfuzzerTraceConstCmp8 +//go:linkname libfuzzerHookStrCmp runtime.libfuzzerHookStrCmp +//go:linkname libfuzzerHookEqualFold runtime.libfuzzerHookEqualFold + func libfuzzerTraceCmp1(arg0, arg1 uint8) {} func libfuzzerTraceCmp2(arg0, arg1 uint16) {} func libfuzzerTraceCmp4(arg0, arg1 uint32) {} @@ -27,3 +30,6 @@ func libfuzzerTraceConstCmp1(arg0, arg1 uint8) {} func libfuzzerTraceConstCmp2(arg0, arg1 uint16) {} func libfuzzerTraceConstCmp4(arg0, arg1 uint32) {} func libfuzzerTraceConstCmp8(arg0, arg1 uint64) {} + +func libfuzzerHookStrCmp(arg0, arg1 string, fakePC int) {} +func libfuzzerHookEqualFold(arg0, arg1 string, fakePC int) {} diff --git a/src/runtime/libfuzzer.go b/src/runtime/libfuzzer.go index 920ac575f58..c136eaf5fef 100644 --- a/src/runtime/libfuzzer.go +++ b/src/runtime/libfuzzer.go @@ -9,6 +9,7 @@ package runtime import "unsafe" func libfuzzerCallWithTwoByteBuffers(fn, start, end *byte) +func libfuzzerCall4(fn *byte, fakePC uintptr, s1, s2 unsafe.Pointer, result uintptr) func libfuzzerCall(fn *byte, arg0, arg1 uintptr) func libfuzzerTraceCmp1(arg0, arg1 uint8) { @@ -59,6 +60,31 @@ func init() { libfuzzerCallWithTwoByteBuffers(&__sanitizer_cov_pcs_init, &pcTables[0], &pcTables[size-1]) } +// We call libFuzzer's __sanitizer_weak_hook_strcmp function +// which takes the following four arguments: +// 1- caller_pc: location of string comparison call site +// 2- s1: first string used in the comparison +// 3- s2: second string used in the comparison +// 4- result: an integer representing the comparison result. Libfuzzer only distinguishes between two cases: +// - 0 means that the strings are equal and the comparison will be ignored by libfuzzer. +// - Any other value means that strings are not equal and libfuzzer takes the comparison into consideration. +// Here, we pass 1 when the strings are not equal. +func libfuzzerHookStrCmp(s1, s2 string, fakePC int) { + if s1 != s2 { + libfuzzerCall4(&__sanitizer_weak_hook_strcmp, uintptr(fakePC), cstring(s1), cstring(s2), uintptr(1)) + } + // if s1 == s2 we could call the hook with a last argument of 0 but this is unnecessary since this case will be then + // ignored by libfuzzer +} + +// This function has now the same implementation as libfuzzerHookStrCmp because we lack better checks +// for case-insensitive string equality in the runtime package. +func libfuzzerHookEqualFold(s1, s2 string, fakePC int) { + if s1 != s2 { + libfuzzerCall4(&__sanitizer_weak_hook_strcmp, uintptr(fakePC), cstring(s1), cstring(s2), uintptr(1)) + } +} + //go:linkname __sanitizer_cov_trace_cmp1 __sanitizer_cov_trace_cmp1 //go:cgo_import_static __sanitizer_cov_trace_cmp1 var __sanitizer_cov_trace_cmp1 byte @@ -106,3 +132,7 @@ var __stop___sancov_cntrs byte //go:linkname __sanitizer_cov_pcs_init __sanitizer_cov_pcs_init //go:cgo_import_static __sanitizer_cov_pcs_init var __sanitizer_cov_pcs_init byte + +//go:linkname __sanitizer_weak_hook_strcmp __sanitizer_weak_hook_strcmp +//go:cgo_import_static __sanitizer_weak_hook_strcmp +var __sanitizer_weak_hook_strcmp byte diff --git a/src/runtime/libfuzzer_amd64.s b/src/runtime/libfuzzer_amd64.s index 5ea77f59de2..032821fbbcf 100644 --- a/src/runtime/libfuzzer_amd64.s +++ b/src/runtime/libfuzzer_amd64.s @@ -13,12 +13,41 @@ #ifdef GOOS_windows #define RARG0 CX #define RARG1 DX +#define RARG0 R8 +#define RARG1 R9 #else #define RARG0 DI #define RARG1 SI +#define RARG2 DX +#define RARG3 CX #endif -// void runtime·libfuzzerCall(fn, arg0, arg1 uintptr) +// void runtime·libfuzzerCall4(fn, hookId int, s1, s2 unsafe.Pointer, result uintptr) +// Calls C function fn from libFuzzer and passes 4 arguments to it. +TEXT runtime·libfuzzerCall4(SB), NOSPLIT, $0-40 + MOVQ fn+0(FP), AX + MOVQ hookId+8(FP), RARG0 + MOVQ s1+16(FP), RARG1 + MOVQ s2+24(FP), RARG2 + MOVQ result+32(FP), RARG3 + + get_tls(R12) + MOVQ g(R12), R14 + MOVQ g_m(R14), R13 + + // Switch to g0 stack. + MOVQ SP, R12 // callee-saved, preserved across the CALL + MOVQ m_g0(R13), R10 + CMPQ R10, R14 + JE call // already on g0 + MOVQ (g_sched+gobuf_sp)(R10), SP +call: + ANDQ $~15, SP // alignment for gcc ABI + CALL AX + MOVQ R12, SP + RET + +// void runtime·libfuzzerCallTraceInit(fn, start, end *byte) // Calls C function fn from libFuzzer and passes 2 arguments to it. TEXT runtime·libfuzzerCall(SB), NOSPLIT, $0-24 MOVQ fn+0(FP), AX diff --git a/src/runtime/libfuzzer_arm64.s b/src/runtime/libfuzzer_arm64.s index b0146682a26..f9b67913e2a 100644 --- a/src/runtime/libfuzzer_arm64.s +++ b/src/runtime/libfuzzer_arm64.s @@ -9,12 +9,40 @@ // Based on race_arm64.s; see commentary there. +#define RARG0 R0 +#define RARG1 R1 +#define RARG2 R2 +#define RARG3 R3 + +// void runtime·libfuzzerCall4(fn, hookId int, s1, s2 unsafe.Pointer, result uintptr) +// Calls C function fn from libFuzzer and passes 4 arguments to it. +TEXT runtime·libfuzzerCall4(SB), NOSPLIT, $0-40 + MOVD fn+0(FP), R9 + MOVD hookId+8(FP), RARG0 + MOVD s1+16(FP), RARG1 + MOVD s2+24(FP), RARG2 + MOVD result+32(FP), RARG3 + + MOVD g_m(g), R10 + + // Switch to g0 stack. + MOVD RSP, R19 // callee-saved, preserved across the CALL + MOVD m_g0(R10), R11 + CMP R11, g + BEQ call // already on g0 + MOVD (g_sched+gobuf_sp)(R11), R12 + MOVD R12, RSP +call: + BL R9 + MOVD R19, RSP + RET + // func runtime·libfuzzerCall(fn, arg0, arg1 uintptr) // Calls C function fn from libFuzzer and passes 2 arguments to it. TEXT runtime·libfuzzerCall(SB), NOSPLIT, $0-24 MOVD fn+0(FP), R9 - MOVD arg0+8(FP), R0 - MOVD arg1+16(FP), R1 + MOVD arg0+8(FP), RARG0 + MOVD arg1+16(FP), RARG1 MOVD g_m(g), R10