Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions internal/api/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,72 @@ func TestGraphQLError_Code(t *testing.T) {
}

}

func TestGraphQLError_Message(t *testing.T) {
for name, tc := range map[string]struct {
in string
want string
wantErr bool
}{
"invalid message": {
in: `{
"errors": [
{
"message": 42
}
],
"data": null
}`,
wantErr: true,
},
"no message": {
in: `{
"errors": [
{
"extensions": {
"code": "ErrBatchChangesUnlicensed"
}
}
],
"data": null
}`,
want: "",
},
"valid message": {
in: `{
"errors": [
{
"message": "Cannot query field \"batchChanges\" on type \"Query\"."
}
],
"data": null
}`,
want: `Cannot query field "batchChanges" on type "Query".`,
},
} {
t.Run(name, func(t *testing.T) {
var result rawResult
if err := json.Unmarshal([]byte(tc.in), &result); err != nil {
t.Fatal(err)
}
if ne := len(result.Errors); ne != 1 {
t.Fatalf("unexpected number of GraphQL errors (this test can only handle one!): %d", ne)
}

ge := &GraphQlError{result.Errors[0]}
have, err := ge.Message()
if tc.wantErr {
if err == nil {
t.Errorf("unexpected nil error")
}
} else {
if err != nil {
t.Errorf("unexpected error: %+v", err)
}
if have != tc.want {
t.Errorf("unexpected message: have=%q want=%q", have, tc.want)
}
}
})
}
}
19 changes: 19 additions & 0 deletions internal/batches/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,22 @@ func (e IgnoredRepoSet) Append(repo *graphql.Repository) {
func (e IgnoredRepoSet) HasIgnored() bool {
return len(e) > 0
}

// BatchChangesDisabledError indicates that Batch Changes is unavailable on the
// target instance, typically because the GraphQL schema does not expose the
// relevant fields when the feature is disabled.
type BatchChangesDisabledError struct {
cause error
}

func NewBatchChangesDisabledError(cause error) *BatchChangesDisabledError {
return &BatchChangesDisabledError{cause: cause}
}

func (e *BatchChangesDisabledError) Error() string {
return "Batch Changes is disabled on this Sourcegraph instance. Ask your site admin to enable Batch Changes before running 'src batch' commands."
}

func (e *BatchChangesDisabledError) Unwrap() error {
return e.cause
}
57 changes: 56 additions & 1 deletion internal/batches/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ func (svc *Service) getSourcegraphVersionAndMaxChangesetsCount(ctx context.Conte
}

ok, err := svc.client.NewQuery(getInstanceInfo).Do(ctx, &result)
if err != nil || !ok {
if err != nil {
return "", 0, translateBatchChangesDisabledError(err)
}
if !ok {
return "", 0, err
}

Expand All @@ -79,6 +82,9 @@ func (svc *Service) getSourcegraphVersionAndMaxChangesetsCount(ctx context.Conte
func (svc *Service) DetermineLicenseAndFeatureFlags(ctx context.Context, skipErrors bool) (*batches.LicenseRestrictions, *batches.FeatureFlags, error) {
version, mc, err := svc.getSourcegraphVersionAndMaxChangesetsCount(ctx)
if err != nil {
if _, ok := err.(*batches.BatchChangesDisabledError); ok {
return nil, nil, err
}
return nil, nil, errors.Wrap(err, "failed to query Sourcegraph version and license info for instance")
}

Expand All @@ -91,6 +97,55 @@ func (svc *Service) DetermineLicenseAndFeatureFlags(ctx context.Context, skipErr

}

func translateBatchChangesDisabledError(err error) error {
gqlErrs, ok := err.(api.GraphQlErrors)
if !ok || len(gqlErrs) == 0 {
return err
}

sawBatchChangesField := false

for _, gqlErr := range gqlErrs {
message, messageErr := gqlErr.Message()
if messageErr != nil {
return err
}

field, ok := parseMissingQueryField(message)
if !ok {
return err
}

switch field {
case "batchChanges":
sawBatchChangesField = true
case "maxUnlicensedChangesets":
default:
return err
}
}

if !sawBatchChangesField {
return err
}

return batches.NewBatchChangesDisabledError(err)
}

func parseMissingQueryField(message string) (string, bool) {
const (
prefix = `Cannot query field "`
suffix = `" on type "Query".`
)

if !strings.HasPrefix(message, prefix) || !strings.HasSuffix(message, suffix) {
return "", false
}

field := strings.TrimSuffix(strings.TrimPrefix(message, prefix), suffix)
return field, field != ""
}

const applyBatchChangeMutation = `
mutation ApplyBatchChange($batchSpec: ID!) {
applyBatchChange(batchSpec: $batchSpec) {
Expand Down
59 changes: 59 additions & 0 deletions internal/batches/service/service_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package service

import (
"bytes"
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strconv"
Expand All @@ -16,11 +20,66 @@ import (

batcheslib "github.com/sourcegraph/sourcegraph/lib/batches"

"github.com/sourcegraph/src-cli/internal/api"
"github.com/sourcegraph/src-cli/internal/batches"
"github.com/sourcegraph/src-cli/internal/batches/docker"
"github.com/sourcegraph/src-cli/internal/batches/graphql"
"github.com/sourcegraph/src-cli/internal/batches/mock"
)

func TestService_DetermineLicenseAndFeatureFlags_BatchChangesDisabled(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/.api/graphql", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"errors": [
{"message": "Cannot query field \"maxUnlicensedChangesets\" on type \"Query\"."},
{"message": "Cannot query field \"batchChanges\" on type \"Query\"."}
],
"data": {}
}`))
}))
defer ts.Close()

endpointURL, err := url.Parse(ts.URL)
require.NoError(t, err)

var clientOutput bytes.Buffer
svc := New(&Opts{Client: api.NewClient(api.ClientOpts{EndpointURL: endpointURL, Out: &clientOutput})})

_, _, err = svc.DetermineLicenseAndFeatureFlags(context.Background(), false)
require.Error(t, err)

var disabledErr *batches.BatchChangesDisabledError
require.ErrorAs(t, err, &disabledErr)
assert.Equal(t, "Batch Changes is disabled on this Sourcegraph instance. Ask your site admin to enable Batch Changes before running 'src batch' commands.", err.Error())
}

func TestService_DetermineLicenseAndFeatureFlags_DoesNotMisclassifySchemaErrors(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"errors": [
{"message": "Cannot query field \"maxUnlicensedChangesets\" on type \"Query\"."}
],
"data": {}
}`))
}))
defer ts.Close()

endpointURL, err := url.Parse(ts.URL)
require.NoError(t, err)

var clientOutput bytes.Buffer
svc := New(&Opts{Client: api.NewClient(api.ClientOpts{EndpointURL: endpointURL, Out: &clientOutput})})

_, _, err = svc.DetermineLicenseAndFeatureFlags(context.Background(), false)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to query Sourcegraph version and license info for instance")
assert.NotContains(t, err.Error(), "Batch Changes is disabled on this Sourcegraph instance")
}

func TestService_ValidateChangesetSpecs(t *testing.T) {
repo1 := &graphql.Repository{ID: "repo-graphql-id-1", Name: "github.com/sourcegraph/src-cli"}
repo2 := &graphql.Repository{ID: "repo-graphql-id-2", Name: "github.com/sourcegraph/sourcegraph"}
Expand Down
Loading