From e6b9af36de56179fce3fa2919b6ec857857a9510 Mon Sep 17 00:00:00 2001 From: Jim Myhrberg Date: Sun, 11 Dec 2022 21:37:35 +0000 Subject: [PATCH] feat(client)!: simplify Client by extracting API methods to APIClient BREAKING CHANGE: All API request related moved from Client to APIClient type. --- api_client.go | 194 +++++++++++++++++++++++++++++++++++ archive.go | 2 +- client.go | 245 ++------------------------------------------- collection_data.go | 4 +- collection_jobs.go | 2 +- collections.go | 8 +- options.go | 70 +++++++++++++ recent_jobs.go | 2 +- words.go | 2 +- 9 files changed, 282 insertions(+), 247 deletions(-) create mode 100644 api_client.go create mode 100644 options.go diff --git a/api_client.go b/api_client.go new file mode 100644 index 0000000..bba9a94 --- /dev/null +++ b/api_client.go @@ -0,0 +1,194 @@ +package midjourney + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/rs/zerolog" +) + +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +type APIClient struct { + HTTPClient HTTPClient + APIURL *url.URL + AuthToken string + UserAgent string + Logger zerolog.Logger +} + +func NewAPI(options ...Option) (*APIClient, error) { + c := &APIClient{ + HTTPClient: http.DefaultClient, + APIURL: &DefaultAPIURL, + UserAgent: DefaultUserAgent, + Logger: zerolog.Nop(), + } + err := c.Set(options...) + + return c, err +} + +func (ac *APIClient) Set(options ...Option) error { + for _, opt := range options { + err := opt.apply(ac) + if err != nil { + return err + } + } + + return nil +} + +func (ac *APIClient) Do(req *http.Request) (*http.Response, error) { + req.Header.Set("Accept", "application/json") + if ac.AuthToken != "" { + req.Header.Set( + "Cookie", "__Secure-next-auth.session-token="+ac.AuthToken, + ) + } + if ac.UserAgent != "" { + req.Header.Set("User-Agent", ac.UserAgent) + } + + return ac.HTTPClient.Do(req) +} + +func (ac *APIClient) Request( + ctx context.Context, + method string, + path string, + params url.Values, + body any, + result any, +) error { + u := &url.URL{Path: path} + if params != nil { + u.RawQuery = params.Encode() + } + u = ac.APIURL.ResolveReference(u) + + ac.Logger.Debug(). + Str("method", method). + Str("url", u.String()). + Msg("request") + + var req *http.Request + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return err + } + + ac.Logger.Trace().RawJSON("body", b).Msg("request") + + buf := bytes.NewBuffer(b) + req, err = http.NewRequestWithContext(ctx, method, u.String(), buf) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + } else { + var err error + req, err = http.NewRequestWithContext(ctx, method, u.String(), nil) + if err != nil { + return err + } + } + + resp, err := ac.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status) + } + + // When token is invalid, a HTTP 200 response with content type text/html is + // returned. Hence we treat non-JSON responses as an invalid auth token + // error. + contentType := resp.Header.Get("Content-Type") + if !strings.Contains(contentType, "application/json") { + return ErrInvalidAuthToken + } + + var buf bytes.Buffer + _, err = buf.ReadFrom(resp.Body) + if err != nil { + return err + } + + ac.Logger.Trace().RawJSON("body", buf.Bytes()).Msg("response") + + err = json.Unmarshal(buf.Bytes(), result) + if err != nil { + respErr := &ResponseError{} + unmarshalErr := json.Unmarshal(buf.Bytes(), respErr) + if unmarshalErr != nil { + return err + } + + return respErr + } + + return nil +} + +func (ac *APIClient) Get( + ctx context.Context, + path string, + params url.Values, + x any, +) error { + return ac.Request(ctx, http.MethodGet, path, params, nil, x) +} + +func (ac *APIClient) Put( + ctx context.Context, + path string, + params url.Values, + body any, + x any, +) error { + return ac.Request(ctx, http.MethodPut, path, params, body, x) +} + +func (ac *APIClient) Post( + ctx context.Context, + path string, + params url.Values, + body any, + x any, +) error { + return ac.Request(ctx, http.MethodPost, path, params, body, x) +} + +func (ac *APIClient) Patch( + ctx context.Context, + path string, + params url.Values, + body any, + x any, +) error { + return ac.Request(ctx, http.MethodPatch, path, params, body, x) +} + +func (ac *APIClient) Delete( + ctx context.Context, + path string, + params url.Values, + body any, + x any, +) error { + return ac.Request(ctx, http.MethodDelete, path, params, body, x) +} diff --git a/archive.go b/archive.go index 5e571c2..17e5e41 100644 --- a/archive.go +++ b/archive.go @@ -10,7 +10,7 @@ func (c *Client) ArchiveDay( ctx context.Context, date time.Time, ) (jobIDs []string, err error) { - err = c.Get(ctx, "app/archive/day", url.Values{ + err = c.API.Get(ctx, "app/archive/day", url.Values{ "day": []string{date.Format("2")}, "month": []string{date.Format("1")}, "year": []string{date.Format("2006")}, diff --git a/client.go b/client.go index 5be6292..949f8a5 100644 --- a/client.go +++ b/client.go @@ -1,16 +1,9 @@ package midjourney import ( - "bytes" - "context" - "encoding/json" "errors" "fmt" - "net/http" "net/url" - "strings" - - "github.com/rs/zerolog" ) var ( @@ -28,245 +21,23 @@ var ( Host: "www.midjourney.com", Path: "/api/", } - DefaultUserAgent = "go-midjourney/0.0.0-dev" + + DefaultUserAgent = "go-midjourney/0.0.1" // x-release-please-version ) -type Option interface { - apply(*Client) error -} - -type optionFunc func(*Client) error - -func (fn optionFunc) apply(o *Client) error { - return fn(o) -} - -// WithAuthToken returns a new Option type which sets the auth token that the -// client will use. The authToken value can be fetched from the -// "__Secure-next-auth.session-token" cookie on the midjourney.com website. -func WithAuthToken(authToken string) Option { - return optionFunc(func(c *Client) error { - c.AuthToken = authToken - - return nil - }) -} - -func WithAPIURL(baseURL string) Option { - return optionFunc(func(c *Client) error { - if !strings.HasSuffix(baseURL, "/") { - baseURL += "/" - } - - u, err := url.Parse(baseURL) - if err != nil { - return err - } - - c.APIURL = u - - return nil - }) -} - -func WithHTTPClient(httpClient *http.Client) Option { - return optionFunc(func(c *Client) error { - c.HTTPClient = httpClient - - return nil - }) -} - -func WithUserAgent(userAgent string) Option { - return optionFunc(func(c *Client) error { - c.UserAgent = userAgent - - return nil - }) -} - -func WithLogger(logger zerolog.Logger) Option { - return optionFunc(func(c *Client) error { - c.Logger = logger - - return nil - }) -} - -type HTTPClient interface { - Do(req *http.Request) (*http.Response, error) -} - type Client struct { - HTTPClient HTTPClient - APIURL *url.URL - AuthToken string - UserAgent string - Logger zerolog.Logger + API *APIClient } func New(options ...Option) (*Client, error) { - c := &Client{ - HTTPClient: http.DefaultClient, - APIURL: &DefaultAPIURL, - UserAgent: DefaultUserAgent, - Logger: zerolog.Nop(), - } - err := c.Set(options...) - - return c, err -} - -func (c *Client) Set(options ...Option) error { - for _, opt := range options { - err := opt.apply(c) - if err != nil { - return err - } - } - - return nil -} - -func (c *Client) Do(req *http.Request) (*http.Response, error) { - req.Header.Set("Accept", "application/json") - if c.AuthToken != "" { - req.Header.Set( - "Cookie", "__Secure-next-auth.session-token="+c.AuthToken, - ) - } - if c.UserAgent != "" { - req.Header.Set("User-Agent", c.UserAgent) - } - - return c.HTTPClient.Do(req) -} - -func (c *Client) Request( - ctx context.Context, - method string, - path string, - params url.Values, - body any, - result any, -) error { - u := &url.URL{Path: path} - if params != nil { - u.RawQuery = params.Encode() - } - u = c.APIURL.ResolveReference(u) - - c.Logger.Debug().Str("method", method).Str("url", u.String()).Msg("request") - - var req *http.Request - if body != nil { - b, err := json.Marshal(body) - if err != nil { - return err - } - - c.Logger.Trace().RawJSON("body", b).Msg("request") - - buf := bytes.NewBuffer(b) - req, err = http.NewRequestWithContext(ctx, method, u.String(), buf) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/json") - } else { - var err error - req, err = http.NewRequestWithContext(ctx, method, u.String(), nil) - if err != nil { - return err - } - } - - resp, err := c.Do(req) + api, err := NewAPI(options...) if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status) + return nil, err } - // When token is invalid, a HTTP 200 response with content type text/html is - // returned. Hence we treat non-JSON responses as an invalid auth token - // error. - contentType := resp.Header.Get("Content-Type") - if !strings.Contains(contentType, "application/json") { - return ErrInvalidAuthToken - } - - var buf bytes.Buffer - _, err = buf.ReadFrom(resp.Body) - if err != nil { - return err - } - - c.Logger.Trace().RawJSON("body", buf.Bytes()).Msg("response") - - err = json.Unmarshal(buf.Bytes(), result) - if err != nil { - respErr := &ResponseError{} - unmarshalErr := json.Unmarshal(buf.Bytes(), respErr) - if unmarshalErr != nil { - return err - } - - return respErr - } - - return nil + return &Client{API: api}, nil } -func (c *Client) Get( - ctx context.Context, - path string, - params url.Values, - x any, -) error { - return c.Request(ctx, http.MethodGet, path, params, nil, x) -} - -func (c *Client) Put( - ctx context.Context, - path string, - params url.Values, - body any, - x any, -) error { - return c.Request(ctx, http.MethodPut, path, params, body, x) -} - -func (c *Client) Post( - ctx context.Context, - path string, - params url.Values, - body any, - x any, -) error { - return c.Request(ctx, http.MethodPost, path, params, body, x) -} - -func (c *Client) Patch( - ctx context.Context, - path string, - params url.Values, - body any, - x any, -) error { - return c.Request(ctx, http.MethodPatch, path, params, body, x) -} - -func (c *Client) Delete( - ctx context.Context, - path string, - params url.Values, - body any, - x any, -) error { - return c.Request(ctx, http.MethodDelete, path, params, body, x) +func (ac *Client) Set(options ...Option) error { + return ac.API.Set(options...) } diff --git a/collection_data.go b/collection_data.go index b507e38..8b01567 100644 --- a/collection_data.go +++ b/collection_data.go @@ -28,7 +28,7 @@ func (c *Client) PutCollectionData( } resp := &Collection{} - err := c.Put(ctx, "app/collections/", nil, req, resp) + err := c.API.Put(ctx, "app/collections/", nil, req, resp) return resp, err } @@ -50,7 +50,7 @@ func (c *Client) PutCollectionFilters( } resp := &Collection{} - err := c.Put(ctx, "app/collections/", nil, req, resp) + err := c.API.Put(ctx, "app/collections/", nil, req, resp) return resp, err } diff --git a/collection_jobs.go b/collection_jobs.go index 91ea76b..24f74ce 100644 --- a/collection_jobs.go +++ b/collection_jobs.go @@ -50,7 +50,7 @@ func (c *Client) collectionJobs( var resp *CollectionJobsResult - err := c.Request( + err := c.API.Request( ctx, method, "app/collections-jobs/", nil, &collectionJobsRequest{CollectionID: collectionID, JobIDs: jobIDs}, resp, diff --git a/collections.go b/collections.go index c3cbc59..875339b 100644 --- a/collections.go +++ b/collections.go @@ -54,7 +54,7 @@ func (c *Client) Collections( ) ([]*Collection, error) { var collections []*Collection - err := c.Get(ctx, "app/collections/", query.URLValues(), &collections) + err := c.API.Get(ctx, "app/collections/", query.URLValues(), &collections) return collections, err } @@ -73,7 +73,7 @@ func (c *Client) GetCollection( // Deletion of a collection is strangely done by setting the hidden flag to // true. This is a bit confusing, but it's how the API works. - err := c.Get(ctx, "app/collections/", q.URLValues(), &cols) + err := c.API.Get(ctx, "app/collections/", q.URLValues(), &cols) if len(cols) == 0 { return nil, fmt.Errorf("%w: id=%s", ErrCollectionNotFound, collectionID) @@ -88,7 +88,7 @@ func (c *Client) PutCollection( ) (*Collection, error) { var col *Collection - err := c.Put(ctx, "app/collections/", nil, collection, col) + err := c.API.Put(ctx, "app/collections/", nil, collection, col) return col, err } @@ -105,7 +105,7 @@ func (c *Client) DeleteCollection( // Deletion of a collection is strangely done by setting the hidden flag to // true. This is a bit confusing, but it's how the API works. - err := c.Put( + err := c.API.Put( ctx, "app/collections/", nil, &Collection{ID: collectionID, Hidden: true}, col, ) diff --git a/options.go b/options.go new file mode 100644 index 0000000..88ed014 --- /dev/null +++ b/options.go @@ -0,0 +1,70 @@ +package midjourney + +import ( + "net/url" + "strings" + + "github.com/rs/zerolog" +) + +type Option interface { + apply(*APIClient) error +} + +type optionFunc func(*APIClient) error + +func (fn optionFunc) apply(o *APIClient) error { + return fn(o) +} + +// WithAuthToken returns a new Option type which sets the auth token that the +// client will use. The authToken value can be fetched from the +// "__Secure-next-auth.session-token" cookie on the midjourney.com website. +func WithAuthToken(authToken string) Option { + return optionFunc(func(c *APIClient) error { + c.AuthToken = authToken + + return nil + }) +} + +func WithAPIURL(baseURL string) Option { + return optionFunc(func(c *APIClient) error { + if !strings.HasSuffix(baseURL, "/") { + baseURL += "/" + } + + u, err := url.Parse(baseURL) + if err != nil { + return err + } + + c.APIURL = u + + return nil + }) +} + +func WithHTTPClient(httpClient HTTPClient) Option { + return optionFunc(func(c *APIClient) error { + c.HTTPClient = httpClient + + return nil + }) +} + +func WithUserAgent(userAgent string) Option { + return optionFunc(func(c *APIClient) error { + c.UserAgent = userAgent + + return nil + }) +} + +func WithLogger(logger zerolog.Logger) Option { + return optionFunc(func(c *APIClient) error { + c.Logger = logger + + return nil + }) +} diff --git a/recent_jobs.go b/recent_jobs.go index 52e47b3..55cc1d9 100644 --- a/recent_jobs.go +++ b/recent_jobs.go @@ -119,7 +119,7 @@ func (c *Client) RecentJobs( Page: q.Page, } - err := c.Get(ctx, "app/recent-jobs", q.URLValues(), &rj.Jobs) + err := c.API.Get(ctx, "app/recent-jobs", q.URLValues(), &rj.Jobs) if err != nil { return nil, err } diff --git a/words.go b/words.go index 8f9ae98..7602ee9 100644 --- a/words.go +++ b/words.go @@ -61,7 +61,7 @@ func randInt(max int) int { func (c *Client) Words(ctx context.Context, q *WordsQuery) ([]*Word, error) { w := map[string]string{} - err := c.Get(ctx, "app/words/", q.URLValues(), &w) + err := c.API.Get(ctx, "app/words/", q.URLValues(), &w) if err != nil { return nil, err }