feat(client): improve Client with various helper request methods

This commit is contained in:
2022-12-11 20:41:32 +00:00
parent 5d89204e21
commit ae6683db15
6 changed files with 225 additions and 92 deletions

View File

@@ -2,9 +2,6 @@ package midjourney
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
)
@@ -13,34 +10,11 @@ func (c *Client) ArchiveDay(
ctx context.Context,
date time.Time,
) (jobIDs []string, err error) {
u := &url.URL{
Path: "app/archive/day/",
RawQuery: url.Values{
"day": []string{date.Format("2")},
"month": []string{date.Format("1")},
"year": []string{date.Format("2006")},
}.Encode(),
}
err = c.Get(ctx, "app/archive/day", url.Values{
"day": []string{date.Format("2")},
"month": []string{date.Format("1")},
"year": []string{date.Format("2006")},
}, &jobIDs)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}
resp, err := c.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status)
}
err = json.NewDecoder(resp.Body).Decode(&jobIDs)
if err != nil {
return nil, err
}
return jobIDs, nil
return jobIDs, err
}

142
client.go
View File

@@ -1,6 +1,9 @@
package midjourney
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
@@ -13,9 +16,11 @@ import (
var (
Err = errors.New("midjourney")
ErrNoAuthToken = fmt.Errorf("%w: no auth token", Err)
ErrInvalidAuthToken = fmt.Errorf("%w: invalid auth token", Err)
ErrInvalidAPIURL = fmt.Errorf("%w: invalid API URL", Err)
ErrInvalidHTTPClient = fmt.Errorf("%w: invalid HTTP client", Err)
ErrResponseStatus = fmt.Errorf("%w: response status", Err)
ErrResponse = fmt.Errorf("%w: response", Err)
ErrResponseStatus = fmt.Errorf("%w: response status", ErrResponse)
DefaultAPIURL = url.URL{
Scheme: "https",
@@ -35,6 +40,9 @@ func (fn optionFunc) apply(o *Client) error {
return fn(o)
}
// WithAuthToken returns a new Option type which sets the auth token that the
// client will use. The authToken value can be fetched from the
// "__Secure-next-auth.session-token" cookie on the midjourney.com website.
func WithAuthToken(authToken string) Option {
return optionFunc(func(c *Client) error {
c.AuthToken = authToken
@@ -120,9 +128,6 @@ func (c *Client) Set(options ...Option) error {
}
func (c *Client) Do(req *http.Request) (*http.Response, error) {
req.URL = c.APIURL.ResolveReference(req.URL)
c.Logger.Debug().Str("url", req.URL.String()).Msg("request")
req.Header.Set("Accept", "application/json")
if c.AuthToken != "" {
req.Header.Set(
@@ -135,3 +140,132 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
return c.HTTPClient.Do(req)
}
func (c *Client) Request(
ctx context.Context,
method string,
path string,
params url.Values,
body any,
result any,
) error {
u := &url.URL{Path: path}
if params != nil {
u.RawQuery = params.Encode()
}
u = c.APIURL.ResolveReference(u)
c.Logger.Debug().Str("method", method).Str("url", u.String()).Msg("request")
var req *http.Request
if body != nil {
b, err := json.Marshal(body)
if err != nil {
return err
}
c.Logger.Trace().RawJSON("body", b).Msg("request")
buf := bytes.NewBuffer(b)
req, err = http.NewRequestWithContext(ctx, method, u.String(), buf)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
} else {
var err error
req, err = http.NewRequestWithContext(ctx, method, u.String(), nil)
if err != nil {
return err
}
}
resp, err := c.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status)
}
// When token is invalid, a HTTP 200 response with content type text/html is
// returned. Hence we treat non-JSON responses as an invalid auth token
// error.
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "application/json") {
return ErrInvalidAuthToken
}
var buf bytes.Buffer
_, err = buf.ReadFrom(resp.Body)
if err != nil {
return err
}
c.Logger.Trace().RawJSON("body", buf.Bytes()).Msg("response")
err = json.Unmarshal(buf.Bytes(), result)
if err != nil {
respErr := &ResponseError{}
unmarshalErr := json.Unmarshal(buf.Bytes(), respErr)
if unmarshalErr != nil {
return err
}
return respErr
}
return nil
}
func (c *Client) Get(
ctx context.Context,
path string,
params url.Values,
x any,
) error {
return c.Request(ctx, http.MethodGet, path, params, nil, x)
}
func (c *Client) Put(
ctx context.Context,
path string,
params url.Values,
body any,
x any,
) error {
return c.Request(ctx, http.MethodPut, path, params, body, x)
}
func (c *Client) Post(
ctx context.Context,
path string,
params url.Values,
body any,
x any,
) error {
return c.Request(ctx, http.MethodPost, path, params, body, x)
}
func (c *Client) Patch(
ctx context.Context,
path string,
params url.Values,
body any,
x any,
) error {
return c.Request(ctx, http.MethodPatch, path, params, body, x)
}
func (c *Client) Delete(
ctx context.Context,
path string,
params url.Values,
body any,
x any,
) error {
return c.Request(ctx, http.MethodDelete, path, params, body, x)
}

View File

@@ -2,9 +2,7 @@ package midjourney
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"time"
@@ -42,7 +40,7 @@ type RecentJobsQuery struct {
RefreshAPI int
}
func (rjq *RecentJobsQuery) Values() url.Values {
func (rjq *RecentJobsQuery) URLValues() url.Values {
v := url.Values{}
if rjq.Amount != 0 {
v.Set("amount", strconv.Itoa(rjq.Amount))
@@ -105,35 +103,15 @@ func (c *Client) RecentJobs(
ctx context.Context,
q *RecentJobsQuery,
) (*RecentJobs, error) {
u := &url.URL{
Path: "app/recent-jobs/",
RawQuery: q.Values().Encode(),
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}
now := time.Now().UTC()
resp, err := c.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status)
}
rj := &RecentJobs{
Query: *q,
Jobs: []*Job{},
Page: q.Page,
}
err = json.NewDecoder(resp.Body).Decode(&rj.Jobs)
err := c.Get(ctx, "app/recent-jobs", q.URLValues(), &rj.Jobs)
if err != nil {
return nil, err
}

23
response_error.go Normal file
View File

@@ -0,0 +1,23 @@
package midjourney
import (
"fmt"
)
type ResponseError struct {
Message string `json:"error,omitempty"`
message string
}
func (re *ResponseError) Error() string {
if re.message != "" {
return re.message
}
re.message = fmt.Errorf("%w: %s", ErrResponse, re.Message).Error()
return re.message
}
func (re *ResponseError) Unwrap() error {
return ErrResponse
}

46
response_error_test.go Normal file
View File

@@ -0,0 +1,46 @@
package midjourney
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestResponseError_Is(t *testing.T) {
tests := []struct {
name string
is error
want bool
}{
{
name: "Err",
is: Err,
want: true,
},
{
name: "ErrResponse",
is: ErrResponse,
want: true,
},
{
name: "ErrResponse",
is: ErrResponse,
want: true,
},
{
name: "ErrInvalidAPIURL",
is: ErrInvalidAPIURL,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
respErr := &ResponseError{Message: "foo"}
got := errors.Is(respErr, tt.is)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -3,10 +3,8 @@ package midjourney
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
"math/big"
"net/http"
"net/url"
"strconv"
)
@@ -30,19 +28,19 @@ type WordsQuery struct {
RandomSeed bool
}
func (rjq *WordsQuery) Values() url.Values {
func (wq *WordsQuery) URLValues() url.Values {
v := url.Values{}
if rjq.Query != "" {
v.Set("query", rjq.Query)
if wq.Query != "" {
v.Set("query", wq.Query)
}
if rjq.Amount != 0 {
v.Set("amount", strconv.Itoa(rjq.Amount))
if wq.Amount != 0 {
v.Set("amount", strconv.Itoa(wq.Amount))
}
v.Set("page", strconv.Itoa(rjq.Page))
if rjq.RandomSeed {
v.Set("page", strconv.Itoa(wq.Page))
if wq.RandomSeed {
v.Set("seed", strconv.Itoa(randInt(9999)))
} else if rjq.Seed != 0 {
v.Set("seed", strconv.Itoa(rjq.Seed))
} else if wq.Seed != 0 {
v.Set("seed", strconv.Itoa(wq.Seed))
}
return v
@@ -62,28 +60,8 @@ func randInt(max int) int {
}
func (c *Client) Words(ctx context.Context, q *WordsQuery) ([]*Word, error) {
u := &url.URL{
Path: "app/words/",
RawQuery: q.Values().Encode(),
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}
resp, err := c.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status)
}
w := map[string]string{}
err = json.NewDecoder(resp.Body).Decode(&w)
err := c.Get(ctx, "app/words/", q.URLValues(), &w)
if err != nil {
return nil, err
}