mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 01:34:01 +00:00
5f142d3758
* Additional safety with `X-Forwarded-Host` handling This adds `utils.GetSingleHeader` as a common way to make sure that additional headers are not being inserted into the request. We use this in the `aws`, `azure`, and `gcp` handler as part of verifying the source of the request. In addition `alpnproxy/local_proxy.go` contains a fix where an invalid `Host` header can allow an arbitrary `X-Forwarded-Host` value to pass through unchanged. * Apply PR feedback around error type and testing
142 lines
3.5 KiB
Go
142 lines
3.5 KiB
Go
/*
|
|
Copyright 2023 Gravitational, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package utils
|
|
|
|
import (
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestRenameHeaders(t *testing.T) {
|
|
header := make(http.Header)
|
|
header.Add("aaa", "a1")
|
|
header.Add("aaa", "a2")
|
|
header.Add("bbb", "b1")
|
|
header.Add("ccc", "c1")
|
|
|
|
RenameHeader(header, "aaa", "aaaa")
|
|
RenameHeader(header, "bbb", "bbbb")
|
|
RenameHeader(header, "ccc", "ccc")
|
|
require.Equal(t, http.Header{
|
|
"Aaaa": []string{"a1", "a2"},
|
|
"Bbbb": []string{"b1"},
|
|
"Ccc": []string{"c1"},
|
|
}, header)
|
|
}
|
|
|
|
func TestGetAnyHeader(t *testing.T) {
|
|
header := make(http.Header)
|
|
header.Set("aaa", "a1")
|
|
header.Set("bbb", "b1")
|
|
|
|
require.Equal(t, "", GetAnyHeader(header))
|
|
require.Equal(t, "", GetAnyHeader(header, "ccc"))
|
|
require.Equal(t, "a1", GetAnyHeader(header, "aaa"))
|
|
require.Equal(t, "a1", GetAnyHeader(header, "ccc", "aaa"))
|
|
require.Equal(t, "b1", GetAnyHeader(header, "bbb", "aaa"))
|
|
}
|
|
|
|
func TestGetSingleHeader(t *testing.T) {
|
|
t.Run("NoValue", func(t *testing.T) {
|
|
t.Parallel()
|
|
headers := make(http.Header)
|
|
|
|
result, err := GetSingleHeader(headers, "key")
|
|
require.Empty(t, result)
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("SingleValue", func(t *testing.T) {
|
|
t.Parallel()
|
|
headers := make(http.Header)
|
|
key := "key"
|
|
value := "value"
|
|
headers.Set(key, value)
|
|
|
|
result, err := GetSingleHeader(headers, key)
|
|
require.NoError(t, err)
|
|
require.Equal(t, value, result)
|
|
})
|
|
t.Run("DuplicateValue", func(t *testing.T) {
|
|
t.Parallel()
|
|
headers := make(http.Header)
|
|
key := "key"
|
|
value := "value1"
|
|
headers.Add(key, value)
|
|
headers.Add(key, "value2")
|
|
|
|
result, err := GetSingleHeader(headers, key)
|
|
require.Empty(t, result)
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("DuplicateCaseValue", func(t *testing.T) {
|
|
t.Parallel()
|
|
headers := make(http.Header)
|
|
key := "key"
|
|
value := "value1"
|
|
headers.Add(key, value)
|
|
headers.Add(strings.ToUpper(key), "value2")
|
|
|
|
result, err := GetSingleHeader(headers, key)
|
|
require.Empty(t, result)
|
|
require.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestChainHTTPMiddlewares(t *testing.T) {
|
|
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("baseHandler"))
|
|
})
|
|
|
|
middleware2 := func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("middleware2->"))
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
middleware4 := func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("middleware4->"))
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
handler := ChainHTTPMiddlewares(
|
|
baseHandler,
|
|
nil,
|
|
middleware2,
|
|
NoopHTTPMiddleware,
|
|
middleware4,
|
|
)
|
|
|
|
w := httptest.NewRecorder()
|
|
r := httptest.NewRequest("", "/", nil)
|
|
handler.ServeHTTP(w, r)
|
|
|
|
resp := w.Result()
|
|
defer resp.Body.Close()
|
|
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "middleware4->middleware2->baseHandler", string(body))
|
|
}
|