diff --git a/archive.go b/archive.go index 2d3913c..5e571c2 100644 --- a/archive.go +++ b/archive.go @@ -2,9 +2,6 @@ package midjourney import ( "context" - "encoding/json" - "fmt" - "net/http" "net/url" "time" ) @@ -13,34 +10,11 @@ func (c *Client) ArchiveDay( ctx context.Context, date time.Time, ) (jobIDs []string, err error) { - u := &url.URL{ - Path: "app/archive/day/", - RawQuery: url.Values{ - "day": []string{date.Format("2")}, - "month": []string{date.Format("1")}, - "year": []string{date.Format("2006")}, - }.Encode(), - } + err = c.Get(ctx, "app/archive/day", url.Values{ + "day": []string{date.Format("2")}, + "month": []string{date.Format("1")}, + "year": []string{date.Format("2006")}, + }, &jobIDs) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) - if err != nil { - return nil, err - } - - resp, err := c.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status) - } - - err = json.NewDecoder(resp.Body).Decode(&jobIDs) - if err != nil { - return nil, err - } - - return jobIDs, nil + return jobIDs, err } diff --git a/client.go b/client.go index afc4f4c..5f2fc97 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,9 @@ package midjourney import ( + "bytes" + "context" + "encoding/json" "errors" "fmt" "net/http" @@ -13,9 +16,11 @@ import ( var ( Err = errors.New("midjourney") ErrNoAuthToken = fmt.Errorf("%w: no auth token", Err) + ErrInvalidAuthToken = fmt.Errorf("%w: invalid auth token", Err) ErrInvalidAPIURL = fmt.Errorf("%w: invalid API URL", Err) ErrInvalidHTTPClient = fmt.Errorf("%w: invalid HTTP client", Err) - ErrResponseStatus = fmt.Errorf("%w: response status", Err) + ErrResponse = fmt.Errorf("%w: response", Err) + ErrResponseStatus = fmt.Errorf("%w: response status", ErrResponse) DefaultAPIURL = url.URL{ Scheme: "https", @@ -35,6 +40,9 @@ func (fn optionFunc) apply(o *Client) error { return fn(o) } +// WithAuthToken returns a new Option type which sets the auth token that the +// client will use. The authToken value can be fetched from the +// "__Secure-next-auth.session-token" cookie on the midjourney.com website. func WithAuthToken(authToken string) Option { return optionFunc(func(c *Client) error { c.AuthToken = authToken @@ -120,9 +128,6 @@ func (c *Client) Set(options ...Option) error { } func (c *Client) Do(req *http.Request) (*http.Response, error) { - req.URL = c.APIURL.ResolveReference(req.URL) - c.Logger.Debug().Str("url", req.URL.String()).Msg("request") - req.Header.Set("Accept", "application/json") if c.AuthToken != "" { req.Header.Set( @@ -135,3 +140,132 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { return c.HTTPClient.Do(req) } + +func (c *Client) Request( + ctx context.Context, + method string, + path string, + params url.Values, + body any, + result any, +) error { + u := &url.URL{Path: path} + if params != nil { + u.RawQuery = params.Encode() + } + u = c.APIURL.ResolveReference(u) + + c.Logger.Debug().Str("method", method).Str("url", u.String()).Msg("request") + + var req *http.Request + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return err + } + + c.Logger.Trace().RawJSON("body", b).Msg("request") + + buf := bytes.NewBuffer(b) + req, err = http.NewRequestWithContext(ctx, method, u.String(), buf) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + } else { + var err error + req, err = http.NewRequestWithContext(ctx, method, u.String(), nil) + if err != nil { + return err + } + } + + resp, err := c.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status) + } + + // When token is invalid, a HTTP 200 response with content type text/html is + // returned. Hence we treat non-JSON responses as an invalid auth token + // error. + contentType := resp.Header.Get("Content-Type") + if !strings.Contains(contentType, "application/json") { + return ErrInvalidAuthToken + } + + var buf bytes.Buffer + _, err = buf.ReadFrom(resp.Body) + if err != nil { + return err + } + + c.Logger.Trace().RawJSON("body", buf.Bytes()).Msg("response") + + err = json.Unmarshal(buf.Bytes(), result) + if err != nil { + respErr := &ResponseError{} + unmarshalErr := json.Unmarshal(buf.Bytes(), respErr) + if unmarshalErr != nil { + return err + } + + return respErr + } + + return nil +} + +func (c *Client) Get( + ctx context.Context, + path string, + params url.Values, + x any, +) error { + return c.Request(ctx, http.MethodGet, path, params, nil, x) +} + +func (c *Client) Put( + ctx context.Context, + path string, + params url.Values, + body any, + x any, +) error { + return c.Request(ctx, http.MethodPut, path, params, body, x) +} + +func (c *Client) Post( + ctx context.Context, + path string, + params url.Values, + body any, + x any, +) error { + return c.Request(ctx, http.MethodPost, path, params, body, x) +} + +func (c *Client) Patch( + ctx context.Context, + path string, + params url.Values, + body any, + x any, +) error { + return c.Request(ctx, http.MethodPatch, path, params, body, x) +} + +func (c *Client) Delete( + ctx context.Context, + path string, + params url.Values, + body any, + x any, +) error { + return c.Request(ctx, http.MethodDelete, path, params, body, x) +} diff --git a/recent_jobs.go b/recent_jobs.go index ff865ad..50daaa9 100644 --- a/recent_jobs.go +++ b/recent_jobs.go @@ -2,9 +2,7 @@ package midjourney import ( "context" - "encoding/json" "fmt" - "net/http" "net/url" "strconv" "time" @@ -42,7 +40,7 @@ type RecentJobsQuery struct { RefreshAPI int } -func (rjq *RecentJobsQuery) Values() url.Values { +func (rjq *RecentJobsQuery) URLValues() url.Values { v := url.Values{} if rjq.Amount != 0 { v.Set("amount", strconv.Itoa(rjq.Amount)) @@ -105,35 +103,15 @@ func (c *Client) RecentJobs( ctx context.Context, q *RecentJobsQuery, ) (*RecentJobs, error) { - u := &url.URL{ - Path: "app/recent-jobs/", - RawQuery: q.Values().Encode(), - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) - if err != nil { - return nil, err - } - now := time.Now().UTC() - resp, err := c.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status) - } - rj := &RecentJobs{ Query: *q, Jobs: []*Job{}, Page: q.Page, } - err = json.NewDecoder(resp.Body).Decode(&rj.Jobs) + err := c.Get(ctx, "app/recent-jobs", q.URLValues(), &rj.Jobs) if err != nil { return nil, err } diff --git a/response_error.go b/response_error.go new file mode 100644 index 0000000..ede4bd4 --- /dev/null +++ b/response_error.go @@ -0,0 +1,23 @@ +package midjourney + +import ( + "fmt" +) + +type ResponseError struct { + Message string `json:"error,omitempty"` + message string +} + +func (re *ResponseError) Error() string { + if re.message != "" { + return re.message + } + re.message = fmt.Errorf("%w: %s", ErrResponse, re.Message).Error() + + return re.message +} + +func (re *ResponseError) Unwrap() error { + return ErrResponse +} diff --git a/response_error_test.go b/response_error_test.go new file mode 100644 index 0000000..85cd987 --- /dev/null +++ b/response_error_test.go @@ -0,0 +1,46 @@ +package midjourney + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResponseError_Is(t *testing.T) { + tests := []struct { + name string + is error + want bool + }{ + { + name: "Err", + is: Err, + want: true, + }, + { + name: "ErrResponse", + is: ErrResponse, + want: true, + }, + { + name: "ErrResponse", + is: ErrResponse, + want: true, + }, + { + name: "ErrInvalidAPIURL", + is: ErrInvalidAPIURL, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + respErr := &ResponseError{Message: "foo"} + + got := errors.Is(respErr, tt.is) + + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/words.go b/words.go index 3989f92..8f9ae98 100644 --- a/words.go +++ b/words.go @@ -3,10 +3,8 @@ package midjourney import ( "context" "crypto/rand" - "encoding/json" "fmt" "math/big" - "net/http" "net/url" "strconv" ) @@ -30,19 +28,19 @@ type WordsQuery struct { RandomSeed bool } -func (rjq *WordsQuery) Values() url.Values { +func (wq *WordsQuery) URLValues() url.Values { v := url.Values{} - if rjq.Query != "" { - v.Set("query", rjq.Query) + if wq.Query != "" { + v.Set("query", wq.Query) } - if rjq.Amount != 0 { - v.Set("amount", strconv.Itoa(rjq.Amount)) + if wq.Amount != 0 { + v.Set("amount", strconv.Itoa(wq.Amount)) } - v.Set("page", strconv.Itoa(rjq.Page)) - if rjq.RandomSeed { + v.Set("page", strconv.Itoa(wq.Page)) + if wq.RandomSeed { v.Set("seed", strconv.Itoa(randInt(9999))) - } else if rjq.Seed != 0 { - v.Set("seed", strconv.Itoa(rjq.Seed)) + } else if wq.Seed != 0 { + v.Set("seed", strconv.Itoa(wq.Seed)) } return v @@ -62,28 +60,8 @@ func randInt(max int) int { } func (c *Client) Words(ctx context.Context, q *WordsQuery) ([]*Word, error) { - u := &url.URL{ - Path: "app/words/", - RawQuery: q.Values().Encode(), - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) - if err != nil { - return nil, err - } - - resp, err := c.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status) - } - w := map[string]string{} - err = json.NewDecoder(resp.Body).Decode(&w) + err := c.Get(ctx, "app/words/", q.URLValues(), &w) if err != nil { return nil, err }