os/user: lookup Linux users and groups via systemd userdb

Fetch usernames and groups via systemd userdb if available.
Otherwise fall back to parsing /etc/passwd, etc.

Fixes #38810

Co-authored-by: Michael Stapelberg <stapelberg@google.com>
This commit is contained in:
Ananth Bhaskararaman 2022-12-26 01:09:10 +05:30
parent bcd82125f8
commit 1a627cc9a1
No known key found for this signature in database
GPG key ID: 66A20525398E18D7
7 changed files with 1370 additions and 0 deletions

View file

@ -9,11 +9,13 @@ package user
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"strconv"
"time"
)
func listGroupsFromReader(u *User, r io.Reader) ([]string, error) {
@ -99,6 +101,13 @@ func listGroupsFromReader(u *User, r io.Reader) ([]string, error) {
}
func listGroups(u *User) ([]string, error) {
if defaultUserdbClient.isUsable() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
if ids, ok, err := defaultUserdbClient.lookupGroupIds(ctx, u.Username); ok {
return ids, err
}
}
f, err := os.Open(groupFile)
if err != nil {
return nil, err

View file

@ -9,11 +9,13 @@ package user
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"os"
"strconv"
"strings"
"time"
)
// lineFunc returns a value, an error, or (nil, nil) to skip the row.
@ -198,6 +200,13 @@ func findUsername(name string, r io.Reader) (*User, error) {
}
func lookupGroup(groupname string) (*Group, error) {
if defaultUserdbClient.isUsable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if g, ok, err := defaultUserdbClient.lookupGroup(ctx, groupname); ok {
return g, err
}
}
f, err := os.Open(groupFile)
if err != nil {
return nil, err
@ -207,6 +216,13 @@ func lookupGroup(groupname string) (*Group, error) {
}
func lookupGroupId(id string) (*Group, error) {
if defaultUserdbClient.isUsable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if g, ok, err := defaultUserdbClient.lookupGroupId(ctx, id); ok {
return g, err
}
}
f, err := os.Open(groupFile)
if err != nil {
return nil, err
@ -216,6 +232,13 @@ func lookupGroupId(id string) (*Group, error) {
}
func lookupUser(username string) (*User, error) {
if defaultUserdbClient.isUsable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if u, ok, err := defaultUserdbClient.lookupUser(ctx, username); ok {
return u, err
}
}
f, err := os.Open(userFile)
if err != nil {
return nil, err
@ -225,6 +248,13 @@ func lookupUser(username string) (*User, error) {
}
func lookupUserId(uid string) (*User, error) {
if defaultUserdbClient.isUsable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if u, ok, err := defaultUserdbClient.lookupUserId(ctx, uid); ok {
return u, err
}
}
f, err := os.Open(userFile)
if err != nil {
return nil, err

View file

@ -11,6 +11,10 @@ One is written in pure Go and parses /etc/passwd and /etc/group. The other
is cgo-based and relies on the standard C library (libc) routines such as
getpwuid_r, getgrnam_r, and getgrouplist.
For Linux, the pure Go implementation queries the systemd-userdb service first.
If the service is not available, it falls back to parsing /etc/passwd and
/etc/group.
When cgo is available, and the required routines are implemented in libc
for a particular platform, cgo-based (libc-backed) code is used.
This can be overridden by using osusergo build tag, which enforces

View file

@ -0,0 +1,22 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package user
// userdbClient queries the io.systemd.UserDatabase VARLINK interface provided by
// systemd-userdbd.service(8) on Linux for obtaining full user/group details
// even when cgo is not available.
// VARLINK protocol: https://varlink.org
// Systemd userdb VARLINK interface https://systemd.io/USER_GROUP_API
// dir contains multiple varlink service sockets implementing the userdb interface.
type userdbClient struct {
dir string
}
// IsUsable checks if the client can be used to make queries.
func (cl userdbClient) isUsable() bool {
return len(cl.dir) != 0
}
var defaultUserdbClient userdbClient

View file

@ -0,0 +1,772 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux
package user
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/fs"
"os"
"strconv"
"strings"
"sync"
"syscall"
"unicode/utf16"
"unicode/utf8"
)
const (
// Well known multiplexer service.
svcMultiplexer = "io.systemd.Multiplexer"
userdbNamespace = "io.systemd.UserDatabase"
// io.systemd.UserDatabase VARLINK interface methods.
mGetGroupRecord = userdbNamespace + ".GetGroupRecord"
mGetUserRecord = userdbNamespace + ".GetUserRecord"
mGetMemberships = userdbNamespace + ".GetMemberships"
// io.systemd.UserDatabase VARLINK interface errors.
errNoRecordFound = userdbNamespace + ".NoRecordFound"
errServiceNotAvailable = userdbNamespace + ".ServiceNotAvailable"
)
func init() {
defaultUserdbClient.dir = "/run/systemd/userdb"
}
// userdbCall represents a VARLINK service call sent to systemd-userdb.
// method is the VARLINK method to call.
// parameters are the VARLINK parameters to pass.
// more indicates if more responses are expected.
// fastest indicates if only the fastest response should be returned.
type userdbCall struct {
method string
parameters callParameters
more bool
fastest bool
}
func (u userdbCall) marshalJSON(service string) ([]byte, error) {
params, err := u.parameters.marshalJSON(service)
if err != nil {
return nil, err
}
var data bytes.Buffer
data.WriteString(`{"method":"`)
data.WriteString(u.method)
data.WriteString(`","parameters":`)
data.Write(params)
if u.more {
data.WriteString(`,"more":true`)
}
data.WriteString(`}`)
return data.Bytes(), nil
}
type callParameters struct {
uid *int64
userName string
gid *int64
groupName string
}
func (c callParameters) marshalJSON(service string) ([]byte, error) {
var data bytes.Buffer
data.WriteString(`{"service":"`)
data.WriteString(service)
data.WriteString(`"`)
if c.uid != nil {
data.WriteString(`,"uid":`)
data.WriteString(strconv.FormatInt(*c.uid, 10))
}
if c.userName != "" {
data.WriteString(`,"userName":"`)
data.WriteString(c.userName)
data.WriteString(`"`)
}
if c.gid != nil {
data.WriteString(`,"gid":`)
data.WriteString(strconv.FormatInt(*c.gid, 10))
}
if c.groupName != "" {
data.WriteString(`,"groupName":"`)
data.WriteString(c.groupName)
data.WriteString(`"`)
}
data.WriteString(`}`)
return data.Bytes(), nil
}
type userdbReply struct {
continues bool
errorStr string
}
func (u *userdbReply) unmarshalJSON(data []byte) error {
var (
kContinues = []byte(`"continues"`)
kError = []byte(`"error"`)
)
if i := bytes.Index(data, kContinues); i != -1 {
continues, err := parseJSONBoolean(data[i+len(kContinues):])
if err != nil {
return err
}
u.continues = continues
}
if i := bytes.Index(data, kError); i != -1 {
errStr, err := parseJSONString(data[i+len(kError):])
if err != nil {
return err
}
u.errorStr = errStr
}
return nil
}
// response is the parsed reply from a method call to systemd-userdb.
// data is one or more VARLINK response parameters separated by 0.
// handled indicates if the call was handled by systemd-userdb.
// err is any error encountered.
type response struct {
data []byte
handled bool
err error
}
// querySocket calls the io.systemd.UserDatabase VARLINK interface at sock with request.
// Multiple replies can be read by setting more to true in the request.
// Reply parameters are accumulated separated by 0, if there are many.
// Replies with io.systemd.UserDatabase.NoRecordFound errors are skipped.
// Other UserDatabase errors are returned as is.
// If the socket does not exist, or if the io.systemd.UserDatabase.ServiceNotAvailable
// error is seen in a response, the query is considered unhandled.
func querySocket(ctx context.Context, sock string, request []byte) response {
sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
if err != nil {
return response{err: err}
}
defer syscall.Close(sockFd)
if err := syscall.Connect(sockFd, &syscall.SockaddrUnix{Name: sock}); err != nil {
if errors.Is(err, os.ErrNotExist) {
return response{err: err}
}
return response{handled: true, err: err}
}
// Null terminate request.
if request[len(request)-1] != 0 {
request = append(request, 0)
}
// Write request to socket.
written := 0
for written < len(request) {
if ctx.Err() != nil {
return response{handled: true, err: ctx.Err()}
}
if n, err := syscall.Write(sockFd, request[written:]); err != nil {
return response{handled: true, err: err}
} else {
written += n
}
}
// Read response.
var resp bytes.Buffer
for {
if ctx.Err() != nil {
return response{handled: true, err: ctx.Err()}
}
buf := make([]byte, 4096)
if n, err := syscall.Read(sockFd, buf); err != nil {
return response{handled: true, err: err}
} else if n > 0 {
resp.Write(buf[:n])
if buf[n-1] == 0 {
break
}
} else {
// EOF
break
}
}
if resp.Len() == 0 {
return response{handled: true}
}
buf := resp.Bytes()
// Remove trailing 0.
buf = buf[:len(buf)-1]
// Split into VARLINK messages.
msgs := bytes.Split(buf, []byte{0})
// Parse VARLINK messages.
for _, m := range msgs {
var resp userdbReply
if err := resp.unmarshalJSON(m); err != nil {
return response{handled: true, err: err}
}
// Handle VARLINK message errors.
switch e := resp.errorStr; e {
case "":
case errNoRecordFound: // Ignore not found error.
continue
case errServiceNotAvailable:
return response{}
default:
return response{handled: true, err: errors.New(e)}
}
if !resp.continues {
break
}
}
return response{data: buf, handled: true, err: ctx.Err()}
}
// queryMany calls the io.systemd.UserDatabase VARLINK interface on many services at once.
// ss is a slice of userdb services to call. Each service must have a socket in cl.dir.
// c is sent to all services in ss. If c.fastest is true, only the fastest reply is read.
// Otherwise all replies are aggregated. um is called with aggregated reply parameters.
// queryMany returns the first error encountered. The first result is false if no userdb
// socket is available or if all requests time out.
func (cl userdbClient) queryMany(ctx context.Context, ss []string, c *userdbCall, um jsonUnmarshaler) (bool, error) {
responseCh := make(chan response, len(ss))
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Query all services in parallel.
var workers sync.WaitGroup
for _, svc := range ss {
data, err := c.marshalJSON(svc)
if err != nil {
return true, err
}
// Spawn worker to query service.
workers.Add(1)
go func(sock string, data []byte) {
defer workers.Done()
responseCh <- querySocket(ctx, sock, data)
}(cl.dir+"/"+svc, data)
}
go func() {
// Clean up workers.
workers.Wait()
close(responseCh)
}()
var result bytes.Buffer
var notOk int
RecvResponses:
for {
select {
case resp, ok := <-responseCh:
if !ok {
// Responses channel is closed so stop reading.
break RecvResponses
}
if resp.err != nil {
// querySocket only returns unrecoverable errors,
// so return the first one received.
return true, resp.err
}
if !resp.handled {
notOk++
continue
}
first := result.Len() == 0
result.Write(resp.data)
if first && c.fastest {
// Return the fastest response.
break RecvResponses
}
case <-ctx.Done():
// If requests time out, userdb is unavailable.
return ctx.Err() != context.DeadlineExceeded, nil
}
}
// If all sockets are not ok, userdb is unavailable.
if notOk == len(ss) {
return false, nil
}
return true, um.unmarshalJSON(result.Bytes())
}
// services enumerates userdb service sockets in dir.
// If ok is false, io.systemd.UserDatabase service does not exist.
func (cl userdbClient) services() (s []string, ok bool, err error) {
var entries []fs.DirEntry
if entries, err = os.ReadDir(cl.dir); err != nil {
ok = !os.IsNotExist(err)
return
}
ok = true
for _, ent := range entries {
s = append(s, ent.Name())
}
return
}
// query looks up users/groups on the io.systemd.UserDatabase VARLINK interface.
// If the multiplexer service is available, the call is sent only to it.
// Otherwise, the call is sent simultaneously to all UserDatabase services in cl.dir.
// The fastest reply is read and parsed. All other requests are cancelled.
// If the service is unavailable, the first result is false.
// The service is considered unavailable if the requests time-out as well.
func (cl userdbClient) query(ctx context.Context, call *userdbCall, um jsonUnmarshaler) (bool, error) {
services := []string{svcMultiplexer}
if _, err := os.Stat(cl.dir + "/" + svcMultiplexer); err != nil {
// No mux service so call all available services.
var ok bool
if services, ok, err = cl.services(); !ok || err != nil {
return ok, err
}
}
call.fastest = true
if ok, err := cl.queryMany(ctx, services, call, um); !ok || err != nil {
return ok, err
}
return true, nil
}
type jsonUnmarshaler interface {
unmarshalJSON([]byte) error
}
func isSpace(c byte) bool {
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
}
// findElementStart returns a slice of r that starts at the next JSON element.
// It skips over valid JSON space characters and checks for the colon separator.
func findElementStart(r []byte) ([]byte, error) {
var idx int
var b byte
colon := byte(':')
var seenColon bool
for idx, b = range r {
if isSpace(b) {
continue
}
if !seenColon && b == colon {
seenColon = true
continue
}
// Spotted colon and b is not a space, so value starts here.
if seenColon {
break
}
return nil, errors.New("expected colon, got invalid character: " + string(b))
}
if !seenColon {
return nil, errors.New("expected colon, got end of input")
}
return r[idx:], nil
}
// parseJSONString reads a JSON string from r.
func parseJSONString(r []byte) (string, error) {
r, err := findElementStart(r)
if err != nil {
return "", err
}
// Smallest valid string is `""`.
if l := len(r); l < 2 {
return "", errors.New("unexpected end of input")
} else if l == 2 {
if bytes.Equal(r, []byte(`""`)) {
return "", nil
}
return "", errors.New("invalid string")
}
if c := r[0]; c != '"' {
return "", errors.New(`expected " got ` + string(c))
}
// Advance over opening quote.
r = r[1:]
var value strings.Builder
var inEsc bool
var inUEsc bool
var strEnds bool
reader := bytes.NewReader(r)
for {
if value.Len() > 4096 {
return "", errors.New("string too large")
}
// Parse unicode escape sequences.
if inUEsc {
maybeRune := make([]byte, 4)
n, err := reader.Read(maybeRune)
if err != nil || n != 4 {
return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune))
}
prn, err := strconv.ParseUint(string(maybeRune), 16, 32)
if err != nil {
return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune))
}
rn := rune(prn)
if !utf16.IsSurrogate(rn) {
value.WriteRune(rn)
inUEsc = false
continue
}
// rn maybe a high surrogate; read the low surrogate.
maybeRune = make([]byte, 6)
n, err = reader.Read(maybeRune)
if err != nil || n != 6 || maybeRune[0] != '\\' || maybeRune[1] != 'u' {
// Not a valid UTF-16 surrogate pair.
if _, err := reader.Seek(int64(-n), io.SeekCurrent); err != nil {
return "", err
}
// Invalid low surrogate; write the replacement character.
value.WriteRune(utf8.RuneError)
} else {
rn1, err := strconv.ParseUint(string(maybeRune[2:]), 16, 32)
if err != nil {
return "", fmt.Errorf("invalid unicode escape sequence %s", string(maybeRune))
}
// Check if rn and rn1 are valid UTF-16 surrogate pairs.
if dec := utf16.DecodeRune(rn, rune(rn1)); dec != utf8.RuneError {
n = utf8.EncodeRune(maybeRune, dec)
// Write the decoded rune.
value.Write(maybeRune[:n])
}
}
inUEsc = false
continue
}
if inEsc {
b, err := reader.ReadByte()
if err != nil {
return "", err
}
switch b {
case 'b':
value.WriteByte('\b')
case 'f':
value.WriteByte('\f')
case 'n':
value.WriteByte('\n')
case 'r':
value.WriteByte('\r')
case 't':
value.WriteByte('\t')
case 'u':
inUEsc = true
case '/':
value.WriteByte('/')
case '\\':
value.WriteByte('\\')
case '"':
value.WriteByte('"')
default:
return "", errors.New("unexpected character in escape sequence " + string(b))
}
inEsc = false
continue
} else {
rn, _, err := reader.ReadRune()
if err != nil {
if err == io.EOF {
break
}
return "", err
}
if rn == '\\' {
inEsc = true
continue
}
if rn == '"' {
// String ends on un-escaped quote.
strEnds = true
break
}
value.WriteRune(rn)
}
}
if !strEnds {
return "", errors.New("unexpected end of input")
}
return value.String(), nil
}
// parseJSONInt64 reads a 64 bit integer from r.
func parseJSONInt64(r []byte) (int64, error) {
r, err := findElementStart(r)
if err != nil {
return 0, err
}
var num strings.Builder
for _, b := range r {
// int64 max is 19 digits long.
if num.Len() == 20 {
return 0, errors.New("number too large")
}
if strings.ContainsRune("0123456789", rune(b)) {
num.WriteByte(b)
} else {
break
}
}
n, err := strconv.ParseInt(num.String(), 10, 64)
return int64(n), err
}
// parseJSONBoolean reads a boolean from r.
func parseJSONBoolean(r []byte) (bool, error) {
r, err := findElementStart(r)
if err != nil {
return false, err
}
if bytes.HasPrefix(r, []byte("true")) {
return true, nil
}
if bytes.HasPrefix(r, []byte("false")) {
return false, nil
}
return false, errors.New("unable to parse boolean value")
}
type groupRecord struct {
groupName string
gid int64
}
func (g *groupRecord) unmarshalJSON(data []byte) error {
var (
kGroupName = []byte(`"groupName"`)
kGid = []byte(`"gid"`)
)
if i := bytes.Index(data, kGroupName); i != -1 {
groupname, err := parseJSONString(data[i+len(kGroupName):])
if err != nil {
return err
}
g.groupName = groupname
}
if i := bytes.Index(data, kGid); i != -1 {
gid, err := parseJSONInt64(data[i+len(kGid):])
if err != nil {
return err
}
g.gid = gid
}
return nil
}
// queryGroupDb queries the userdb interface for a gid, groupname, or both.
func (cl userdbClient) queryGroupDb(ctx context.Context, gid *int64, groupname string) (*Group, bool, error) {
group := groupRecord{}
request := userdbCall{
method: mGetGroupRecord,
parameters: callParameters{gid: gid, groupName: groupname},
}
if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err)
}
return &Group{
Name: group.groupName,
Gid: strconv.FormatInt(group.gid, 10),
}, true, nil
}
type userRecord struct {
userName string
realName string
uid int64
gid int64
homeDirectory string
}
func (u *userRecord) unmarshalJSON(data []byte) error {
var (
kUserName = []byte(`"userName"`)
kRealName = []byte(`"realName"`)
kUid = []byte(`"uid"`)
kGid = []byte(`"gid"`)
kHomeDirectory = []byte(`"homeDirectory"`)
)
if i := bytes.Index(data, kUserName); i != -1 {
username, err := parseJSONString(data[i+len(kUserName):])
if err != nil {
return err
}
u.userName = username
}
if i := bytes.Index(data, kRealName); i != -1 {
realname, err := parseJSONString(data[i+len(kRealName):])
if err != nil {
return err
}
u.realName = realname
}
if i := bytes.Index(data, kUid); i != -1 {
uid, err := parseJSONInt64(data[i+len(kUid):])
if err != nil {
return err
}
u.uid = uid
}
if i := bytes.Index(data, kGid); i != -1 {
gid, err := parseJSONInt64(data[i+len(kGid):])
if err != nil {
return err
}
u.gid = gid
}
if i := bytes.Index(data, kHomeDirectory); i != -1 {
homedir, err := parseJSONString(data[i+len(kHomeDirectory):])
if err != nil {
return err
}
u.homeDirectory = homedir
}
return nil
}
// queryUserDb queries the userdb interface for a uid, username, or both.
func (cl userdbClient) queryUserDb(ctx context.Context, uid *int64, username string) (*User, bool, error) {
user := userRecord{}
request := userdbCall{
method: mGetUserRecord,
parameters: callParameters{
uid: uid,
userName: username,
},
}
if ok, err := cl.query(ctx, &request, &user); !ok || err != nil {
return nil, ok, fmt.Errorf("error querying systemd-userdb user record: %s", err)
}
return &User{
Uid: strconv.FormatInt(user.uid, 10),
Gid: strconv.FormatInt(user.gid, 10),
Username: user.userName,
Name: user.realName,
HomeDir: user.homeDirectory,
}, true, nil
}
func (cl userdbClient) lookupGroup(ctx context.Context, groupname string) (*Group, bool, error) {
return cl.queryGroupDb(ctx, nil, groupname)
}
func (cl userdbClient) lookupGroupId(ctx context.Context, id string) (*Group, bool, error) {
gid, err := strconv.ParseInt(id, 10, 64)
if err != nil {
return nil, true, err
}
return cl.queryGroupDb(ctx, &gid, "")
}
func (cl userdbClient) lookupUser(ctx context.Context, username string) (*User, bool, error) {
return cl.queryUserDb(ctx, nil, username)
}
func (cl userdbClient) lookupUserId(ctx context.Context, id string) (*User, bool, error) {
uid, err := strconv.ParseInt(id, 10, 64)
if err != nil {
return nil, true, err
}
return cl.queryUserDb(ctx, &uid, "")
}
type memberships struct {
// Keys are groupNames and values are sets of userNames.
groupUsers map[string]map[string]struct{}
}
// unmarshalJSON expects many (userName, groupName) records separated by a null byte.
// This is used to build a membership map.
func (m *memberships) unmarshalJSON(data []byte) error {
if m.groupUsers == nil {
m.groupUsers = make(map[string]map[string]struct{})
}
var (
kUserName = []byte(`"userName"`)
kGroupName = []byte(`"groupName"`)
)
// Split records by null terminator.
records := bytes.Split(data, []byte{byte(0)})
for _, rec := range records {
if len(rec) == 0 {
continue
}
var groupName string
var userName string
var err error
if i := bytes.Index(rec, kGroupName); i != -1 {
if groupName, err = parseJSONString(rec[i+len(kGroupName):]); err != nil {
return err
}
}
if i := bytes.Index(rec, kUserName); i != -1 {
if userName, err = parseJSONString(rec[i+len(kUserName):]); err != nil {
return err
}
}
// Associate userName with groupName.
if groupName != "" && userName != "" {
if _, ok := m.groupUsers[groupName]; ok {
m.groupUsers[groupName][userName] = struct{}{}
} else {
m.groupUsers[groupName] = map[string]struct{}{userName: {}}
}
}
}
return nil
}
func (cl userdbClient) lookupGroupIds(ctx context.Context, username string) ([]string, bool, error) {
services, ok, err := cl.services()
if !ok || err != nil {
return nil, ok, err
}
// Fetch group memberships for username.
var ms memberships
request := userdbCall{
method: mGetMemberships,
parameters: callParameters{userName: username},
more: true,
}
if ok, err := cl.queryMany(ctx, services, &request, &ms); !ok || err != nil {
return nil, ok, fmt.Errorf("error querying systemd-userdb memberships record: %s", err)
}
// Fetch user group gid.
var group groupRecord
request = userdbCall{
method: mGetGroupRecord,
parameters: callParameters{groupName: username},
}
if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
return nil, ok, err
}
gids := []string{strconv.FormatInt(group.gid, 10)}
// Fetch group records for each group.
for g := range ms.groupUsers {
var group groupRecord
request.parameters.groupName = g
// Query group for gid.
if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err)
}
gids = append(gids, strconv.FormatInt(group.gid, 10))
}
return gids, true, nil
}

