feat: compare on GET requests with non-empty query string

Also add a /about page so there's a non-root page that does not return
404.
This commit is contained in:
2022-05-24 17:39:12 +01:00
parent 51717ebcd1
commit cdaaec6b0b

50
main.go
View File

@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"os" "os"
"strings" "strings"
"time"
"gopkg.in/alecthomas/kingpin.v2" "gopkg.in/alecthomas/kingpin.v2"
) )
@@ -44,13 +45,26 @@ Case-insensitive string comparison, as an API. Because ¯\_(ツ)_/¯
Example usage: Example usage:
curl -X POST -F "a=Foo Bar" -F "b=FOO BAR" %s://%s/ curl -X POST -F "a=Foo Bar" -F "b=FOO BAR" %s://%s/
curl -X POST "%s://%s/?a=Foo+Bar&b=FOO+BAR"`, curl -X GET "%s://%s/?a=Foo+Bar&b=FOO+BAR"
`,
name, version, scheme, r.Host, scheme, r.Host) name, version, scheme, r.Host, scheme, r.Host)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
func aboutHandler(w http.ResponseWriter, _ *http.Request) {
_, err := fmt.Fprintf(w,
`%s %s
https://github.com/jimeh/casecmp
`,
name, version)
if err != nil {
log.Fatal(err)
}
}
func casecmpHandler(w http.ResponseWriter, r *http.Request) { func casecmpHandler(w http.ResponseWriter, r *http.Request) {
a := r.FormValue("a") a := r.FormValue("a")
b := r.FormValue("b") b := r.FormValue("b")
@@ -59,22 +73,24 @@ func casecmpHandler(w http.ResponseWriter, r *http.Request) {
if strings.EqualFold(string(a), string(b)) { if strings.EqualFold(string(a), string(b)) {
resp = "1" resp = "1"
} }
_, err := fmt.Fprintf(w, resp) _, err := fmt.Fprint(w, resp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
func rootHandler(w http.ResponseWriter, r *http.Request) { func handler(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" { switch r.URL.Path {
http.NotFound(w, r) case "/":
return if r.Method != "GET" || r.URL.RawQuery != "" {
} casecmpHandler(w, r)
return
if r.Method == "GET" { }
indexHandler(w, r) indexHandler(w, r)
} else { case "/about":
casecmpHandler(w, r) aboutHandler(w, r)
default:
http.NotFound(w, r)
} }
} }
@@ -90,8 +106,6 @@ func printVersion() {
} }
func startServer() { func startServer() {
http.HandleFunc("/", rootHandler)
if *portFlag == "" { if *portFlag == "" {
envPort := os.Getenv("PORT") envPort := os.Getenv("PORT")
if envPort != "" { if envPort != "" {
@@ -107,7 +121,15 @@ func startServer() {
address := *bindFlag + ":" + *portFlag address := *bindFlag + ":" + *portFlag
fmt.Printf("Listening on %s\n", address) fmt.Printf("Listening on %s\n", address)
log.Fatal(http.ListenAndServe(address, nil)) srv := &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 30 * time.Second,
Handler: http.HandlerFunc(handler),
Addr: address,
}
log.Fatal(srv.ListenAndServe())
} }
func main() { func main() {