diff --git a/ranked_score.go b/ranked_score.go new file mode 100644 index 0000000..7113b57 --- /dev/null +++ b/ranked_score.go @@ -0,0 +1,81 @@ +package midjourney + +import ( + "encoding/json" + "strconv" + "strings" +) + +type RankedScores []RankedScore + +// URIParam returns a string representation of the RankedScores suitable for URI +// parameters. +func (rs RankedScores) URIParam() string { + vals := make([]string, 0, len(rs)) + + for _, v := range rs { + vals = append(vals, strconv.Itoa(int(v))) + } + + return strings.Join(vals, ",") +} + +func (rs RankedScores) MarshalJSON() ([]byte, error) { + return json.Marshal(rs.URIParam()) +} + +func (rs *RankedScores) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return nil + } + + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + scores := strings.Split(s, ",") + + for _, score := range scores { + if score == "" { + continue + } + + val, err := strconv.Atoi(score) + if err != nil { + return err + } + + *rs = append(*rs, RankedScore(val)) + } + + return nil +} + +type RankedScore int + +const ( + Unranked RankedScore = 0 + Mehd RankedScore = 2 + Liked RankedScore = 4 + Loved RankedScore = 5 +) + +func (rs RankedScore) String() string { + switch rs { + case Mehd: + return "meh" + case Liked: + return "liked" + case Loved: + return "loved" + case Unranked: + return "unranked" + default: + return "unknown" + } +} + +func (rs RankedScore) URIParam() string { + return strconv.Itoa(int(rs)) +} diff --git a/ranked_score_test.go b/ranked_score_test.go new file mode 100644 index 0000000..e6c1f8f --- /dev/null +++ b/ranked_score_test.go @@ -0,0 +1,84 @@ +package midjourney + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRankedScores_MarshalJSON(t *testing.T) { + tests := []struct { + name string + rs RankedScores + want string + }{ + { + name: "empty", + rs: RankedScores{}, + want: `""`, + }, + { + name: "one score", + rs: RankedScores{Mehd}, + want: `"2"`, + }, + { + name: "multiple scores", + rs: RankedScores{Unranked, Loved}, + want: `"0,5"`, + }, + { + name: "all scores", + rs: RankedScores{Unranked, Mehd, Liked, Loved}, + want: `"0,2,4,5"`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.rs) + require.NoError(t, err) + + assert.Equal(t, tt.want, string(got)) + }) + } +} + +func TestRankedScores_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + want RankedScores + }{ + { + name: "empty", + json: `""`, + want: nil, + }, + { + name: "one score", + json: `"2"`, + want: RankedScores{Mehd}, + }, + { + name: "multiple scores", + json: `"0,5"`, + want: RankedScores{Unranked, Loved}, + }, + { + name: "all scores", + json: `"0,2,4,5"`, + want: RankedScores{Unranked, Mehd, Liked, Loved}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got RankedScores + err := json.Unmarshal([]byte(tt.json), &got) + require.NoError(t, err) + + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/recent_jobs.go b/recent_jobs.go index 50daaa9..104a081 100644 --- a/recent_jobs.go +++ b/recent_jobs.go @@ -26,18 +26,19 @@ const ( ) type RecentJobsQuery struct { - Amount int - JobType JobType - OrderBy Order - JobStatus JobStatus - UserID string - UserIDLiked string - FromDate time.Time - Page int - Prompt string - Personal bool - Dedupe bool - RefreshAPI int + Amount int + JobType JobType + OrderBy Order + UserIDRankedScore RankedScores + JobStatus JobStatus + UserID string + UserIDLiked string + FromDate time.Time + Page int + Prompt string + Personal bool + Dedupe bool + RefreshAPI int } func (rjq *RecentJobsQuery) URLValues() url.Values { @@ -51,6 +52,9 @@ func (rjq *RecentJobsQuery) URLValues() url.Values { if rjq.OrderBy != "" { v.Set("orderBy", string(rjq.OrderBy)) } + if len(rjq.UserIDRankedScore) > 0 { + v.Set("user_id_ranked_score", rjq.UserIDRankedScore.URIParam()) + } if rjq.JobStatus != "" { v.Set("jobStatus", string(rjq.JobStatus)) }