From cdaaec6b0b763141476562047578844e6105ec7a Mon Sep 17 00:00:00 2001 From: Jim Myhrberg Date: Tue, 24 May 2022 17:39:12 +0100 Subject: [PATCH] feat: compare on GET requests with non-empty query string Also add a /about page so there's a non-root page that does not return 404. --- main.go | 50 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/main.go b/main.go index 901fad4..fd9d02f 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "strings" + "time" "gopkg.in/alecthomas/kingpin.v2" ) @@ -44,13 +45,26 @@ Case-insensitive string comparison, as an API. Because ¯\_(ツ)_/¯ Example usage: curl -X POST -F "a=Foo Bar" -F "b=FOO BAR" %s://%s/ -curl -X POST "%s://%s/?a=Foo+Bar&b=FOO+BAR"`, +curl -X GET "%s://%s/?a=Foo+Bar&b=FOO+BAR" +`, name, version, scheme, r.Host, scheme, r.Host) if err != nil { log.Fatal(err) } } +func aboutHandler(w http.ResponseWriter, _ *http.Request) { + _, err := fmt.Fprintf(w, + `%s %s + +https://github.com/jimeh/casecmp +`, + name, version) + if err != nil { + log.Fatal(err) + } +} + func casecmpHandler(w http.ResponseWriter, r *http.Request) { a := r.FormValue("a") b := r.FormValue("b") @@ -59,22 +73,24 @@ func casecmpHandler(w http.ResponseWriter, r *http.Request) { if strings.EqualFold(string(a), string(b)) { resp = "1" } - _, err := fmt.Fprintf(w, resp) + _, err := fmt.Fprint(w, resp) if err != nil { log.Fatal(err) } } -func rootHandler(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/" { - http.NotFound(w, r) - return - } - - if r.Method == "GET" { +func handler(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/": + if r.Method != "GET" || r.URL.RawQuery != "" { + casecmpHandler(w, r) + return + } indexHandler(w, r) - } else { - casecmpHandler(w, r) + case "/about": + aboutHandler(w, r) + default: + http.NotFound(w, r) } } @@ -90,8 +106,6 @@ func printVersion() { } func startServer() { - http.HandleFunc("/", rootHandler) - if *portFlag == "" { envPort := os.Getenv("PORT") if envPort != "" { @@ -107,7 +121,15 @@ func startServer() { address := *bindFlag + ":" + *portFlag fmt.Printf("Listening on %s\n", address) - log.Fatal(http.ListenAndServe(address, nil)) + srv := &http.Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + IdleTimeout: 30 * time.Second, + Handler: http.HandlerFunc(handler), + Addr: address, + } + + log.Fatal(srv.ListenAndServe()) } func main() {