diff --git a/shortener/base58_shortener.go b/shortener/base58_shortener.go index 4179486..257b538 100644 --- a/shortener/base58_shortener.go +++ b/shortener/base58_shortener.go @@ -1,9 +1,7 @@ package shortener import ( - "crypto/sha1" "errors" - "fmt" "github.com/jimeh/go-base58" "github.com/jimeh/ozu.io/storage" @@ -24,55 +22,40 @@ type Base58Shortener struct { } // Shorten a given URL. -func (s *Base58Shortener) Shorten(rawURL []byte) (uid []byte, url []byte, err error) { - url, err = NormalizeURL(rawURL) +func (s *Base58Shortener) Shorten(rawURL []byte) (*storage.Record, error) { + url, err := NormalizeURL(rawURL) if err != nil { - return nil, nil, err + return &storage.Record{}, err } - urlKey := s.makeURLKey(url) - uid, err = s.Store.Get(urlKey) - - if uid != nil && err == nil { - return uid, url, nil - } else if err != nil && err.Error() != "not found" { - return nil, nil, err + record, err := s.Store.FindByURL(url) + if err == nil { + return record, nil + } else if err != storage.ErrNotFound { + return &storage.Record{}, err } - uid, err = s.newUID() + uid, err := s.newUID() if err != nil { - return nil, nil, err + return &storage.Record{}, err } - err = s.Store.Set(urlKey, uid) + record, err = s.Store.Create(uid, url) if err != nil { - return nil, nil, err + return &storage.Record{}, err } - uidKey := s.makeUIDKey(uid) - err = s.Store.Set(uidKey, url) - if err != nil { - return nil, nil, err - } - - return uid, url, nil + return record, nil } // Lookup the URL of a given UID. -func (s *Base58Shortener) Lookup(uid []byte) ([]byte, error) { +func (s *Base58Shortener) Lookup(uid []byte) (*storage.Record, error) { _, err := base58.Decode(uid) if err != nil { - return nil, errInvalidUID + return &storage.Record{}, errInvalidUID } - uidKey := s.makeUIDKey(uid) - - url, err := s.Store.Get(uidKey) - if err != nil { - return nil, err - } - - return url, nil + return s.Store.FindByUID(uid) } func (s *Base58Shortener) newUID() ([]byte, error) { @@ -83,12 +66,3 @@ func (s *Base58Shortener) newUID() ([]byte, error) { return base58.Encode(index), nil } - -func (s *Base58Shortener) makeUIDKey(uid []byte) []byte { - return append(uidKeyPrefix, uid...) -} - -func (s *Base58Shortener) makeURLKey(rawURL []byte) []byte { - urlSHA := fmt.Sprintf("%x", sha1.Sum(rawURL)) - return append(urlKeyPrefix, urlSHA...) -} diff --git a/shortener/base58_shortener_test.go b/shortener/base58_shortener_test.go index e537012..2e1e46a 100644 --- a/shortener/base58_shortener_test.go +++ b/shortener/base58_shortener_test.go @@ -1,13 +1,12 @@ package shortener import ( - "crypto/sha1" "errors" - "fmt" "strings" "testing" "github.com/jimeh/ozu.io/shortener/mocks" + "github.com/jimeh/ozu.io/storage" "github.com/stretchr/testify/suite" ) @@ -19,52 +18,70 @@ import ( type Base58ShortenerSuite struct { suite.Suite - store *mocks.Store - shortener *Base58Shortener - errNotFound error + store *mocks.Store + shortener *Base58Shortener } func (s *Base58ShortenerSuite) SetupTest() { s.store = new(mocks.Store) s.shortener = NewBase58(s.store) - s.errNotFound = errors.New("not found") } // Tests func (s *Base58ShortenerSuite) TestShortenExisting() { - rawURL := []byte("http://google.com/") uid := []byte("ig") - urlSHA := fmt.Sprintf("%x", sha1.Sum(rawURL)) + url := []byte("https://google.com/") + record := storage.Record{UID: uid, URL: url} - s.store.On("Get", append([]byte("url:"), urlSHA...)).Return(uid, nil) + s.store.On("FindByURL", url).Return(&record, nil) - resultUID, resultURL, err := s.shortener.Shorten(rawURL) + result, err := s.shortener.Shorten(url) s.NoError(err) - s.Equal(uid, resultUID) - s.Equal(rawURL, resultURL) + s.Equal(uid, result.UID) + s.Equal(url, result.URL) s.store.AssertExpectations(s.T()) } func (s *Base58ShortenerSuite) TestShortenNew() { - rawURL := []byte("https://google.com") - url := []byte("https://google.com/") uid := []byte("ig") - urlKey := append([]byte("url:"), fmt.Sprintf("%x", sha1.Sum(url))...) + url := []byte("https://google.com/") + record := storage.Record{UID: uid, URL: url} - s.store.On("Get", urlKey).Return(nil, s.errNotFound) + s.store.On("FindByURL", url).Return(nil, storage.ErrNotFound) s.store.On("NextSequence").Return(1001, nil) - s.store.On("Set", urlKey, uid).Return(nil) - s.store.On("Set", append([]byte("uid:"), uid...), url).Return(nil) + s.store.On("Create", uid, url).Return(&record, nil) - rUID, rURL, err := s.shortener.Shorten(rawURL) + result, err := s.shortener.Shorten(url) s.NoError(err) - s.Equal(uid, rUID) - s.Equal(url, rURL) + s.Equal(uid, result.UID) + s.Equal(url, result.URL) s.store.AssertExpectations(s.T()) } +func (s *Base58ShortenerSuite) TestShortenAndNormalizeURL() { + examples := []struct { + url []byte + normalized []byte + }{ + {[]byte("google.com"), []byte("http://google.com/")}, + {[]byte("google.com/"), []byte("http://google.com/")}, + {[]byte("http://google.com"), []byte("http://google.com/")}, + } + + for _, e := range examples { + record := storage.Record{UID: []byte("ig"), URL: e.normalized} + s.store.On("FindByURL", record.URL).Return(&record, nil) + + result, err := s.shortener.Shorten(e.url) + s.NoError(err) + s.Equal(record.UID, result.UID) + s.Equal(record.URL, result.URL) + s.store.AssertExpectations(s.T()) + } +} + func (s *Base58ShortenerSuite) TestShortenInvalidURL() { examples := []struct { url string @@ -93,9 +110,9 @@ func (s *Base58ShortenerSuite) TestShortenInvalidURL() { } for _, e := range examples { - rUID, rURL, err := s.shortener.Shorten([]byte(e.url)) - s.Nil(rUID) - s.Nil(rURL) + record, err := s.shortener.Shorten([]byte(e.url)) + s.Nil(record.UID) + s.Nil(record.URL) s.EqualError(err, e.error) } } @@ -103,48 +120,50 @@ func (s *Base58ShortenerSuite) TestShortenInvalidURL() { func (s *Base58ShortenerSuite) TestShortenStoreError() { url := []byte("https://google.com/") storeErr := errors.New("leveldb: something wrong") - urlKey := append([]byte("url:"), fmt.Sprintf("%x", sha1.Sum(url))...) - s.store.On("Get", urlKey).Return(nil, storeErr) + s.store.On("FindByURL", url).Return(nil, storeErr) - rUID, rURL, err := s.shortener.Shorten(url) - s.Nil(rUID) - s.Nil(rURL) + result, err := s.shortener.Shorten(url) + s.Nil(result.UID) + s.Nil(result.URL) s.EqualError(err, storeErr.Error()) + s.store.AssertExpectations(s.T()) } func (s *Base58ShortenerSuite) TestLookupExisting() { - url := []byte("https://google.com/") uid := []byte("ig") + url := []byte("https://google.com/") + record := storage.Record{UID: uid, URL: url} - s.store.On("Get", append([]byte("uid:"), uid...)).Return(url, nil) - - rURL, err := s.shortener.Lookup(uid) + s.store.On("FindByUID", uid).Return(&record, nil) + result, err := s.shortener.Lookup(uid) s.NoError(err) - s.Equal(url, rURL) + s.Equal(uid, result.UID) + s.Equal(url, result.URL) s.store.AssertExpectations(s.T()) } func (s *Base58ShortenerSuite) TestLookupNonExistant() { uid := []byte("ig") - s.store.On("Get", append([]byte("uid:"), uid...)).Return(nil, s.errNotFound) - - rURL, err := s.shortener.Lookup(uid) + s.store.On("FindByUID", uid).Return(&storage.Record{}, storage.ErrNotFound) + result, err := s.shortener.Lookup(uid) s.EqualError(err, "not found") - s.Nil(rURL) + s.Nil(result.UID) + s.Nil(result.URL) s.store.AssertExpectations(s.T()) } func (s *Base58ShortenerSuite) TestLookupInvalid() { uid := []byte("ig\"; drop table haha") - rURL, err := s.shortener.Lookup(uid) + result, err := s.shortener.Lookup(uid) s.EqualError(err, "invalid UID") - s.Nil(rURL) + s.Nil(result.UID) + s.Nil(result.URL) s.store.AssertExpectations(s.T()) } diff --git a/shortener/mocks/Store.go b/shortener/mocks/Store.go index feb6ab6..4504e42 100644 --- a/shortener/mocks/Store.go +++ b/shortener/mocks/Store.go @@ -22,36 +22,114 @@ func (_m *Store) Close() error { return r0 } -// Delete provides a mock function with given fields: _a0 -func (_m *Store) Delete(_a0 []byte) error { - ret := _m.Called(_a0) +// Create provides a mock function with given fields: UID, URL +func (_m *Store) Create(UID []byte, URL []byte) (*storage.Record, error) { + ret := _m.Called(UID, URL) - var r0 error - if rf, ok := ret.Get(0).(func([]byte) error); ok { - r0 = rf(_a0) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Get provides a mock function with given fields: _a0 -func (_m *Store) Get(_a0 []byte) ([]byte, error) { - ret := _m.Called(_a0) - - var r0 []byte - if rf, ok := ret.Get(0).(func([]byte) []byte); ok { - r0 = rf(_a0) + var r0 *storage.Record + if rf, ok := ret.Get(0).(func([]byte, []byte) *storage.Record); ok { + r0 = rf(UID, URL) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) + r0 = ret.Get(0).(*storage.Record) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte, []byte) error); ok { + r1 = rf(UID, URL) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteByUID provides a mock function with given fields: UID +func (_m *Store) DeleteByUID(UID []byte) (*storage.Record, error) { + ret := _m.Called(UID) + + var r0 *storage.Record + if rf, ok := ret.Get(0).(func([]byte) *storage.Record); ok { + r0 = rf(UID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*storage.Record) } } var r1 error if rf, ok := ret.Get(1).(func([]byte) error); ok { - r1 = rf(_a0) + r1 = rf(UID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteByURL provides a mock function with given fields: URL +func (_m *Store) DeleteByURL(URL []byte) (*storage.Record, error) { + ret := _m.Called(URL) + + var r0 *storage.Record + if rf, ok := ret.Get(0).(func([]byte) *storage.Record); ok { + r0 = rf(URL) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*storage.Record) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(URL) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FindByUID provides a mock function with given fields: UID +func (_m *Store) FindByUID(UID []byte) (*storage.Record, error) { + ret := _m.Called(UID) + + var r0 *storage.Record + if rf, ok := ret.Get(0).(func([]byte) *storage.Record); ok { + r0 = rf(UID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*storage.Record) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(UID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FindByURL provides a mock function with given fields: URL +func (_m *Store) FindByURL(URL []byte) (*storage.Record, error) { + ret := _m.Called(URL) + + var r0 *storage.Record + if rf, ok := ret.Get(0).(func([]byte) *storage.Record); ok { + r0 = rf(URL) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*storage.Record) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(URL) } else { r1 = ret.Error(1) } @@ -80,18 +158,4 @@ func (_m *Store) NextSequence() (int, error) { return r0, r1 } -// Set provides a mock function with given fields: _a0, _a1 -func (_m *Store) Set(_a0 []byte, _a1 []byte) error { - ret := _m.Called(_a0, _a1) - - var r0 error - if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok { - r0 = rf(_a0, _a1) - } else { - r0 = ret.Error(0) - } - - return r0 -} - var _ storage.Store = (*Store)(nil) diff --git a/shortener/shortener.go b/shortener/shortener.go index 4679e54..96484c0 100644 --- a/shortener/shortener.go +++ b/shortener/shortener.go @@ -1,7 +1,9 @@ package shortener +import "github.com/jimeh/ozu.io/storage" + // Shortener defines a shortener interface for shortening URLs. type Shortener interface { - Shorten([]byte) ([]byte, []byte, error) - Lookup([]byte) ([]byte, error) + Shorten([]byte) (*storage.Record, error) + Lookup([]byte) (*storage.Record, error) } diff --git a/web/api_handler.go b/web/api_handler.go index 3b6940a..9326a7b 100644 --- a/web/api_handler.go +++ b/web/api_handler.go @@ -20,24 +20,24 @@ type APIHandler struct { // Shorten shortens given URL. func (h *APIHandler) Shorten(c *routing.Context) error { - uid, url, err := h.shortener.Shorten(c.FormValue("url")) + record, err := h.shortener.Shorten(c.FormValue("url")) if err != nil { return h.respondWithError(c, err) } - r := makeResponse(c, uid, url) + r := makeResponse(c, record) return h.respond(c, &r) } // Lookup shortened UID. func (h *APIHandler) Lookup(c *routing.Context) error { uid := c.FormValue("uid") - url, err := h.shortener.Lookup(uid) + record, err := h.shortener.Lookup(uid) if err != nil { return h.respondWithError(c, err) } - r := makeResponse(c, uid, url) + r := makeResponse(c, record) return h.respond(c, &r) } diff --git a/web/handler.go b/web/handler.go index 0480f4b..95aaebd 100644 --- a/web/handler.go +++ b/web/handler.go @@ -71,12 +71,12 @@ func (h *Handler) Index(c *routing.Context) error { rawURL := c.FormValue("url") if len(rawURL) > 0 { - uid, url, err := h.shortener.Shorten(rawURL) + record, err := h.shortener.Shorten(rawURL) if err != nil { return h.respond(c, template, makeErrResponse(err)) } - r := makeResponse(c, uid, url) + r := makeResponse(c, record) return h.respond(c, template, r) } @@ -106,19 +106,19 @@ func (h *Handler) Static(c *routing.Context) error { func (h *Handler) LookupAndRedirect(c *routing.Context) error { uid := []byte(c.Param("uid")) - url, err := h.shortener.Lookup(uid) + record, err := h.shortener.Lookup(uid) if err != nil { return h.NotFound(c) } - r := makeResponse(c, uid, url) + r := makeResponse(c, record) c.Response.Header.Set("Pragma", "no-cache") c.Response.Header.Set("Expires", "Mon, 01 Jan 1990 00:00:00 GMT") c.Response.Header.Set("X-XSS-Protection", "1; mode=block") c.Response.Header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") - c.Redirect(string(url), fasthttp.StatusMovedPermanently) + c.Redirect(string(record.URL), fasthttp.StatusMovedPermanently) c.Response.Header.Set("Connection", "close") c.Response.Header.Set("X-Content-Type-Options", "nosniff") c.Response.Header.Set("Accept-Ranges", "none") diff --git a/web/handler_helpers.go b/web/handler_helpers.go index e1a7198..3155b7f 100644 --- a/web/handler_helpers.go +++ b/web/handler_helpers.go @@ -3,14 +3,15 @@ package web import ( "net/url" + "github.com/jimeh/ozu.io/storage" "github.com/qiangxue/fasthttp-routing" ) -func makeResponse(c *routing.Context, uid []byte, url []byte) Response { +func makeResponse(c *routing.Context, r *storage.Record) Response { return Response{ - UID: string(uid), - URL: makeShortURL(c, uid), - Target: string(url), + UID: string(r.UID), + URL: makeShortURL(c, r.UID), + Target: string(r.URL), } }