mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 01:34:01 +00:00
Fix Azure join for identities across resource groups (#28927)
This change fixes a bug in the Azure join method where a VM's identity can't be verified if it's in a different resource group from its managed identity.
This commit is contained in:
parent
9d7f553bf3
commit
9ee7e5774f
|
@ -268,7 +268,7 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
|
|||
// If the token is from a user-assigned managed identity, the resource ID is
|
||||
// for the identity and we need to look the VM up by VM ID.
|
||||
} else {
|
||||
vm, err = vmClient.GetByVMID(ctx, resourceID.ResourceGroupName, vmID)
|
||||
vm, err = vmClient.GetByVMID(ctx, types.Wildcard, vmID)
|
||||
if err != nil {
|
||||
if trace.IsNotFound(err) {
|
||||
return nil, trace.AccessDenied("no VM found with matching VM ID")
|
||||
|
|
|
@ -508,6 +508,25 @@ func (m *ARMComputeMock) NewListPager(resourceGroup string, _ *armcompute.Virtua
|
|||
})
|
||||
}
|
||||
|
||||
func (m *ARMComputeMock) NewListAllPager(_ *armcompute.VirtualMachinesClientListAllOptions) *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse] {
|
||||
var vms []*armcompute.VirtualMachine
|
||||
for _, resourceGroupVMs := range m.VirtualMachines {
|
||||
vms = append(vms, resourceGroupVMs...)
|
||||
}
|
||||
return runtime.NewPager(runtime.PagingHandler[armcompute.VirtualMachinesClientListAllResponse]{
|
||||
More: func(page armcompute.VirtualMachinesClientListAllResponse) bool {
|
||||
return page.NextLink != nil && len(*page.NextLink) > 0
|
||||
},
|
||||
Fetcher: func(ctx context.Context, page *armcompute.VirtualMachinesClientListAllResponse) (armcompute.VirtualMachinesClientListAllResponse, error) {
|
||||
return armcompute.VirtualMachinesClientListAllResponse{
|
||||
VirtualMachineListResult: armcompute.VirtualMachineListResult{
|
||||
Value: vms,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *ARMComputeMock) Get(_ context.Context, _ string, _ string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) {
|
||||
return armcompute.VirtualMachinesClientGetResponse{
|
||||
VirtualMachine: m.GetResult,
|
||||
|
|
|
@ -23,27 +23,31 @@ import (
|
|||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
|
||||
"github.com/gravitational/trace"
|
||||
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
)
|
||||
|
||||
// armCompute provides an interface for an Azure Virtual Machine client.
|
||||
// armCompute provides an interface for an Azure virtual machine client.
|
||||
type armCompute interface {
|
||||
// Get retrieves information about an Azure Virtual Machine.
|
||||
// Get retrieves information about an Azure virtual machine.
|
||||
Get(ctx context.Context, resourceGroupName string, vmName string, options *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error)
|
||||
// NewListPagers lists Azure Virtual Machines.
|
||||
// NewListPagers lists Azure virtual Machines.
|
||||
NewListPager(resourceGroup string, opts *armcompute.VirtualMachinesClientListOptions) *runtime.Pager[armcompute.VirtualMachinesClientListResponse]
|
||||
// NewListAllPager lists Azure virtual machines in any resource group.
|
||||
NewListAllPager(opts *armcompute.VirtualMachinesClientListAllOptions) *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse]
|
||||
}
|
||||
|
||||
// VirtualMachinesClient is a client for Azure Virtual Machines.
|
||||
// VirtualMachinesClient is a client for Azure virtual machines.
|
||||
type VirtualMachinesClient interface {
|
||||
// Get returns the Virtual Machine for the given resource ID.
|
||||
// Get returns the virtual machine for the given resource ID.
|
||||
Get(ctx context.Context, resourceID string) (*VirtualMachine, error)
|
||||
// GetByVMID returns the Virtual Machine for a given VM ID.
|
||||
// GetByVMID returns the virtual machine for a given VM ID.
|
||||
GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error)
|
||||
// ListVirtualMachines gets all of the virtual machines in the given resource group.
|
||||
ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error)
|
||||
}
|
||||
|
||||
// VirtualMachine represents an Azure Virtual Machine.
|
||||
// VirtualMachine represents an Azure virtual machine.
|
||||
type VirtualMachine struct {
|
||||
// ID resource ID.
|
||||
ID string `json:"id,omitempty"`
|
||||
|
@ -59,18 +63,18 @@ type VirtualMachine struct {
|
|||
Identities []Identity
|
||||
}
|
||||
|
||||
// Identitiy represents an Azure Virtual Machine identity.
|
||||
// Identitiy represents an Azure virtual machine identity.
|
||||
type Identity struct {
|
||||
// ResourceID the identity resource ID.
|
||||
ResourceID string
|
||||
}
|
||||
|
||||
type vmClient struct {
|
||||
// api is the Azure Virtual Machine client.
|
||||
// api is the Azure virtual machine client.
|
||||
api armCompute
|
||||
}
|
||||
|
||||
// NewVirtualMachinesClient creates a new Azure Virtual Machines client by
|
||||
// NewVirtualMachinesClient creates a new Azure virtual machines client by
|
||||
// subscription and credentials.
|
||||
func NewVirtualMachinesClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (VirtualMachinesClient, error) {
|
||||
computeAPI, err := armcompute.NewVirtualMachinesClient(subscription, cred, options)
|
||||
|
@ -81,7 +85,7 @@ func NewVirtualMachinesClient(subscription string, cred azcore.TokenCredential,
|
|||
return NewVirtualMachinesClientByAPI(computeAPI), nil
|
||||
}
|
||||
|
||||
// NewVirtualMachinesClientByAPI creates a new Azure Virtual Machines client by
|
||||
// NewVirtualMachinesClientByAPI creates a new Azure virtual machines client by
|
||||
// ARM API client.
|
||||
func NewVirtualMachinesClientByAPI(api armCompute) VirtualMachinesClient {
|
||||
return &vmClient{
|
||||
|
@ -121,7 +125,7 @@ func parseVirtualMachine(vm *armcompute.VirtualMachine) (*VirtualMachine, error)
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Get returns the Virtual Machine for the given resource ID.
|
||||
// Get returns the virtual machine for the given resource ID.
|
||||
func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine, error) {
|
||||
parsedResourceID, err := arm.ParseResourceID(resourceID)
|
||||
if err != nil {
|
||||
|
@ -137,7 +141,7 @@ func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine,
|
|||
return vm, trace.Wrap(err)
|
||||
}
|
||||
|
||||
// GetByVMID returns the Virtual Machine for a given VM ID.
|
||||
// GetByVMID returns the virtual machine for a given VM ID.
|
||||
func (c *vmClient) GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error) {
|
||||
vms, err := c.ListVirtualMachines(ctx, resourceGroup)
|
||||
if err != nil {
|
||||
|
@ -152,17 +156,48 @@ func (c *vmClient) GetByVMID(ctx context.Context, resourceGroup, vmID string) (*
|
|||
return nil, trace.NotFound("no VM with ID %q", vmID)
|
||||
}
|
||||
|
||||
// ListVirtualMachines lists all virtual machines in a given resource group using the Azure Virtual Machines API.
|
||||
type vmPager struct {
|
||||
more func() bool
|
||||
nextPage func(context.Context) ([]*armcompute.VirtualMachine, error)
|
||||
}
|
||||
|
||||
func newListPager(azurePager *runtime.Pager[armcompute.VirtualMachinesClientListResponse]) vmPager {
|
||||
return vmPager{
|
||||
more: azurePager.More,
|
||||
nextPage: func(ctx context.Context) ([]*armcompute.VirtualMachine, error) {
|
||||
res, err := azurePager.NextPage(ctx)
|
||||
return res.Value, trace.Wrap(err)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newListAllPager(azurePager *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse]) vmPager {
|
||||
return vmPager{
|
||||
more: azurePager.More,
|
||||
nextPage: func(ctx context.Context) ([]*armcompute.VirtualMachine, error) {
|
||||
res, err := azurePager.NextPage(ctx)
|
||||
return res.Value, trace.Wrap(err)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ListVirtualMachines lists all virtual machines in a given resource group
|
||||
// using the Azure virtual machines API. If resourceGroup is "*", it lists
|
||||
// all virtual machines in any resource group.
|
||||
func (c *vmClient) ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error) {
|
||||
pagerOpts := &armcompute.VirtualMachinesClientListOptions{}
|
||||
pager := c.api.NewListPager(resourceGroup, pagerOpts)
|
||||
var pager vmPager
|
||||
if resourceGroup == types.Wildcard {
|
||||
pager = newListAllPager(c.api.NewListAllPager(&armcompute.VirtualMachinesClientListAllOptions{}))
|
||||
} else {
|
||||
pager = newListPager(c.api.NewListPager(resourceGroup, &armcompute.VirtualMachinesClientListOptions{}))
|
||||
}
|
||||
var virtualMachines []*armcompute.VirtualMachine
|
||||
for pager.More() {
|
||||
res, err := pager.NextPage(ctx)
|
||||
for pager.more() {
|
||||
res, err := pager.nextPage(ctx)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(ConvertResponseError(err))
|
||||
}
|
||||
virtualMachines = append(virtualMachines, res.Value...)
|
||||
virtualMachines = append(virtualMachines, res...)
|
||||
}
|
||||
|
||||
return virtualMachines, nil
|
||||
|
|
|
@ -22,6 +22,8 @@ import (
|
|||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
)
|
||||
|
||||
func TestGetVirtualMachine(t *testing.T) {
|
||||
|
@ -169,6 +171,11 @@ func TestListVirtualMachines(t *testing.T) {
|
|||
resourceGroup: "rgfake",
|
||||
wantIDs: []string{},
|
||||
},
|
||||
{
|
||||
name: "all resource groups",
|
||||
resourceGroup: types.Wildcard,
|
||||
wantIDs: []string{"vm1", "vm2", "vm3", "vm4"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
|
Loading…
Reference in a new issue