mirror of
https://github.com/jimeh/go-midjourney.git
synced 2026-02-19 09:56:41 +00:00
feat(client): improve Client with various helper request methods
This commit is contained in:
38
archive.go
38
archive.go
@@ -2,9 +2,6 @@ package midjourney
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
@@ -13,34 +10,11 @@ func (c *Client) ArchiveDay(
|
||||
ctx context.Context,
|
||||
date time.Time,
|
||||
) (jobIDs []string, err error) {
|
||||
u := &url.URL{
|
||||
Path: "app/archive/day/",
|
||||
RawQuery: url.Values{
|
||||
"day": []string{date.Format("2")},
|
||||
"month": []string{date.Format("1")},
|
||||
"year": []string{date.Format("2006")},
|
||||
}.Encode(),
|
||||
}
|
||||
err = c.Get(ctx, "app/archive/day", url.Values{
|
||||
"day": []string{date.Format("2")},
|
||||
"month": []string{date.Format("1")},
|
||||
"year": []string{date.Format("2006")},
|
||||
}, &jobIDs)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status)
|
||||
}
|
||||
|
||||
err = json.NewDecoder(resp.Body).Decode(&jobIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return jobIDs, nil
|
||||
return jobIDs, err
|
||||
}
|
||||
|
||||
142
client.go
142
client.go
@@ -1,6 +1,9 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -13,9 +16,11 @@ import (
|
||||
var (
|
||||
Err = errors.New("midjourney")
|
||||
ErrNoAuthToken = fmt.Errorf("%w: no auth token", Err)
|
||||
ErrInvalidAuthToken = fmt.Errorf("%w: invalid auth token", Err)
|
||||
ErrInvalidAPIURL = fmt.Errorf("%w: invalid API URL", Err)
|
||||
ErrInvalidHTTPClient = fmt.Errorf("%w: invalid HTTP client", Err)
|
||||
ErrResponseStatus = fmt.Errorf("%w: response status", Err)
|
||||
ErrResponse = fmt.Errorf("%w: response", Err)
|
||||
ErrResponseStatus = fmt.Errorf("%w: response status", ErrResponse)
|
||||
|
||||
DefaultAPIURL = url.URL{
|
||||
Scheme: "https",
|
||||
@@ -35,6 +40,9 @@ func (fn optionFunc) apply(o *Client) error {
|
||||
return fn(o)
|
||||
}
|
||||
|
||||
// WithAuthToken returns a new Option type which sets the auth token that the
|
||||
// client will use. The authToken value can be fetched from the
|
||||
// "__Secure-next-auth.session-token" cookie on the midjourney.com website.
|
||||
func WithAuthToken(authToken string) Option {
|
||||
return optionFunc(func(c *Client) error {
|
||||
c.AuthToken = authToken
|
||||
@@ -120,9 +128,6 @@ func (c *Client) Set(options ...Option) error {
|
||||
}
|
||||
|
||||
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
req.URL = c.APIURL.ResolveReference(req.URL)
|
||||
c.Logger.Debug().Str("url", req.URL.String()).Msg("request")
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if c.AuthToken != "" {
|
||||
req.Header.Set(
|
||||
@@ -135,3 +140,132 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
|
||||
return c.HTTPClient.Do(req)
|
||||
}
|
||||
|
||||
func (c *Client) Request(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
path string,
|
||||
params url.Values,
|
||||
body any,
|
||||
result any,
|
||||
) error {
|
||||
u := &url.URL{Path: path}
|
||||
if params != nil {
|
||||
u.RawQuery = params.Encode()
|
||||
}
|
||||
u = c.APIURL.ResolveReference(u)
|
||||
|
||||
c.Logger.Debug().Str("method", method).Str("url", u.String()).Msg("request")
|
||||
|
||||
var req *http.Request
|
||||
if body != nil {
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Logger.Trace().RawJSON("body", b).Msg("request")
|
||||
|
||||
buf := bytes.NewBuffer(b)
|
||||
req, err = http.NewRequestWithContext(ctx, method, u.String(), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
var err error
|
||||
req, err = http.NewRequestWithContext(ctx, method, u.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := c.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status)
|
||||
}
|
||||
|
||||
// When token is invalid, a HTTP 200 response with content type text/html is
|
||||
// returned. Hence we treat non-JSON responses as an invalid auth token
|
||||
// error.
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if !strings.Contains(contentType, "application/json") {
|
||||
return ErrInvalidAuthToken
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, err = buf.ReadFrom(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Logger.Trace().RawJSON("body", buf.Bytes()).Msg("response")
|
||||
|
||||
err = json.Unmarshal(buf.Bytes(), result)
|
||||
if err != nil {
|
||||
respErr := &ResponseError{}
|
||||
unmarshalErr := json.Unmarshal(buf.Bytes(), respErr)
|
||||
if unmarshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return respErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Get(
|
||||
ctx context.Context,
|
||||
path string,
|
||||
params url.Values,
|
||||
x any,
|
||||
) error {
|
||||
return c.Request(ctx, http.MethodGet, path, params, nil, x)
|
||||
}
|
||||
|
||||
func (c *Client) Put(
|
||||
ctx context.Context,
|
||||
path string,
|
||||
params url.Values,
|
||||
body any,
|
||||
x any,
|
||||
) error {
|
||||
return c.Request(ctx, http.MethodPut, path, params, body, x)
|
||||
}
|
||||
|
||||
func (c *Client) Post(
|
||||
ctx context.Context,
|
||||
path string,
|
||||
params url.Values,
|
||||
body any,
|
||||
x any,
|
||||
) error {
|
||||
return c.Request(ctx, http.MethodPost, path, params, body, x)
|
||||
}
|
||||
|
||||
func (c *Client) Patch(
|
||||
ctx context.Context,
|
||||
path string,
|
||||
params url.Values,
|
||||
body any,
|
||||
x any,
|
||||
) error {
|
||||
return c.Request(ctx, http.MethodPatch, path, params, body, x)
|
||||
}
|
||||
|
||||
func (c *Client) Delete(
|
||||
ctx context.Context,
|
||||
path string,
|
||||
params url.Values,
|
||||
body any,
|
||||
x any,
|
||||
) error {
|
||||
return c.Request(ctx, http.MethodDelete, path, params, body, x)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,7 @@ package midjourney
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -42,7 +40,7 @@ type RecentJobsQuery struct {
|
||||
RefreshAPI int
|
||||
}
|
||||
|
||||
func (rjq *RecentJobsQuery) Values() url.Values {
|
||||
func (rjq *RecentJobsQuery) URLValues() url.Values {
|
||||
v := url.Values{}
|
||||
if rjq.Amount != 0 {
|
||||
v.Set("amount", strconv.Itoa(rjq.Amount))
|
||||
@@ -105,35 +103,15 @@ func (c *Client) RecentJobs(
|
||||
ctx context.Context,
|
||||
q *RecentJobsQuery,
|
||||
) (*RecentJobs, error) {
|
||||
u := &url.URL{
|
||||
Path: "app/recent-jobs/",
|
||||
RawQuery: q.Values().Encode(),
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
resp, err := c.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status)
|
||||
}
|
||||
|
||||
rj := &RecentJobs{
|
||||
Query: *q,
|
||||
Jobs: []*Job{},
|
||||
Page: q.Page,
|
||||
}
|
||||
|
||||
err = json.NewDecoder(resp.Body).Decode(&rj.Jobs)
|
||||
err := c.Get(ctx, "app/recent-jobs", q.URLValues(), &rj.Jobs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
23
response_error.go
Normal file
23
response_error.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type ResponseError struct {
|
||||
Message string `json:"error,omitempty"`
|
||||
message string
|
||||
}
|
||||
|
||||
func (re *ResponseError) Error() string {
|
||||
if re.message != "" {
|
||||
return re.message
|
||||
}
|
||||
re.message = fmt.Errorf("%w: %s", ErrResponse, re.Message).Error()
|
||||
|
||||
return re.message
|
||||
}
|
||||
|
||||
func (re *ResponseError) Unwrap() error {
|
||||
return ErrResponse
|
||||
}
|
||||
46
response_error_test.go
Normal file
46
response_error_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResponseError_Is(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
is error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Err",
|
||||
is: Err,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ErrResponse",
|
||||
is: ErrResponse,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ErrResponse",
|
||||
is: ErrResponse,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ErrInvalidAPIURL",
|
||||
is: ErrInvalidAPIURL,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
respErr := &ResponseError{Message: "foo"}
|
||||
|
||||
got := errors.Is(respErr, tt.is)
|
||||
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
42
words.go
42
words.go
@@ -3,10 +3,8 @@ package midjourney
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
)
|
||||
@@ -30,19 +28,19 @@ type WordsQuery struct {
|
||||
RandomSeed bool
|
||||
}
|
||||
|
||||
func (rjq *WordsQuery) Values() url.Values {
|
||||
func (wq *WordsQuery) URLValues() url.Values {
|
||||
v := url.Values{}
|
||||
if rjq.Query != "" {
|
||||
v.Set("query", rjq.Query)
|
||||
if wq.Query != "" {
|
||||
v.Set("query", wq.Query)
|
||||
}
|
||||
if rjq.Amount != 0 {
|
||||
v.Set("amount", strconv.Itoa(rjq.Amount))
|
||||
if wq.Amount != 0 {
|
||||
v.Set("amount", strconv.Itoa(wq.Amount))
|
||||
}
|
||||
v.Set("page", strconv.Itoa(rjq.Page))
|
||||
if rjq.RandomSeed {
|
||||
v.Set("page", strconv.Itoa(wq.Page))
|
||||
if wq.RandomSeed {
|
||||
v.Set("seed", strconv.Itoa(randInt(9999)))
|
||||
} else if rjq.Seed != 0 {
|
||||
v.Set("seed", strconv.Itoa(rjq.Seed))
|
||||
} else if wq.Seed != 0 {
|
||||
v.Set("seed", strconv.Itoa(wq.Seed))
|
||||
}
|
||||
|
||||
return v
|
||||
@@ -62,28 +60,8 @@ func randInt(max int) int {
|
||||
}
|
||||
|
||||
func (c *Client) Words(ctx context.Context, q *WordsQuery) ([]*Word, error) {
|
||||
u := &url.URL{
|
||||
Path: "app/words/",
|
||||
RawQuery: q.Values().Encode(),
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%w: %s", ErrResponseStatus, resp.Status)
|
||||
}
|
||||
|
||||
w := map[string]string{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&w)
|
||||
err := c.Get(ctx, "app/words/", q.URLValues(), &w)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user