mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 00:33:50 +00:00
(web) adding CSRF protection to OIDC and SAML callbacks
This commit is contained in:
parent
e58b9c23f0
commit
e86fffd28f
|
@ -68,18 +68,17 @@ func (s *AuthServer) CreateOIDCAuthRequest(req services.OIDCAuthRequest) (*servi
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
token, err := utils.CryptoRandomHex(TokenLenBytes)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
req.StateToken = token
|
||||
|
||||
oauthClient, err := oidcClient.OAuthClient()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
stateToken, err := utils.CryptoRandomHex(TokenLenBytes)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
req.StateToken = stateToken
|
||||
// online is OIDC online scope, "select_account" forces user to always select account
|
||||
req.RedirectURL = oauthClient.AuthCodeURL(req.StateToken, "online", "select_account")
|
||||
|
||||
|
|
|
@ -40,72 +40,86 @@ const (
|
|||
// it implements "double submit cookie" approach to check against CSRF attacks
|
||||
// https://www.owasp.org/index.php/Cross-Site_Request_Forgery_%28CSRF%29_Prevention_Cheat_Sheet#Double_Submit_Cookie
|
||||
func AddCSRFProtection(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
encodedToken := ""
|
||||
token, err := extractFromCookie(r)
|
||||
token, err := ExtractTokenFromCookie(r)
|
||||
// if there was an error retrieving the token, the token doesn't exist
|
||||
if err != nil || len(token) == 0 {
|
||||
encodedToken, err = utils.CryptoRandomHex(tokenLenBytes)
|
||||
token, err = utils.CryptoRandomHex(tokenLenBytes)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
} else {
|
||||
encodedToken = hex.EncodeToString(token)
|
||||
}
|
||||
|
||||
save(encodedToken, w)
|
||||
return encodedToken, nil
|
||||
save(token, w)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// VerifyToken checks if the cookie value and request value match.
|
||||
func VerifyToken(w http.ResponseWriter, r *http.Request) error {
|
||||
realToken, err := extractFromCookie(r)
|
||||
// VerifyHTTPHeader checks if HTTP header value matches the cookie.
|
||||
func VerifyHTTPHeader(r *http.Request) error {
|
||||
token := r.Header.Get(HeaderName)
|
||||
if len(token) == 0 {
|
||||
return trace.BadParameter("cannot retrieve CSRF token from HTTP header %q", HeaderName)
|
||||
}
|
||||
|
||||
err := VerifyToken(token, r)
|
||||
if err != nil {
|
||||
return trace.BadParameter("cannot retrieve CSRF token from cookie", err)
|
||||
}
|
||||
|
||||
if len(realToken) != tokenLenBytes {
|
||||
return trace.BadParameter("invalid CSRF cookie token length, expected %v, got %v", tokenLenBytes, len(realToken))
|
||||
}
|
||||
|
||||
requestToken, err := extractFromRequest(r)
|
||||
if err != nil {
|
||||
return trace.BadParameter("cannot retrieve CSRF token from HTTP header", err)
|
||||
}
|
||||
|
||||
// compare the request token against the real token
|
||||
if !compareTokens(requestToken, realToken) {
|
||||
return trace.BadParameter("request and cookie CSRF tokens do not match")
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFromCookie retrieves a CSRF token from the session cookie.
|
||||
func extractFromCookie(r *http.Request) ([]byte, error) {
|
||||
cookie, err := r.Cookie(CookieName)
|
||||
// VerifyToken validates given token based on HTTP request cookie
|
||||
func VerifyToken(token string, r *http.Request) error {
|
||||
realToken, err := ExtractTokenFromCookie(r)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
return trace.Wrap(err, "unable to extract CSRF token from cookie")
|
||||
}
|
||||
|
||||
token, err := decode(cookie.Value)
|
||||
decodedTokenA, err := decode(token)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
return trace.Wrap(err, "unable to decode CSRF token")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
decodedTokenB, err := decode(realToken)
|
||||
if err != nil {
|
||||
return trace.Wrap(err, "unable to decode cookie CSRF token")
|
||||
}
|
||||
|
||||
if !compareTokens(decodedTokenA, decodedTokenB) {
|
||||
return trace.BadParameter("CSRF tokens do not match")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFromRequest returns the issued token from HTTP header.
|
||||
func extractFromRequest(r *http.Request) ([]byte, error) {
|
||||
issued := r.Header.Get(HeaderName)
|
||||
decoded, err := decode(issued)
|
||||
// ExtractTokenFromCookie retrieves a CSRF token from the session cookie.
|
||||
func ExtractTokenFromCookie(r *http.Request) (string, error) {
|
||||
cookie, err := r.Cookie(CookieName)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
|
||||
return cookie.Value, nil
|
||||
}
|
||||
|
||||
// decode decodes a cookie using base64.
|
||||
func decode(token string) ([]byte, error) {
|
||||
decoded, err := hex.DecodeString(token)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
if len(decoded) != tokenLenBytes {
|
||||
return nil, trace.BadParameter("invalid CSRF token byte length, expected %v, got %v", tokenLenBytes, len(decoded))
|
||||
}
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// compareTokens securely (constant-time) compares CSRF tokens
|
||||
func compareTokens(a, b []byte) bool {
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
|
||||
// save stores encoded CSRF token in the session cookie.
|
||||
func save(encodedToken string, w http.ResponseWriter) string {
|
||||
cookie := &http.Cookie{
|
||||
|
@ -122,20 +136,3 @@ func save(encodedToken string, w http.ResponseWriter) string {
|
|||
w.Header().Add("Vary", "Cookie")
|
||||
return encodedToken
|
||||
}
|
||||
|
||||
// compare securely (constant-time) compares request token against the real token
|
||||
// from the session.
|
||||
func compareTokens(a, b []byte) bool {
|
||||
// this is required as subtle.ConstantTimeCompare does not check for equal
|
||||
// lengths in Go versions prior to 1.3.
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
|
||||
// decode decodes a cookie using base64.
|
||||
func decode(value string) ([]byte, error) {
|
||||
return hex.DecodeString(value)
|
||||
}
|
||||
|
|
88
lib/httplib/httpheaders.go
Normal file
88
lib/httplib/httpheaders.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
/*
|
||||
Copyright 2015 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package httplib implements common utility functions for writing
|
||||
// classic HTTP handlers
|
||||
package httplib
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SetNoCacheHeaders tells proxies and browsers do not cache the content
|
||||
func SetNoCacheHeaders(h http.Header) {
|
||||
h.Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
h.Set("Pragma", "no-cache")
|
||||
h.Set("Expires", "0")
|
||||
}
|
||||
|
||||
// SetStaticFileHeaders sets security header flags for static non-html resources
|
||||
func SetStaticFileHeaders(h http.Header) {
|
||||
SetSameOriginIFrame(h)
|
||||
SetNoSniff(h)
|
||||
}
|
||||
|
||||
// SetIndexHTMLHeaders sets security header flags for main index.html page
|
||||
func SetIndexHTMLHeaders(h http.Header) {
|
||||
SetNoCacheHeaders(h)
|
||||
SetSameOriginIFrame(h)
|
||||
SetNoSniff(h)
|
||||
|
||||
// X-Frame-Options indicates that the page can only be displayed in iframe on the same origin as the page itself
|
||||
h.Set("X-Frame-Options", "SAMEORIGIN")
|
||||
|
||||
// X-XSS-Protection is a feature of Internet Explorer, Chrome and Safari that stops pages
|
||||
// from loading when they detect reflected cross-site scripting (XSS) attacks.
|
||||
h.Set("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
// Once a supported browser receives this header that browser will prevent any communications from
|
||||
// being sent over HTTP to the specified domain and will instead send all communications over HTTPS.
|
||||
// It also prevents HTTPS click through prompts on browsers
|
||||
h.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
|
||||
// Prevent web browsers from using content sniffing to discover a file’s MIME type
|
||||
h.Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
// Set content policy flags
|
||||
var cspValue = strings.Join([]string{
|
||||
"script-src 'self'",
|
||||
// 'unsafe-inline' needed for reactjs inline styles
|
||||
"style-src 'self' 'unsafe-inline'",
|
||||
"object-src 'none'",
|
||||
"img-src 'self' data: blob:",
|
||||
}, ";")
|
||||
|
||||
h.Set("Content-Security-Policy", cspValue)
|
||||
}
|
||||
|
||||
// SetSameOriginIFrame sets X-Frame-Options flag
|
||||
func SetSameOriginIFrame(h http.Header) {
|
||||
// X-Frame-Options indicates that the page can only be displayed in iframe on the same origin as the page itself
|
||||
h.Set("X-Frame-Options", "SAMEORIGIN")
|
||||
}
|
||||
|
||||
// SetNoSniff sets X-Content-Type-Options flag
|
||||
func SetNoSniff(h http.Header) {
|
||||
// Prevent web browsers from using content sniffing to discover a file’s MIME type
|
||||
h.Set("X-Content-Type-Options", "nosniff")
|
||||
}
|
||||
|
||||
// SetWebConfigHeaders sets headers for webConfig.js
|
||||
func SetWebConfigHeaders(h http.Header) {
|
||||
SetStaticFileHeaders(h)
|
||||
h.Set("Content-Type", "application/javascript")
|
||||
}
|
|
@ -25,11 +25,14 @@ import (
|
|||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitational/teleport/lib/httplib/csrf"
|
||||
|
||||
"github.com/gravitational/roundtrip"
|
||||
"github.com/gravitational/trace"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// HandlerFunc specifies HTTP handler function that returns error
|
||||
|
@ -69,6 +72,20 @@ func MakeStdHandler(fn StdHandlerFunc) http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
// WithCSRFProtection ensures that request to unauthenticated API is checked against CSRF attacks
|
||||
func WithCSRFProtection(fn HandlerFunc) httprouter.Handle {
|
||||
hanlderFn := MakeHandler(fn)
|
||||
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
err := csrf.VerifyHTTPHeader(r)
|
||||
if err != nil {
|
||||
log.Warningf("unable to validate CSRF token %v", err)
|
||||
trace.WriteError(w, trace.AccessDenied("access denied"))
|
||||
return
|
||||
}
|
||||
hanlderFn(w, r, p)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadJSON reads HTTP json request and unmarshals it
|
||||
// into passed interface{} obj
|
||||
func ReadJSON(r *http.Request, val interface{}) error {
|
||||
|
@ -94,45 +111,6 @@ func ConvertResponse(re *roundtrip.Response, err error) (*roundtrip.Response, er
|
|||
return re, trace.ReadError(re.Code(), re.Bytes())
|
||||
}
|
||||
|
||||
// SetNoCacheHeaders tells proxies and browsers do not cache the content
|
||||
func SetNoCacheHeaders(h http.Header) {
|
||||
h.Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
h.Set("Pragma", "no-cache")
|
||||
h.Set("Expires", "0")
|
||||
}
|
||||
|
||||
// SetIndexHTMLHeaders sets security header flags for main index.html page
|
||||
func SetIndexHTMLHeaders(h http.Header) {
|
||||
// Disable caching
|
||||
SetNoCacheHeaders(h)
|
||||
|
||||
// X-Frame-Options indicates that the page can only be displayed in iframe on the same origin as the page itself
|
||||
h.Set("X-Frame-Options", "SAMEORIGIN")
|
||||
|
||||
// X-XSS-Protection is a feature of Internet Explorer, Chrome and Safari that stops pages
|
||||
// from loading when they detect reflected cross-site scripting (XSS) attacks.
|
||||
h.Set("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
// Once a supported browser receives this header that browser will prevent any communications from
|
||||
// being sent over HTTP to the specified domain and will instead send all communications over HTTPS.
|
||||
// It also prevents HTTPS click through prompts on browsers
|
||||
h.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
|
||||
// Prevent web browsers from using content sniffing to discover a file’s MIME type
|
||||
h.Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
// Set content policy flags
|
||||
var cspValue = strings.Join([]string{
|
||||
"script-src 'self'",
|
||||
// 'unsafe-inline' needed for reactjs inline styles
|
||||
"style-src 'self' 'unsafe-inline'",
|
||||
"object-src 'none'",
|
||||
"img-src 'self' data: blob:",
|
||||
}, ";")
|
||||
|
||||
h.Set("Content-Security-Policy", cspValue)
|
||||
}
|
||||
|
||||
// ParseBool will parse boolean variable from url query
|
||||
// returns value, ok, error
|
||||
func ParseBool(q url.Values, name string) (bool, bool, error) {
|
||||
|
|
|
@ -266,6 +266,9 @@ type OIDCAuthRequest struct {
|
|||
// reuqest coming from
|
||||
StateToken string `json:"state_token"`
|
||||
|
||||
// CSRFToken is associated with user web session token
|
||||
CSRFToken string `json:"csrf_token"`
|
||||
|
||||
// RedirectURL will be used by browser
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
|
||||
|
@ -336,6 +339,9 @@ type SAMLAuthRequest struct {
|
|||
// CertTTL is the TTL of the certificate user wants to get
|
||||
CertTTL time.Duration `json:"cert_ttl"`
|
||||
|
||||
// CSRFToken is associated with user web session token
|
||||
CSRFToken string `json:"csrf_token"`
|
||||
|
||||
// CreateWebSession indicates if user wants to generate a web
|
||||
// session after successful authentication
|
||||
CreateWebSession bool `json:"create_web_session"`
|
||||
|
|
|
@ -151,7 +151,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*RewritingHandler, error) {
|
|||
h.GET("/webapi/ping/:connector", httplib.MakeHandler(h.pingWithConnector))
|
||||
|
||||
// Web sessions
|
||||
h.POST("/webapi/sessions", h.WithCSRFProtection(httplib.MakeHandler(h.createSession)))
|
||||
h.POST("/webapi/sessions", httplib.WithCSRFProtection(h.createSession))
|
||||
h.DELETE("/webapi/sessions", h.WithAuth(h.deleteSession))
|
||||
h.POST("/webapi/sessions/renew", h.WithAuth(h.renewSession))
|
||||
|
||||
|
@ -212,9 +212,8 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*RewritingHandler, error) {
|
|||
|
||||
// if Web UI is enabled, check the assets dir:
|
||||
var (
|
||||
writeSettings http.HandlerFunc
|
||||
indexPage *template.Template
|
||||
staticFS http.FileSystem
|
||||
indexPage *template.Template
|
||||
staticFS http.FileSystem
|
||||
)
|
||||
if !cfg.DisableUI {
|
||||
staticFS, err = NewStaticFileSystem(isDebugMode())
|
||||
|
@ -235,7 +234,8 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*RewritingHandler, error) {
|
|||
if err != nil {
|
||||
return nil, trace.BadParameter("failed parsing index.html template: %v", err)
|
||||
}
|
||||
writeSettings = httplib.MakeStdHandler(h.getConfigurationSettings)
|
||||
|
||||
h.Handle("GET", "/web/config.js", httplib.MakeHandler(h.getConfigurationSettings))
|
||||
}
|
||||
|
||||
routingHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -259,11 +259,9 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*RewritingHandler, error) {
|
|||
|
||||
// serve Web UI:
|
||||
if strings.HasPrefix(r.URL.Path, "/web/app") {
|
||||
httplib.SetStaticFileHeaders(w.Header())
|
||||
http.StripPrefix("/web", http.FileServer(staticFS)).ServeHTTP(w, r)
|
||||
|
||||
} else if strings.HasPrefix(r.URL.Path, "/web/config.js") {
|
||||
writeSettings.ServeHTTP(w, r)
|
||||
} else if strings.HasPrefix(r.URL.Path, "/web") {
|
||||
} else if strings.HasPrefix(r.URL.Path, "/web/") || r.URL.Path == "/web" {
|
||||
csrfToken, err := csrf.AddCSRFProtection(w, r)
|
||||
if err != nil {
|
||||
log.Errorf("failed to generate CSRF token %v", err)
|
||||
|
@ -515,7 +513,8 @@ type webConfig struct {
|
|||
}
|
||||
|
||||
// getConfigurationSettings returns configuration for the web application.
|
||||
func (h *Handler) getConfigurationSettings(w http.ResponseWriter, r *http.Request) (interface{}, error) {
|
||||
func (h *Handler) getConfigurationSettings(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
|
||||
httplib.SetWebConfigHeaders(w.Header())
|
||||
as, err := defaultAuthenticationSettings(h.cfg.ProxyClient)
|
||||
if err != nil {
|
||||
log.Infof("Cannot retrieve cluster auth preferences: %v", err)
|
||||
|
@ -547,8 +546,16 @@ func (h *Handler) oidcLoginWeb(w http.ResponseWriter, r *http.Request, p httprou
|
|||
if connectorID == "" {
|
||||
return nil, trace.BadParameter("missing connector_id query parameter")
|
||||
}
|
||||
|
||||
csrfToken, err := csrf.ExtractTokenFromCookie(r)
|
||||
if err != nil {
|
||||
log.Warningf("unable to extract CSRF token from cookie", err)
|
||||
return nil, trace.AccessDenied("access denied")
|
||||
}
|
||||
|
||||
response, err := h.cfg.ProxyClient.CreateOIDCAuthRequest(
|
||||
services.OIDCAuthRequest{
|
||||
CSRFToken: csrfToken,
|
||||
ConnectorID: connectorID,
|
||||
CreateWebSession: true,
|
||||
ClientRedirectURL: clientRedirectURL,
|
||||
|
@ -608,6 +615,12 @@ func (h *Handler) oidcCallback(w http.ResponseWriter, r *http.Request, p httprou
|
|||
}
|
||||
// if we created web session, set session cookie and redirect to original url
|
||||
if response.Req.CreateWebSession {
|
||||
err = csrf.VerifyToken(response.Req.CSRFToken, r)
|
||||
if err != nil {
|
||||
log.Warningf("[OIDC] unable to verify CSRF token", err)
|
||||
return nil, trace.AccessDenied("access denied")
|
||||
}
|
||||
|
||||
log.Infof("oidcCallback redirecting to web browser")
|
||||
if err := SetSession(w, response.Username, response.Session.GetName()); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
|
@ -1731,18 +1744,6 @@ func (h *Handler) WithAuth(fn ContextHandler) httprouter.Handle {
|
|||
})
|
||||
}
|
||||
|
||||
// WithCSRFProtection ensures that request to unauthenticated API is checked against CSRF attacks
|
||||
func (h *Handler) WithCSRFProtection(fn httprouter.Handle) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
err := csrf.VerifyToken(w, r)
|
||||
if err != nil {
|
||||
trace.WriteError(w, trace.AccessDenied("failed to validate CSRF token", err))
|
||||
return
|
||||
}
|
||||
fn(w, r, p)
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticateRequest authenticates request using combination of a session cookie
|
||||
// and bearer token
|
||||
func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) {
|
||||
|
|
|
@ -474,8 +474,17 @@ func (s *WebSuite) TestSAMLSuccess(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
s.authServer.SetClock(clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC)))
|
||||
clt := s.clientNoRedirects()
|
||||
re, err := clt.Get(clt.Endpoint("webapi", "saml", "sso"),
|
||||
url.Values{"redirect_url": []string{"http://localhost/after"}, "connector_id": []string{connector.GetName()}})
|
||||
|
||||
csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992"
|
||||
|
||||
baseURL, err := url.Parse(clt.Endpoint("webapi", "saml", "sso") + `?redirect_url=http://localhost/after;connector_id=` + connector.GetName())
|
||||
c.Assert(err, IsNil)
|
||||
req, err := http.NewRequest("GET", baseURL.String(), nil)
|
||||
addCSRFCookieToReq(req, csrfToken)
|
||||
re, err := clt.Client.RoundTrip(func() (*http.Response, error) {
|
||||
return clt.Client.HTTPClient().Do(req)
|
||||
})
|
||||
|
||||
// we got a redirect
|
||||
locationURL := re.Headers().Get("Location")
|
||||
u, err := url.Parse(locationURL)
|
||||
|
@ -494,8 +503,10 @@ func (s *WebSuite) TestSAMLSuccess(c *C) {
|
|||
identity := local.NewIdentityService(s.bk)
|
||||
authRequest, err := identity.GetSAMLAuthRequest(id.Value)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// now swap the request id to the hardcoded one in fixtures
|
||||
authRequest.ID = fixtures.SAMLOktaAuthRequestID
|
||||
authRequest.CSRFToken = csrfToken
|
||||
identity.CreateSAMLAuthRequest(*authRequest, backend.Forever)
|
||||
|
||||
// now respond with pre-recorded request to the POST url
|
||||
|
@ -513,12 +524,14 @@ func (s *WebSuite) TestSAMLSuccess(c *C) {
|
|||
// now send the response to the server to exchange it for auth session
|
||||
form := url.Values{}
|
||||
form.Add("SAMLResponse", encodedResponse)
|
||||
req, err := http.NewRequest("POST", clt.Endpoint("webapi", "saml", "acs"), strings.NewReader(form.Encode()))
|
||||
req, err = http.NewRequest("POST", clt.Endpoint("webapi", "saml", "acs"), strings.NewReader(form.Encode()))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
addCSRFCookieToReq(req, csrfToken)
|
||||
c.Assert(err, IsNil)
|
||||
authRe, err := clt.Client.RoundTrip(func() (*http.Response, error) {
|
||||
return clt.Client.HTTPClient().Do(req)
|
||||
})
|
||||
|
||||
c.Assert(err, IsNil)
|
||||
comment := Commentf("Response: %v", string(authRe.Bytes()))
|
||||
c.Assert(authRe.Code(), Equals, http.StatusFound, comment)
|
||||
|
@ -560,19 +573,11 @@ func (s *WebSuite) authPackFromResponse(c *C, re *roundtrip.Response) *authPack
|
|||
}
|
||||
}
|
||||
|
||||
// authPack returns new authenticated package consisting
|
||||
// of created valid user, hotp token, created web session and
|
||||
// authenticated client
|
||||
func (s *WebSuite) authPack(c *C) *authPack {
|
||||
user := s.user
|
||||
pass := "abc123"
|
||||
rawSecret := "def456"
|
||||
otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret))
|
||||
|
||||
func (s *WebSuite) createUser(c *C, user string, pass string, otpSecret string) {
|
||||
teleUser, err := services.NewUser(user)
|
||||
c.Assert(err, IsNil)
|
||||
role := services.RoleForUser(teleUser)
|
||||
role.SetLogins(services.Allow, []string{s.user})
|
||||
role.SetLogins(services.Allow, []string{user})
|
||||
err = s.roleAuth.UpsertRole(role, backend.Forever)
|
||||
c.Assert(err, IsNil)
|
||||
teleUser.AddRole(role.GetName())
|
||||
|
@ -585,6 +590,17 @@ func (s *WebSuite) authPack(c *C) *authPack {
|
|||
|
||||
err = s.roleAuth.UpsertTOTP(user, otpSecret)
|
||||
c.Assert(err, IsNil)
|
||||
}
|
||||
|
||||
// authPack returns new authenticated package consisting
|
||||
// of created valid user, hotp token, created web session and
|
||||
// authenticated client
|
||||
func (s *WebSuite) authPack(c *C) *authPack {
|
||||
user := s.user
|
||||
pass := "abc123"
|
||||
rawSecret := "def456"
|
||||
otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret))
|
||||
s.createUser(c, user, pass, otpSecret)
|
||||
|
||||
// create a valid otp token
|
||||
validToken, err := totp.GenerateCode(otpSecret, time.Now())
|
||||
|
@ -652,12 +668,27 @@ func (s *WebSuite) TestNamespace(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
}
|
||||
|
||||
func (s *WebSuite) TestCRSF(c *C) {
|
||||
func (s *WebSuite) TestCSRF(c *C) {
|
||||
type input struct {
|
||||
reqToken string
|
||||
cookieToken string
|
||||
}
|
||||
|
||||
// create a valid user
|
||||
user := "csrfuser"
|
||||
pass := "abc123"
|
||||
otpSecret := base32.StdEncoding.EncodeToString([]byte("def456"))
|
||||
s.createUser(c, user, pass, otpSecret)
|
||||
|
||||
// create a valid login form request
|
||||
validToken, err := totp.GenerateCode(otpSecret, time.Now())
|
||||
c.Assert(err, IsNil)
|
||||
loginForm := createSessionReq{
|
||||
User: user,
|
||||
Pass: pass,
|
||||
SecondFactorToken: validToken,
|
||||
}
|
||||
|
||||
encodedToken1 := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992"
|
||||
encodedToken2 := "bf355921bbf3ef3672a03e410d4194077dfa5fe863c652521763b3e7f81e7b11"
|
||||
invalid := []input{
|
||||
|
@ -668,9 +699,14 @@ func (s *WebSuite) TestCRSF(c *C) {
|
|||
}
|
||||
|
||||
clt := s.client()
|
||||
|
||||
// valid
|
||||
_, err = s.login(clt, encodedToken1, encodedToken1, loginForm)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// invalid
|
||||
for i := range invalid {
|
||||
_, err := s.login(clt, invalid[i].cookieToken, invalid[i].reqToken, nil)
|
||||
_, err := s.login(clt, invalid[i].cookieToken, invalid[i].reqToken, loginForm)
|
||||
c.Assert(err, NotNil)
|
||||
c.Assert(trace.IsAccessDenied(err), Equals, true)
|
||||
}
|
||||
|
@ -1380,15 +1416,18 @@ func (s *WebSuite) login(clt *client.WebClient, cookieToken string, reqToken str
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: csrf.CookieName,
|
||||
Value: cookieToken,
|
||||
}
|
||||
|
||||
req.AddCookie(cookie)
|
||||
addCSRFCookieToReq(req, cookieToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(csrf.HeaderName, reqToken)
|
||||
return clt.HTTPClient().Do(req)
|
||||
}))
|
||||
}
|
||||
|
||||
func addCSRFCookieToReq(req *http.Request, token string) {
|
||||
cookie := &http.Cookie{
|
||||
Name: csrf.CookieName,
|
||||
Value: token,
|
||||
}
|
||||
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/gravitational/teleport/lib/client"
|
||||
"github.com/gravitational/teleport/lib/httplib"
|
||||
"github.com/gravitational/teleport/lib/httplib/csrf"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
|
||||
"github.com/gravitational/form"
|
||||
|
@ -26,9 +27,17 @@ func (m *Handler) samlSSO(w http.ResponseWriter, r *http.Request, p httprouter.P
|
|||
if connectorID == "" {
|
||||
return nil, trace.BadParameter("missing connector_id query parameter")
|
||||
}
|
||||
|
||||
csrfToken, err := csrf.ExtractTokenFromCookie(r)
|
||||
if err != nil {
|
||||
log.Warningf("unable to extract CSRF token from cookie %v", err)
|
||||
return nil, trace.AccessDenied("access denied")
|
||||
}
|
||||
|
||||
response, err := m.cfg.ProxyClient.CreateSAMLAuthRequest(
|
||||
services.SAMLAuthRequest{
|
||||
ConnectorID: connectorID,
|
||||
CSRFToken: csrfToken,
|
||||
CreateWebSession: true,
|
||||
ClientRedirectURL: clientRedirectURL,
|
||||
})
|
||||
|
@ -90,9 +99,16 @@ func (m *Handler) samlACS(w http.ResponseWriter, r *http.Request, p httprouter.P
|
|||
http.Redirect(w, r, pathToError.String(), http.StatusFound)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// if we created web session, set session cookie and redirect to original url
|
||||
if response.Req.CreateWebSession {
|
||||
log.Debugf("redirecting to web browser")
|
||||
err = csrf.VerifyToken(response.Req.CSRFToken, r)
|
||||
if err != nil {
|
||||
l.Warningf("unable to verify CSRF token", err)
|
||||
return nil, trace.AccessDenied("access denied")
|
||||
}
|
||||
|
||||
if err := SetSession(w, response.Username, response.Session.GetName()); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
|
|
@ -19,8 +19,8 @@ package web
|
|||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitational/teleport"
|
||||
|
||||
|
@ -53,22 +53,28 @@ func (s *StaticSuite) TestLocalFS(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
c.Assert(fs, check.NotNil)
|
||||
|
||||
checkFS(fs, c)
|
||||
f, err := fs.Open("/index.html")
|
||||
c.Assert(err, check.IsNil)
|
||||
bytes, err := ioutil.ReadAll(f)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
html := string(bytes[:])
|
||||
c.Assert(f.Close(), check.IsNil)
|
||||
c.Assert(strings.Contains(html, `<script src="/web/config.js"></script>`), check.Equals, true)
|
||||
c.Assert(strings.Contains(html, `content="{{ .XCSRF }}"`), check.Equals, true)
|
||||
}
|
||||
|
||||
func (s *StaticSuite) TestZipFS(c *check.C) {
|
||||
fs, err := readZipArchive("../../fixtures/assets.zip")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(fs, check.NotNil)
|
||||
}
|
||||
|
||||
func checkFS(fs http.FileSystem, c *check.C) {
|
||||
// test simple full read:
|
||||
f, err := fs.Open("/index.html")
|
||||
c.Assert(err, check.IsNil)
|
||||
bytes, err := ioutil.ReadAll(f)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(bytes), check.Equals, 880)
|
||||
c.Assert(len(bytes), check.Equals, 813)
|
||||
c.Assert(f.Close(), check.IsNil)
|
||||
|
||||
// seek + read
|
||||
|
@ -82,7 +88,7 @@ func checkFS(fs http.FileSystem, c *check.C) {
|
|||
|
||||
bytes, err = ioutil.ReadAll(f)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(bytes), check.Equals, 870)
|
||||
c.Assert(len(bytes), check.Equals, 803)
|
||||
|
||||
n, err = f.Seek(-50, io.SeekEnd)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
|
|
@ -44,26 +44,44 @@ type Server struct {
|
|||
Labels []Label `json:"tags"`
|
||||
}
|
||||
|
||||
// sortedLabels is a sort wrapper that sorts labels by name
|
||||
type sortedLabels []Label
|
||||
|
||||
func (s sortedLabels) Len() int {
|
||||
return len(s)
|
||||
}
|
||||
|
||||
func (s sortedLabels) Less(i, j int) bool {
|
||||
return s[i].Name < s[j].Name
|
||||
}
|
||||
|
||||
func (s sortedLabels) Swap(i, j int) {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
|
||||
// MakeServers creates server objects for webapp
|
||||
func MakeServers(clusterName string, servers []services.Server) []Server {
|
||||
uiServers := []Server{}
|
||||
for _, server := range servers {
|
||||
serverLabels := server.GetLabels()
|
||||
labelNames := []string{}
|
||||
for name := range serverLabels {
|
||||
labelNames = append(labelNames, name)
|
||||
}
|
||||
|
||||
// sort labels by name
|
||||
sort.Strings(labelNames)
|
||||
uiLabels := []Label{}
|
||||
for _, name := range labelNames {
|
||||
serverLabels := server.GetLabels()
|
||||
for name, value := range serverLabels {
|
||||
uiLabels = append(uiLabels, Label{
|
||||
Name: name,
|
||||
Value: serverLabels[name],
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
serverCmdLabels := server.GetCmdLabels()
|
||||
for name, cmd := range serverCmdLabels {
|
||||
uiLabels = append(uiLabels, Label{
|
||||
Name: name,
|
||||
Value: cmd.GetResult(),
|
||||
})
|
||||
}
|
||||
|
||||
sort.Sort(sortedLabels(uiLabels))
|
||||
|
||||
uiServers = append(uiServers, Server{
|
||||
ClusterName: clusterName,
|
||||
Name: server.GetName(),
|
||||
|
|
Loading…
Reference in a new issue