mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 16:53:57 +00:00
Add hidden cli command: wait-no-resolve (#19277)
Part of https://github.com/gravitational/teleport/pull/18274 This commit introduces a new hidden `wait` CLI subcommand: - `teleport wait no-resolve <domain-name>` resolves a domain name and exits only when no IPs are resolved. This CLI command should be used in the Helm chart, as an init-container, to block proxies from rolling out until all auth pods have been successfully rolled-out. - `teleport wait duration 30s` has the same behaviour as `sleep 30`. Due to image hardening we won't have `sleep` available, but waiting 30 seconds in a preStop hook is required to ensure a 100% seamless pod rollout on kube-proxy-based clusters.
This commit is contained in:
parent
3fd74ae3fd
commit
44f57bf346
|
@ -746,6 +746,11 @@ const (
|
|||
// SFTPSubCommand is the sub-command Teleport uses to re-exec itself to
|
||||
// handle SFTP connections.
|
||||
SFTPSubCommand = "sftp"
|
||||
|
||||
// WaitSubCommand is the sub-command Teleport uses to wait
|
||||
// until a domain name stops resolving. Its main use is to ensure no
|
||||
// auth instances are still running the previous major version.
|
||||
WaitSubCommand = "wait"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -76,6 +76,7 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con
|
|||
configureDiscoveryBootstrapFlags configureDiscoveryBootstrapFlags
|
||||
dbConfigCreateFlags createDatabaseConfigFlags
|
||||
systemdInstallFlags installSystemdFlags
|
||||
waitFlags waitFlags
|
||||
)
|
||||
|
||||
// define commands:
|
||||
|
@ -376,6 +377,14 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con
|
|||
dumpNodeConfigure.Flag("join-method", "Method to use to join the cluster (token, iam, ec2, kubernetes)").Default("token").EnumVar(&dumpFlags.JoinMethod, "token", "iam", "ec2", "kubernetes")
|
||||
dumpNodeConfigure.Flag("node-name", "Name for the teleport node.").StringVar(&dumpFlags.NodeName)
|
||||
|
||||
waitCmd := app.Command(teleport.WaitSubCommand, "Used internally by Teleport to onWait until a specific condition is reached.").Hidden()
|
||||
waitNoResolveCmd := waitCmd.Command("no-resolve", "Used internally to onWait until a domain stops resolving IP addresses.")
|
||||
waitNoResolveCmd.Arg("domain", "Domain that is resolved.").StringVar(&waitFlags.domain)
|
||||
waitNoResolveCmd.Flag("period", "Resolution try period. A jitter is applied.").Default(waitNoResolveDefaultPeriod).DurationVar(&waitFlags.period)
|
||||
waitNoResolveCmd.Flag("timeout", "Stops waiting after this duration and exits in error.").Default(waitNoResolveDefaultTimeout).DurationVar(&waitFlags.timeout)
|
||||
waitDurationCmd := waitCmd.Command("duration", "Used internally to onWait a given duration before exiting.")
|
||||
waitDurationCmd.Arg("duration", "Duration to onWait before exit.").DurationVar(&waitFlags.duration)
|
||||
|
||||
// parse CLI commands+flags:
|
||||
utils.UpdateAppUsageTemplate(app, options.Args)
|
||||
command, err := app.Parse(options.Args)
|
||||
|
@ -431,6 +440,10 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con
|
|||
srv.RunAndExit(teleport.CheckHomeDirSubCommand)
|
||||
case park.FullCommand():
|
||||
srv.RunAndExit(teleport.ParkSubCommand)
|
||||
case waitNoResolveCmd.FullCommand():
|
||||
err = onWaitNoResolve(waitFlags)
|
||||
case waitDurationCmd.FullCommand():
|
||||
err = onWaitDuration(waitFlags)
|
||||
case ver.FullCommand():
|
||||
utils.PrintVersion()
|
||||
case dbConfigureCreate.FullCommand():
|
||||
|
|
164
tool/teleport/common/wait.go
Normal file
164
tool/teleport/common/wait.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
/*
|
||||
Copyright 2022-2023 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/gravitational/teleport/api/utils/retryutils"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
"github.com/gravitational/teleport/lib/utils/interval"
|
||||
)
|
||||
|
||||
const (
|
||||
waitNoResolveDefaultPeriod = "10s"
|
||||
waitNoResolveDefaultTimeout = "10m"
|
||||
)
|
||||
|
||||
type waitFlags struct {
|
||||
duration time.Duration
|
||||
domain string
|
||||
period time.Duration
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func onWaitDuration(flags waitFlags) error {
|
||||
utils.InitLogger(utils.LoggingForCLI, log.DebugLevel)
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT, os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
return trace.Wrap(waitDuration(ctx, flags.duration))
|
||||
}
|
||||
|
||||
func onWaitNoResolve(flags waitFlags) error {
|
||||
utils.InitLogger(utils.LoggingForCLI, log.DebugLevel)
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT, os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
return trace.Wrap(waitNoResolve(ctx, flags.domain, flags.period, flags.timeout))
|
||||
}
|
||||
|
||||
func waitDuration(ctx context.Context, duration time.Duration) error {
|
||||
if duration == 0 {
|
||||
return trace.BadParameter("no duration provided")
|
||||
}
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, duration)
|
||||
defer cancel()
|
||||
|
||||
<-timeoutCtx.Done()
|
||||
|
||||
err := timeoutCtx.Err()
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitNoResolve(ctx context.Context, domain string, period, timeout time.Duration) error {
|
||||
if domain == "" {
|
||||
return trace.BadParameter("no domain provided")
|
||||
}
|
||||
|
||||
if period == 0 {
|
||||
return trace.BadParameter("no period provided")
|
||||
}
|
||||
|
||||
if timeout == 0 {
|
||||
return trace.BadParameter("no timeout provided")
|
||||
}
|
||||
|
||||
var err error
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// We resolve the previous auth service until there's no IP returned.
|
||||
// This means all pods got rollout, and we don't risk connecting to
|
||||
// an auth pod running the previous version
|
||||
periodic := interval.New(interval.Config{
|
||||
Duration: period,
|
||||
FirstDuration: time.Millisecond,
|
||||
Jitter: retryutils.NewSeventhJitter(),
|
||||
})
|
||||
defer periodic.Stop()
|
||||
|
||||
exit := false
|
||||
for !exit {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context has been canceled, either we reached the timeout
|
||||
// or something else happened to the parent context
|
||||
err = ctx.Err()
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return trace.LimitExceeded(
|
||||
"timeout (%s) reached, but domain '%s' is still resolving",
|
||||
timeout,
|
||||
domain,
|
||||
)
|
||||
}
|
||||
return trace.Wrap(err)
|
||||
|
||||
case <-periodic.Next():
|
||||
exit, err = checkDomainNoResolve(domain)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("no endpoints found, exiting with success code")
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkDomainNoResolve(domainName string) (exit bool, err error) {
|
||||
endpoints, err := countEndpoints(domainName)
|
||||
if err != nil {
|
||||
dnsErr, ok := err.(*net.DNSError)
|
||||
if !ok {
|
||||
log.Errorf("unexpected error when resolving domain %s : %s", domainName, err)
|
||||
return false, trace.Wrap(err)
|
||||
}
|
||||
if dnsErr.Temporary() {
|
||||
log.Warnf("temporary error when resolving domain %s : %s", domainName, err)
|
||||
return false, nil
|
||||
}
|
||||
if dnsErr.IsNotFound {
|
||||
log.Infof("domain %s not found", domainName)
|
||||
return true, nil
|
||||
}
|
||||
log.Errorf("error when resolving domain %s : %s", domainName, err)
|
||||
return false, nil
|
||||
}
|
||||
log.Infof("%d endpoints found when resolving domain %s", endpoints, domainName)
|
||||
return endpoints == 0, nil
|
||||
}
|
||||
|
||||
func countEndpoints(serviceName string) (int, error) {
|
||||
ips, err := net.LookupIP(serviceName)
|
||||
if err != nil {
|
||||
return 0, trace.Wrap(err)
|
||||
}
|
||||
return len(ips), nil
|
||||
}
|
Loading…
Reference in a new issue