diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index 520583036ca..c67511e286e 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -236,7 +236,11 @@ func (s *APIServer) upsertServer(auth ClientI, role teleport.Role, w http.Respon } switch role { case teleport.RoleNode: - server.SetNamespace(p.ByName("namespace")) + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + server.SetNamespace(namespace) if err := auth.UpsertNode(server); err != nil { return nil, trace.Wrap(err) } @@ -259,7 +263,11 @@ func (s *APIServer) upsertNode(auth ClientI, w http.ResponseWriter, r *http.Requ // getNodes returns registered SSH nodes func (s *APIServer) getNodes(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { - servers, err := auth.GetNodes(p.ByName("namespace")) + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + servers, err := auth.GetNodes(namespace) if err != nil { return nil, trace.Wrap(err) } @@ -866,7 +874,11 @@ func (s *APIServer) createSession(auth ClientI, w http.ResponseWriter, r *http.R if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) } - req.Session.Namespace = p.ByName("namespace") + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + req.Session.Namespace = namespace if err := auth.CreateSession(req.Session); err != nil { return nil, trace.Wrap(err) } @@ -882,7 +894,11 @@ func (s *APIServer) updateSession(auth ClientI, w http.ResponseWriter, r *http.R if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) } - req.Update.Namespace = p.ByName("namespace") + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + req.Update.Namespace = namespace if err := auth.UpdateSession(req.Update); err != nil { return nil, trace.Wrap(err) } @@ -890,7 +906,11 @@ func (s *APIServer) updateSession(auth ClientI, w http.ResponseWriter, r *http.R } func (s *APIServer) getSessions(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { - sessions, err := auth.GetSessions(p.ByName("namespace")) + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + sessions, err := auth.GetSessions(namespace) if err != nil { return nil, trace.Wrap(err) } @@ -902,7 +922,11 @@ func (s *APIServer) getSession(auth ClientI, w http.ResponseWriter, r *http.Requ if err != nil { return nil, trace.Wrap(err) } - se, err := auth.GetSession(p.ByName("namespace"), *sid) + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + se, err := auth.GetSession(namespace, *sid) if err != nil { return nil, trace.Wrap(err) } @@ -1210,6 +1234,9 @@ func (s *APIServer) postSessionChunk(auth ClientI, w http.ResponseWriter, r *htt return nil, trace.Wrap(err) } namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } if err = auth.PostSessionChunk(namespace, *sid, r.Body); err != nil { return nil, trace.Wrap(err) } @@ -1226,6 +1253,9 @@ func (s *APIServer) getSessionChunk(auth ClientI, w http.ResponseWriter, r *http return nil, trace.BadParameter("missing parameter id") } namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } // "offset bytes" query param offsetBytes, err := strconv.Atoi(r.URL.Query().Get("offset")) @@ -1260,6 +1290,9 @@ func (s *APIServer) getSessionEvents(auth ClientI, w http.ResponseWriter, r *htt return nil, trace.Wrap(err) } namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } afterN, err := strconv.Atoi(r.URL.Query().Get("after")) if err != nil { afterN = 0 @@ -1292,6 +1325,10 @@ func (s *APIServer) getNamespaces(auth ClientI, w http.ResponseWriter, r *http.R func (s *APIServer) getNamespace(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { name := p.ByName("namespace") + if !services.IsValidNamespace(name) { + return nil, trace.BadParameter("invalid namespace %q", name) + } + namespace, err := auth.GetNamespace(name) if err != nil { return nil, trace.Wrap(err) @@ -1301,6 +1338,10 @@ func (s *APIServer) getNamespace(auth ClientI, w http.ResponseWriter, r *http.Re func (s *APIServer) deleteNamespace(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { name := p.ByName("namespace") + if !services.IsValidNamespace(name) { + return nil, trace.BadParameter("invalid namespace %q", name) + } + err := auth.DeleteNamespace(name) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/services/namespace.go b/lib/services/namespace.go index 8bd66d521d0..c14d31ac048 100644 --- a/lib/services/namespace.go +++ b/lib/services/namespace.go @@ -18,6 +18,7 @@ package services import ( "fmt" + "regexp" "github.com/gravitational/teleport/lib/utils" @@ -41,6 +42,11 @@ func (n *Namespace) CheckAndSetDefaults() error { if err := n.Metadata.Check(); err != nil { return trace.Wrap(err) } + isValid := IsValidNamespace(n.Metadata.Name) + if !isValid { + return trace.BadParameter("namespace %q is invalid", n.Metadata.Name) + } + return nil } @@ -103,3 +109,9 @@ func (s SortedNamespaces) Less(i, j int) bool { func (s SortedNamespaces) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func IsValidNamespace(s string) bool { + return validNamespace.MatchString(s) +} + +var validNamespace = regexp.MustCompile(`[A-Za-z0-9]+`) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 78b54bcc54c..3ff1e4469a7 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -1039,6 +1039,9 @@ func (m *Handler) getSiteNodes(w http.ResponseWriter, r *http.Request, p httprou return nil, trace.Wrap(err) } namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } servers, err := clt.GetNodes(namespace) if err != nil { return nil, trace.Wrap(err) @@ -1088,6 +1091,11 @@ func (m *Handler) siteNodeConnect( ctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + q := r.URL.Query() params := q.Get("params") if params == "" { @@ -1101,7 +1109,7 @@ func (m *Handler) siteNodeConnect( log.Debugf("[WEB] new terminal request for ns=%s, server=%s, login=%s", req.Namespace, req.ServerID, req.Login) - req.Namespace = p.ByName("namespace") + req.Namespace = namespace req.ProxyHostPort = m.ProxyHostPort() term, err := newTerminal(*req, ctx, site) @@ -1137,7 +1145,12 @@ func (m *Handler) siteSessionStream(w http.ResponseWriter, r *http.Request, p ht return nil, trace.Wrap(err) } - connect, err := newSessionStreamHandler(p.ByName("namespace"), + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + + connect, err := newSessionStreamHandler(namespace, *sessionID, ctx, site, m.sessionStreamPollPeriod) if err != nil { return nil, trace.Wrap(err) @@ -1175,6 +1188,11 @@ type siteSessionGenerateResponse struct { // {"session": {"id": "session-id", "terminal_params": {"w": 100, "h": 100}, "login": "centos"}} // func (m *Handler) siteSessionGenerate(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + var req *siteSessionGenerateReq if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) @@ -1182,7 +1200,7 @@ func (m *Handler) siteSessionGenerate(w http.ResponseWriter, r *http.Request, p req.Session.ID = session.NewID() req.Session.Created = time.Now().UTC() req.Session.LastActive = time.Now().UTC() - req.Session.Namespace = p.ByName("namespace") + req.Session.Namespace = namespace log.Infof("Generated session: %#v", req.Session) return siteSessionGenerateResponse{Session: req.Session}, nil } @@ -1220,7 +1238,12 @@ func (m *Handler) siteSessionUpdate(w http.ResponseWriter, r *http.Request, p ht return nil, trace.Wrap(err) } - err = ctx.UpdateSessionTerminal(siteAPI, p.ByName("namespace"), *sessionID, req.TerminalParams) + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + + err = ctx.UpdateSessionTerminal(siteAPI, namespace, *sessionID, req.TerminalParams) if err != nil { log.Error(err) return nil, trace.Wrap(err) @@ -1245,7 +1268,13 @@ func (m *Handler) siteSessionsGet(w http.ResponseWriter, r *http.Request, p http if err != nil { return nil, trace.Wrap(err) } - sessions, err := clt.GetSessions(p.ByName("namespace")) + + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + + sessions, err := clt.GetSessions(namespace) if err != nil { return nil, trace.Wrap(err) } @@ -1270,7 +1299,13 @@ func (m *Handler) siteSessionGet(w http.ResponseWriter, r *http.Request, p httpr if err != nil { return nil, trace.Wrap(err) } - sess, err := clt.GetSession(p.ByName("namespace"), *sessionID) + + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + + sess, err := clt.GetSession(namespace, *sessionID) if err != nil { return nil, trace.Wrap(err) } @@ -1399,8 +1434,13 @@ func (m *Handler) siteSessionStreamGet(w http.ResponseWriter, r *http.Request, p if max > maxStreamBytes { max = maxStreamBytes } + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + onError(trace.BadParameter("invalid namespace %q", namespace)) + return + } // call the site API to get the chunk: - bytes, err := clt.GetSessionChunk(p.ByName("namespace"), *sid, offset, max) + bytes, err := clt.GetSessionChunk(namespace, *sid, offset, max) if err != nil { onError(trace.Wrap(err)) return @@ -1452,7 +1492,11 @@ func (m *Handler) siteSessionEventsGet(w http.ResponseWriter, r *http.Request, p if err != nil { afterN = 0 } - e, err := clt.GetSessionEvents(p.ByName("namespace"), *sessionID, afterN) + namespace := p.ByName("namespace") + if !services.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + e, err := clt.GetSessionEvents(namespace, *sessionID, afterN) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index d1e501f0542..b95f8bc5f96 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -514,6 +514,16 @@ func (s *WebSuite) TestWebSessionsCRUD(c *C) { c.Assert(trace.IsAccessDenied(err), Equals, true) } +func (s *WebSuite) TestNamespace(c *C) { + pack := s.authPack(c) + + _, err := pack.clt.Get(pack.clt.Endpoint("webapi", "sites", s.domainName, "namespaces", "..%252fevents%3f", "nodes"), url.Values{}) + c.Assert(err, NotNil) + + _, err = pack.clt.Get(pack.clt.Endpoint("webapi", "sites", s.domainName, "namespaces", "default", "nodes"), url.Values{}) + c.Assert(err, IsNil) +} + func (s *WebSuite) TestWebSessionsRenew(c *C) { pack := s.authPack(c)