os: make MkdirAll support volume names

MkdirAll fails to create directories under root paths using volume
names (e.g. //?/Volume{GUID}/foo). This is because fixRootDirectory
only handle extended length paths using drive letters (e.g. //?/C:/foo).

This CL fixes that issue by also detecting volume names without path
separator.

Updates #22230
Fixes #39785

Change-Id: I813fdc0b968ce71a4297f69245b935558e6cd789
Reviewed-on: https://go-review.googlesource.com/c/go/+/517015
Run-TryBot: Quim Muntal <quimmuntal@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
This commit is contained in:
qmuntal 2023-08-08 15:31:43 +02:00 committed by Quim Muntal
parent f617a6c8bf
commit cd589c8a73
7 changed files with 118 additions and 64 deletions

View file

@ -398,6 +398,7 @@ type FILE_ID_BOTH_DIR_INFO struct {
}
//sys GetVolumeInformationByHandle(file syscall.Handle, volumeNameBuffer *uint16, volumeNameSize uint32, volumeNameSerialNumber *uint32, maximumComponentLength *uint32, fileSystemFlags *uint32, fileSystemNameBuffer *uint16, fileSystemNameSize uint32) (err error) = GetVolumeInformationByHandleW
//sys GetVolumeNameForVolumeMountPoint(volumeMountPoint *uint16, volumeName *uint16, bufferlength uint32) (err error) = GetVolumeNameForVolumeMountPointW
//sys RtlLookupFunctionEntry(pc uintptr, baseAddress *uintptr, table *byte) (ret uintptr) = kernel32.RtlLookupFunctionEntry
//sys RtlVirtualUnwind(handlerType uint32, baseAddress uintptr, pc uintptr, entry uintptr, ctxt uintptr, data *uintptr, frame *uintptr, ctxptrs *byte) (ret uintptr) = kernel32.RtlVirtualUnwind

View file

@ -45,46 +45,47 @@ var (
moduserenv = syscall.NewLazyDLL(sysdll.Add("userenv.dll"))
modws2_32 = syscall.NewLazyDLL(sysdll.Add("ws2_32.dll"))
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
procDuplicateTokenEx = modadvapi32.NewProc("DuplicateTokenEx")
procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf")
procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW")
procOpenSCManagerW = modadvapi32.NewProc("OpenSCManagerW")
procOpenServiceW = modadvapi32.NewProc("OpenServiceW")
procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken")
procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procGetACP = modkernel32.NewProc("GetACP")
procGetComputerNameExW = modkernel32.NewProc("GetComputerNameExW")
procGetConsoleCP = modkernel32.NewProc("GetConsoleCP")
procGetCurrentThread = modkernel32.NewProc("GetCurrentThread")
procGetFileInformationByHandleEx = modkernel32.NewProc("GetFileInformationByHandleEx")
procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW")
procGetModuleFileNameW = modkernel32.NewProc("GetModuleFileNameW")
procGetTempPath2W = modkernel32.NewProc("GetTempPath2W")
procGetVolumeInformationByHandleW = modkernel32.NewProc("GetVolumeInformationByHandleW")
procLockFileEx = modkernel32.NewProc("LockFileEx")
procModule32FirstW = modkernel32.NewProc("Module32FirstW")
procModule32NextW = modkernel32.NewProc("Module32NextW")
procMoveFileExW = modkernel32.NewProc("MoveFileExW")
procMultiByteToWideChar = modkernel32.NewProc("MultiByteToWideChar")
procRtlLookupFunctionEntry = modkernel32.NewProc("RtlLookupFunctionEntry")
procRtlVirtualUnwind = modkernel32.NewProc("RtlVirtualUnwind")
procSetFileInformationByHandle = modkernel32.NewProc("SetFileInformationByHandle")
procUnlockFileEx = modkernel32.NewProc("UnlockFileEx")
procVirtualQuery = modkernel32.NewProc("VirtualQuery")
procNetShareAdd = modnetapi32.NewProc("NetShareAdd")
procNetShareDel = modnetapi32.NewProc("NetShareDel")
procNetUserGetLocalGroups = modnetapi32.NewProc("NetUserGetLocalGroups")
procGetProcessMemoryInfo = modpsapi.NewProc("GetProcessMemoryInfo")
procCreateEnvironmentBlock = moduserenv.NewProc("CreateEnvironmentBlock")
procDestroyEnvironmentBlock = moduserenv.NewProc("DestroyEnvironmentBlock")
procGetProfilesDirectoryW = moduserenv.NewProc("GetProfilesDirectoryW")
procWSASocketW = modws2_32.NewProc("WSASocketW")
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
procDuplicateTokenEx = modadvapi32.NewProc("DuplicateTokenEx")
procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf")
procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW")
procOpenSCManagerW = modadvapi32.NewProc("OpenSCManagerW")
procOpenServiceW = modadvapi32.NewProc("OpenServiceW")
procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken")
procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procGetACP = modkernel32.NewProc("GetACP")
procGetComputerNameExW = modkernel32.NewProc("GetComputerNameExW")
procGetConsoleCP = modkernel32.NewProc("GetConsoleCP")
procGetCurrentThread = modkernel32.NewProc("GetCurrentThread")
procGetFileInformationByHandleEx = modkernel32.NewProc("GetFileInformationByHandleEx")
procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW")
procGetModuleFileNameW = modkernel32.NewProc("GetModuleFileNameW")
procGetTempPath2W = modkernel32.NewProc("GetTempPath2W")
procGetVolumeInformationByHandleW = modkernel32.NewProc("GetVolumeInformationByHandleW")
procGetVolumeNameForVolumeMountPointW = modkernel32.NewProc("GetVolumeNameForVolumeMountPointW")
procLockFileEx = modkernel32.NewProc("LockFileEx")
procModule32FirstW = modkernel32.NewProc("Module32FirstW")
procModule32NextW = modkernel32.NewProc("Module32NextW")
procMoveFileExW = modkernel32.NewProc("MoveFileExW")
procMultiByteToWideChar = modkernel32.NewProc("MultiByteToWideChar")
procRtlLookupFunctionEntry = modkernel32.NewProc("RtlLookupFunctionEntry")
procRtlVirtualUnwind = modkernel32.NewProc("RtlVirtualUnwind")
procSetFileInformationByHandle = modkernel32.NewProc("SetFileInformationByHandle")
procUnlockFileEx = modkernel32.NewProc("UnlockFileEx")
procVirtualQuery = modkernel32.NewProc("VirtualQuery")
procNetShareAdd = modnetapi32.NewProc("NetShareAdd")
procNetShareDel = modnetapi32.NewProc("NetShareDel")
procNetUserGetLocalGroups = modnetapi32.NewProc("NetUserGetLocalGroups")
procGetProcessMemoryInfo = modpsapi.NewProc("GetProcessMemoryInfo")
procCreateEnvironmentBlock = moduserenv.NewProc("CreateEnvironmentBlock")
procDestroyEnvironmentBlock = moduserenv.NewProc("DestroyEnvironmentBlock")
procGetProfilesDirectoryW = moduserenv.NewProc("GetProfilesDirectoryW")
procWSASocketW = modws2_32.NewProc("WSASocketW")
)
func adjustTokenPrivileges(token syscall.Token, disableAllPrivileges bool, newstate *TOKEN_PRIVILEGES, buflen uint32, prevstate *TOKEN_PRIVILEGES, returnlen *uint32) (ret uint32, err error) {
@ -279,6 +280,14 @@ func GetVolumeInformationByHandle(file syscall.Handle, volumeNameBuffer *uint16,
return
}
func GetVolumeNameForVolumeMountPoint(volumeMountPoint *uint16, volumeName *uint16, bufferlength uint32) (err error) {
r1, _, e1 := syscall.Syscall(procGetVolumeNameForVolumeMountPointW.Addr(), 3, uintptr(unsafe.Pointer(volumeMountPoint)), uintptr(unsafe.Pointer(volumeName)), uintptr(bufferlength))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func LockFileEx(file syscall.Handle, flags uint32, reserved uint32, bytesLow uint32, bytesHigh uint32, overlapped *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall6(procLockFileEx.Addr(), 6, uintptr(file), uintptr(flags), uintptr(reserved), uintptr(bytesLow), uintptr(bytesHigh), uintptr(unsafe.Pointer(overlapped)))
if r1 == 0 {

View file

@ -26,19 +26,25 @@ func MkdirAll(path string, perm FileMode) error {
}
// Slow path: make sure parent exists and then call Mkdir for path.
i := len(path)
for i > 0 && IsPathSeparator(path[i-1]) { // Skip trailing path separator.
// Extract the parent folder from path by first removing any trailing
// path separator and then scanning backward until finding a path
// separator or reaching the beginning of the string.
i := len(path) - 1
for i >= 0 && IsPathSeparator(path[i]) {
i--
}
j := i
for j > 0 && !IsPathSeparator(path[j-1]) { // Scan backward over element.
j--
for i >= 0 && !IsPathSeparator(path[i]) {
i--
}
if i < 0 {
i = 0
}
if j > 1 {
// Create parent.
err = MkdirAll(fixRootDirectory(path[:j-1]), perm)
// If there is a parent directory, and it is not the volume name,
// recurse to ensure parent directory exists.
if parent := path[:i]; len(parent) > len(volumeName(path)) {
err = MkdirAll(parent, perm)
if err != nil {
return err
}

View file

@ -14,6 +14,6 @@ func IsPathSeparator(c uint8) bool {
return PathSeparator == c
}
func fixRootDirectory(p string) string {
return p
func volumeName(p string) string {
return ""
}

View file

@ -70,6 +70,6 @@ func splitPath(path string) (string, string) {
return dirname, basename
}
func fixRootDirectory(p string) string {
return p
func volumeName(p string) string {
return ""
}

View file

@ -214,14 +214,3 @@ func fixLongPath(path string) string {
}
return string(pathbuf[:w])
}
// fixRootDirectory fixes a reference to a drive's root directory to
// have the required trailing slash.
func fixRootDirectory(p string) string {
if len(p) == len(`\\?\c:`) {
if IsPathSeparator(p[0]) && IsPathSeparator(p[1]) && p[2] == '?' && IsPathSeparator(p[3]) && p[5] == ':' {
return p + `\`
}
}
return p
}

View file

@ -5,7 +5,11 @@
package os_test
import (
"fmt"
"internal/syscall/windows"
"internal/testenv"
"os"
"path/filepath"
"strings"
"syscall"
"testing"
@ -106,3 +110,48 @@ func TestOpenRootSlash(t *testing.T) {
dir.Close()
}
}
func testMkdirAllAtRoot(t *testing.T, root string) {
// Create a unique-enough directory name in root.
base := fmt.Sprintf("%s-%d", t.Name(), os.Getpid())
path := filepath.Join(root, base)
if err := os.MkdirAll(path, 0777); err != nil {
t.Fatalf("MkdirAll(%q) failed: %v", path, err)
}
// Clean up
if err := os.RemoveAll(path); err != nil {
t.Fatal(err)
}
}
func TestMkdirAllExtendedLengthAtRoot(t *testing.T) {
if testenv.Builder() == "" {
t.Skipf("skipping non-hermetic test outside of Go builders")
}
const prefix = `\\?\`
vol := filepath.VolumeName(t.TempDir()) + `\`
if len(vol) < 4 || vol[:4] != prefix {
vol = prefix + vol
}
testMkdirAllAtRoot(t, vol)
}
func TestMkdirAllVolumeNameAtRoot(t *testing.T) {
if testenv.Builder() == "" {
t.Skipf("skipping non-hermetic test outside of Go builders")
}
vol, err := syscall.UTF16PtrFromString(filepath.VolumeName(t.TempDir()) + `\`)
if err != nil {
t.Fatal(err)
}
const maxVolNameLen = 50
var buf [maxVolNameLen]uint16
err = windows.GetVolumeNameForVolumeMountPoint(vol, &buf[0], maxVolNameLen)
if err != nil {
t.Fatal(err)
}
volName := syscall.UTF16ToString(buf[:])
testMkdirAllAtRoot(t, volName)
}