diff --git a/cmd.go b/cmd.go index 30afdd41..48ffedd0 100644 --- a/cmd.go +++ b/cmd.go @@ -245,9 +245,11 @@ func handlePrint(cmdArgs *settings.Arguments, dbExecutor db.Executor) error { quiet := cmdArgs.ExistsArg("q", "quiet") return news.PrintNewsFeed(config.Runtime.HTTPClient, dbExecutor.LastBuildTime(), config.SortMode, double, quiet) case cmdArgs.ExistsDouble("c", "complete"): - return completion.Show(dbExecutor, config.AURURL, config.Runtime.CompletionPath, config.CompletionInterval, true) + return completion.Show(config.Runtime.HTTPClient, dbExecutor, + config.AURURL, config.Runtime.CompletionPath, config.CompletionInterval, true) case cmdArgs.ExistsArg("c", "complete"): - return completion.Show(dbExecutor, config.AURURL, config.Runtime.CompletionPath, config.CompletionInterval, false) + return completion.Show(config.Runtime.HTTPClient, dbExecutor, + config.AURURL, config.Runtime.CompletionPath, config.CompletionInterval, false) case cmdArgs.ExistsArg("s", "stats"): return localStatistics(dbExecutor) } diff --git a/install.go b/install.go index 2677b385..c64c92f0 100644 --- a/install.go +++ b/install.go @@ -351,7 +351,8 @@ func install(cmdArgs *settings.Arguments, dbExecutor db.Executor, ignoreProvider } go func() { - _ = completion.Update(dbExecutor, config.AURURL, config.Runtime.CompletionPath, config.CompletionInterval, false) + _ = completion.Update(config.Runtime.HTTPClient, dbExecutor, + config.AURURL, config.Runtime.CompletionPath, config.CompletionInterval, false) }() err = downloadPkgbuildsSources(do.Aur, incompatible) diff --git a/pkg/completion/completion.go b/pkg/completion/completion.go index a8006c9e..19f7e88f 100644 --- a/pkg/completion/completion.go +++ b/pkg/completion/completion.go @@ -2,6 +2,7 @@ package completion import ( "bufio" + "context" "fmt" "io" "net/http" @@ -20,8 +21,8 @@ type PkgSynchronizer interface { } // Show provides completion info for shells -func Show(dbExecutor PkgSynchronizer, aurURL, completionPath string, interval int, force bool) error { - err := Update(dbExecutor, aurURL, completionPath, interval, force) +func Show(httpClient *http.Client, dbExecutor PkgSynchronizer, aurURL, completionPath string, interval int, force bool) error { + err := Update(httpClient, dbExecutor, aurURL, completionPath, interval, force) if err != nil { return err } @@ -37,7 +38,7 @@ func Show(dbExecutor PkgSynchronizer, aurURL, completionPath string, interval in } // Update updates completion cache to be used by Complete -func Update(dbExecutor PkgSynchronizer, aurURL, completionPath string, interval int, force bool) error { +func Update(httpClient *http.Client, dbExecutor PkgSynchronizer, aurURL, completionPath string, interval int, force bool) error { info, err := os.Stat(completionPath) if os.IsNotExist(err) || (interval != -1 && time.Since(info.ModTime()).Hours() >= float64(interval*24)) || force { @@ -50,7 +51,7 @@ func Update(dbExecutor PkgSynchronizer, aurURL, completionPath string, interval return errf } - if createAURList(aurURL, out) != nil { + if createAURList(httpClient, aurURL, out) != nil { defer os.Remove(completionPath) } @@ -64,13 +65,19 @@ func Update(dbExecutor PkgSynchronizer, aurURL, completionPath string, interval } // CreateAURList creates a new completion file -func createAURList(aurURL string, out io.Writer) error { +func createAURList(client *http.Client, aurURL string, out io.Writer) error { u, err := url.Parse(aurURL) if err != nil { return err } u.Path = path.Join(u.Path, "packages.gz") - resp, err := http.Get(u.String()) + + req, err := http.NewRequestWithContext(context.Background(), "GET", u.String(), nil) + if err != nil { + return err + } + + resp, err := client.Do(req) if err != nil { return err } diff --git a/pkg/completion/completion_test.go b/pkg/completion/completion_test.go index b3adc37e..36eb3893 100644 --- a/pkg/completion/completion_test.go +++ b/pkg/completion/completion_test.go @@ -3,6 +3,7 @@ package completion import ( "bytes" "errors" + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -39,7 +40,7 @@ func Test_createAURList(t *testing.T) { Reply(200). BodyString(samplePackageResp) out := &bytes.Buffer{} - err := createAURList("https://aur.archlinux.org", out) + err := createAURList(&http.Client{}, "https://aur.archlinux.org", out) assert.NoError(t, err) gotOut := out.String() assert.Equal(t, expectPackageCompletion, gotOut) @@ -52,7 +53,7 @@ func Test_createAURListHTTPError(t *testing.T) { Get("/packages.gz"). ReplyError(errors.New("Not available")) out := &bytes.Buffer{} - err := createAURList("https://aur.archlinux.org", out) + err := createAURList(&http.Client{}, "https://aur.archlinux.org", out) assert.EqualError(t, err, "Get \"https://aur.archlinux.org/packages.gz\": Not available") } @@ -64,6 +65,6 @@ func Test_createAURListStatusError(t *testing.T) { Reply(503). BodyString(samplePackageResp) out := &bytes.Buffer{} - err := createAURList("https://aur.archlinux.org", out) + err := createAURList(&http.Client{}, "https://aur.archlinux.org", out) assert.EqualError(t, err, "invalid status code: 503") }