View file

@ -0,0 +1,504 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux
package user
import (
"bytes"
"context"
"errors"
"reflect"
"sort"
"strconv"
"strings"
"sync"
"syscall"
"testing"
"time"
"unicode/utf8"
)
func TestQueryNoUserdb(t *testing.T) {
cl := &userdbClient{dir: "/non/existent"}
if _, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib"); ok {
t.Fatalf("should fail but lookup has been handled or error is nil: %v", err)
}
}
type userdbTestData map[string]udbResponse
type udbResponse struct {
data []byte
delay time.Duration
}
func userdbServer(t *testing.T, sockFn string, data userdbTestData) {
ready := make(chan struct{})
go func() {
if err := serveUserdb(ready, sockFn, data); err != nil {
t.Error(err)
}
}()
<-ready
}
func (u userdbTestData) String() string {
var s strings.Builder
for k, v := range u {
s.WriteString("Request:\n")
s.WriteString(k)
s.WriteString("\nResponse:\n")
if v.delay > 0 {
s.WriteString("Delay: ")
s.WriteString(v.delay.String())
s.WriteString("\n")
}
s.WriteString("Data:\n")
s.Write(v.data)
s.WriteString("\n")
}
return s.String()
}
// serverUserdb is a simple userdb server that replies to VARLINK method calls.
// A message is sent on the ready channel when the server is ready to accept calls.
// The server will reply to each request in the data map. If a request is not
// found in the map, the server will return an error.
func serveUserdb(ready chan<- struct{}, sockFn string, data userdbTestData) error {
sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
if err != nil {
return err
}
defer syscall.Close(sockFd)
if err := syscall.Bind(sockFd, &syscall.SockaddrUnix{Name: sockFn}); err != nil {
return err
}
if err := syscall.Listen(sockFd, 1); err != nil {
return err
}
// Send ready signal.
ready <- struct{}{}
var srvGroup sync.WaitGroup
srvErrs := make(chan error, len(data))
for len(data) != 0 {
nfd, _, err := syscall.Accept(sockFd)
if err != nil {
syscall.Close(nfd)
return err
}
// Read request.
buf := make([]byte, 4096)
n, err := syscall.Read(nfd, buf)
if err != nil {
syscall.Close(nfd)
return err
}
if n == 0 {
// Client went away.
continue
}
if buf[n-1] != 0 {
syscall.Close(nfd)
return errors.New("request not null terminated")
}
// Remove null terminator.
buf = buf[:n-1]
got := string(buf)
// Fetch response for request.
response, ok := data[got]
if !ok {
syscall.Close(nfd)
msg := "unexpected request:\n" + got + "\n\ndata:\n" + data.String()
return errors.New(msg)
}
delete(data, got)
srvGroup.Add(1)
go func() {
defer srvGroup.Done()
if err := serveClient(nfd, response); err != nil {
srvErrs <- err
}
}()
}
srvGroup.Wait()
// Combine serve errors if any.
if len(srvErrs) > 0 {
var errs []error
for err := range srvErrs {
errs = append(errs, err)
}
return errors.Join(errs...)
}
return nil
}
func serveClient(fd int, response udbResponse) error {
defer syscall.Close(fd)
time.Sleep(response.delay)
data := response.data
if len(data) != 0 && data[len(data)-1] != 0 {
data = append(data, 0)
}
written := 0
for written < len(data) {
if n, err := syscall.Write(fd, data[written:]); err != nil {
return err
} else {
written += n
}
}
return nil
}
func TestSlowUserdbLookup(t *testing.T) {
tmpdir := t.TempDir()
data := userdbTestData{
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
delay: time.Hour,
},
}
userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
cl := &userdbClient{dir: tmpdir}
// Lookup should timeout.
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
if _, ok, _ := cl.lookupGroup(ctx, "stdlibcontrib"); ok {
t.Fatalf("lookup should not be handled but was")
}
}
func TestFastestUserdbLookup(t *testing.T) {
tmpdir := t.TempDir()
fastData := userdbTestData{
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"fast","groupName":"stdlibcontrib"}}`: udbResponse{
data: []byte(
`{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
),
},
}
slowData := userdbTestData{
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"slow","groupName":"stdlibcontrib"}}`: udbResponse{
delay: 50 * time.Millisecond,
data: []byte(
`{"parameters":{"record":{"groupName":"stdlibcontrib","gid":182,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
),
},
}
userdbServer(t, tmpdir+"/"+"fast", fastData)
userdbServer(t, tmpdir+"/"+"slow", slowData)
cl := &userdbClient{dir: tmpdir}
group, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib")
if !ok {
t.Fatalf("lookup should be handled but was not")
}
if err != nil {
t.Fatalf("lookup should not fail but did: %v", err)
}
if group.Gid != "181" {
t.Fatalf("lookup should return group 181 but returned %s", group.Gid)
}
}
func TestUserdbLookupGroup(t *testing.T) {
tmpdir := t.TempDir()
data := userdbTestData{
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
data: []byte(
`{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
),
},
}
userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
groupname := "stdlibcontrib"
want := &Group{
Name: "stdlibcontrib",
Gid: "181",
}
cl := &userdbClient{dir: tmpdir}
got, ok, err := cl.lookupGroup(context.Background(), groupname)
if !ok {
t.Fatal("lookup should have been handled")
}
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("lookupGroup(%s) = %v, want %v", groupname, got, want)
}
}
func TestUserdbLookupUser(t *testing.T) {
tmpdir := t.TempDir()
data := userdbTestData{
`{"method":"io.systemd.UserDatabase.GetUserRecord","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"}}`: udbResponse{
data: []byte(
`{"parameters":{"record":{"userName":"stdlibcontrib","uid":181,"gid":181,"realName":"Stdlib Contrib","homeDirectory":"/home/stdlibcontrib","status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
),
},
}
userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
username := "stdlibcontrib"
want := &User{
Uid: "181",
Gid: "181",
Username: "stdlibcontrib",
Name: "Stdlib Contrib",
HomeDir: "/home/stdlibcontrib",
}
cl := &userdbClient{dir: tmpdir}
got, ok, err := cl.lookupUser(context.Background(), username)
if !ok {
t.Fatal("lookup should have been handled")
}
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("lookupUser(%s) = %v, want %v", username, got, want)
}
}
func TestUserdbLookupGroupIds(t *testing.T) {
tmpdir := t.TempDir()
data := userdbTestData{
`{"method":"io.systemd.UserDatabase.GetMemberships","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"},"more":true}`: udbResponse{
data: []byte(
`{"parameters":{"userName":"stdlibcontrib","groupName":"stdlib"},"continues":true}` + "\x00" + `{"parameters":{"userName":"stdlibcontrib","groupName":"contrib"}}`,
),
},
// group records
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
data: []byte(
`{"parameters":{"record":{"groupName":"stdlibcontrib","members":["stdlibcontrib"],"gid":181,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
),
},
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlib"}}`: udbResponse{
data: []byte(
`{"parameters":{"record":{"groupName":"stdlib","members":["stdlibcontrib"],"gid":182,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
),
},
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"contrib"}}`: udbResponse{
data: []byte(
`{"parameters":{"record":{"groupName":"contrib","members":["stdlibcontrib"],"gid":183,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
),
},
}
userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
username := "stdlibcontrib"
want := []string{"181", "182", "183"}
cl := &userdbClient{dir: tmpdir}
got, ok, err := cl.lookupGroupIds(context.Background(), username)
if !ok {
t.Fatal("lookup should have been handled")
}
if err != nil {
t.Fatal(err)
}
// Result order is not specified so sort it.
sort.Strings(got)
if !reflect.DeepEqual(got, want) {
t.Fatalf("lookupGroupIds(%s) = %v, want %v", username, got, want)
}
}
var findElementStartTestCases = []struct {
in []byte
want []byte
err bool
}{
{in: []byte(`:`), want: []byte(``)},
{in: []byte(`: `), want: []byte(``)},
{in: []byte(`:"foo"`), want: []byte(`"foo"`)},
{in: []byte(` :"foo"`), want: []byte(`"foo"`)},
{in: []byte(` 1231 :"foo"`), err: true},
{in: []byte(``), err: true},
{in: []byte(`"foo"`), err: true},
{in: []byte(`foo`), err: true},
}
func TestFindElementStart(t *testing.T) {
for i, tc := range findElementStartTestCases {
t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
got, err := findElementStart(tc.in)
if tc.err && err == nil {
t.Errorf("want err for findElementStart(%s), got nil", tc.in)
}
if !tc.err {
if err != nil {
t.Errorf("findElementStart(%s) unexpected error: %s", tc.in, err.Error())
}
if !bytes.Contains(tc.in, got) {
t.Errorf("%s should contain %s but does not", tc.in, got)
}
}
})
}
}
func FuzzFindElementStart(f *testing.F) {
for _, tc := range findElementStartTestCases {
if !tc.err {
f.Add(tc.in)
}
}
f.Fuzz(func(t *testing.T, b []byte) {
if out, err := findElementStart(b); err == nil && !bytes.Contains(b, out) {
t.Errorf("%s, %v", out, err)
}
})
}
var parseJSONStringTestCases = []struct {
in []byte
want string
err bool
}{
{in: []byte(`:""`)},
{in: []byte(`:"\n"`), want: "\n"},
{in: []byte(`: "\""`), want: "\""},
{in: []byte(`:"\t \\"`), want: "\t \\"},
{in: []byte(`:"\\\\"`), want: `\\`},
{in: []byte(`::`), err: true},
{in: []byte(`""`), err: true},
{in: []byte(`"`), err: true},
{in: []byte(":\"0\xE5"), err: true},
{in: []byte{':', '"', 0xFE, 0xFE, 0xFF, 0xFF, '"'}, want: "\uFFFD\uFFFD\uFFFD\uFFFD"},
{in: []byte(`:"\u0061a"`), want: "aa"},
{in: []byte(`:"\u0159\u0170"`), want: "řŰ"},
{in: []byte(`:"\uD800\uDC00"`), want: "\U00010000"},
{in: []byte(`:"\uD800"`), want: "\uFFFD"},
{in: []byte(`:"\u000"`), err: true},
{in: []byte(`:"\u00MF"`), err: true},
{in: []byte(`:"\uD800\uDC0"`), err: true},
}
func TestParseJSONString(t *testing.T) {
for i, tc := range parseJSONStringTestCases {
t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
got, err := parseJSONString(tc.in)
if tc.err && err == nil {
t.Errorf("want err for parseJSONString(%s), got nil", tc.in)
}
if !tc.err {
if err != nil {
t.Errorf("parseJSONString(%s) unexpected error: %s", tc.in, err.Error())
}
if tc.want != got {
t.Errorf("parseJSONString(%s) = %s, want %s", tc.in, got, tc.want)
}
}
})
}
}
func FuzzParseJSONString(f *testing.F) {
for _, tc := range parseJSONStringTestCases {
f.Add(tc.in)
}
f.Fuzz(func(t *testing.T, b []byte) {
if out, err := parseJSONString(b); err == nil && !utf8.ValidString(out) {
t.Errorf("parseJSONString(%s) = %s, invalid string", b, out)
}
})
}
var parseJSONInt64TestCases = []struct {
in []byte
want int64
err bool
}{
{in: []byte(":1235"), want: 1235},
{in: []byte(": 123"), want: 123},
{in: []byte(":0")},
{in: []byte(":5012313123131231"), want: 5012313123131231},
{in: []byte("1231"), err: true},
}
func TestParseJSONInt64(t *testing.T) {
for i, tc := range parseJSONInt64TestCases {
t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
got, err := parseJSONInt64(tc.in)
if tc.err && err == nil {
t.Errorf("want err for parseJSONInt64(%s), got nil", tc.in)
}
if !tc.err {
if err != nil {
t.Errorf("parseJSONInt64(%s) unexpected error: %s", tc.in, err.Error())
}
if tc.want != got {
t.Errorf("parseJSONInt64(%s) = %d, want %d", tc.in, got, tc.want)
}
}
})
}
}
func FuzzParseJSONInt64(f *testing.F) {
for _, tc := range parseJSONInt64TestCases {
f.Add(tc.in)
}
f.Fuzz(func(t *testing.T, b []byte) {
if out, err := parseJSONInt64(b); err == nil &&
!bytes.Contains(b, []byte(strconv.FormatInt(out, 10))) {
t.Errorf("parseJSONInt64(%s) = %d, %v", b, out, err)
}
})
}
var parseJSONBooleanTestCases = []struct {
in []byte
want bool
err bool
}{
{in: []byte(": true "), want: true},
{in: []byte(":true "), want: true},
{in: []byte(": false "), want: false},
{in: []byte(":false "), want: false},
{in: []byte("true"), err: true},
{in: []byte("false"), err: true},
{in: []byte("foo"), err: true},
}
func TestParseJSONBoolean(t *testing.T) {
for i, tc := range parseJSONBooleanTestCases {
t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
got, err := parseJSONBoolean(tc.in)
if tc.err && err == nil {
t.Errorf("want err for parseJSONBoolean(%s), got nil", tc.in)
}
if !tc.err {
if err != nil {
t.Errorf("parseJSONBoolean(%s) unexpected error: %s", tc.in, err.Error())
}
if tc.want != got {
t.Errorf("parseJSONBoolean(%s) = %t, want %t", tc.in, got, tc.want)
}
}
})
}
}
func FuzzParseJSONBoolean(f *testing.F) {
for _, tc := range parseJSONBooleanTestCases {
f.Add(tc.in)
}
f.Fuzz(func(t *testing.T, b []byte) {
if out, err := parseJSONBoolean(b); err == nil && !bytes.Contains(b, []byte(strconv.FormatBool(out))) {
t.Errorf("parseJSONBoolean(%s) = %t, %v", b, out, err)
}
})
}

View file

@ -0,0 +1,29 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !linux
package user
import "context"
func (cl userdbClient) lookupGroup(_ context.Context, _ string) (*Group, bool, error) {
return nil, false, nil
}
func (cl userdbClient) lookupGroupId(_ context.Context, _ string) (*Group, bool, error) {
return nil, false, nil
}
func (cl userdbClient) lookupUser(_ context.Context, _ string) (*User, bool, error) {
return nil, false, nil
}
func (cl userdbClient) lookupUserId(_ context.Context, _ string) (*User, bool, error) {
return nil, false, nil
}
func (cl userdbClient) lookupGroupIds(_ context.Context, _ string) ([]string, bool, error) {
return nil, false, nil
}