From e2500cdc9cf43f7968023ea24d58f974f6fe249c Mon Sep 17 00:00:00 2001 From: Christian Boitel Date: Thu, 6 Feb 2020 10:05:27 +0100 Subject: [PATCH 1/4] Use custom headers when reloading --- apisprout.go | 183 +++++++++++++++++++++++++++------------------------ 1 file changed, 96 insertions(+), 87 deletions(-) diff --git a/apisprout.go b/apisprout.go index 44fc633..506b364 100644 --- a/apisprout.go +++ b/apisprout.go @@ -597,6 +597,48 @@ var handler = func(rr *RefreshableRouter) http.Handler { }) } +// +func loadSwaggerFromUri(uri string) (data []byte, err error) { + if strings.HasPrefix(uri, "http") { + req, httpErr := http.NewRequest("GET", uri, nil) + if httpErr != nil { + err = httpErr + return + } + if customHeader := viper.GetString("header"); customHeader != "" { + header := strings.Split(customHeader, ":") + if len(header) != 2 { + err = errors.New("Header format is invalid") + } else { + req.Header.Add(strings.TrimSpace(header[0]), strings.TrimSpace(header[1])) + } + } + if err != nil { + return + } + + client := &http.Client{} + resp, httpErr := client.Do(req) + if httpErr != nil { + err = httpErr + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf("Server at %s reported %d status code", uri, resp.StatusCode) + return + } + data, err = ioutil.ReadAll(resp.Body) + if err != nil { + return + } + } else { + data, err = ioutil.ReadFile(uri) + } + + return data, err +} + // server loads an OpenAPI file and runs a mock server using the paths and // examples defined in the file. func server(cmd *cobra.Command, args []string) { @@ -611,83 +653,58 @@ func server(cmd *cobra.Command, args []string) { // Load either from an HTTP URL or from a local file depending on the passed // in value. - if strings.HasPrefix(uri, "http") { - req, err := http.NewRequest("GET", uri, nil) - if err != nil { - log.Fatal(err) - } - if customHeader := viper.GetString("header"); customHeader != "" { - header := strings.Split(customHeader, ":") - if len(header) != 2 { - log.Fatal("Header format is invalid.") - } - req.Header.Add(strings.TrimSpace(header[0]), strings.TrimSpace(header[1])) - } - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - log.Fatal(err) - } + data, err = loadSwaggerFromUri(uri) + if err != nil { + log.Fatal(err) + } - data, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - log.Fatal(err) + if viper.GetBool("watch") { + if strings.HasPrefix(uri, "http") { + log.Fatal(errors.New("Watching a URL is not supported.")) } - if viper.GetBool("watch") { - log.Fatal("Watching a URL is not supported.") - } - } else { - data, err = ioutil.ReadFile(uri) + // Set up a new filesystem watcher and reload the router every time + // the file has changed on disk. + watcher, err := fsnotify.NewWatcher() if err != nil { log.Fatal(err) } - - if viper.GetBool("watch") { - // Set up a new filesystem watcher and reload the router every time - // the file has changed on disk. - watcher, err := fsnotify.NewWatcher() - if err != nil { - log.Fatal(err) - } - defer watcher.Close() - - go func() { - // Since waiting for events or errors is blocking, we do this in a - // goroutine. It loops forever here but will exit when the process - // is finished, e.g. when you `ctrl+c` to exit. - for { - select { - case event, ok := <-watcher.Events: - if !ok { - return + defer watcher.Close() + + go func() { + // Since waiting for events or errors is blocking, we do this in a + // goroutine. It loops forever here but will exit when the process + // is finished, e.g. when you `ctrl+c` to exit. + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Write == fsnotify.Write { + fmt.Printf("🌙 Reloading %s\n", uri) + data, err = ioutil.ReadFile(uri) + if err != nil { + log.Fatal(err) } - if event.Op&fsnotify.Write == fsnotify.Write { - fmt.Printf("🌙 Reloading %s\n", uri) - data, err = ioutil.ReadFile(uri) - if err != nil { - log.Fatal(err) - } - if s, r, err := load(uri, data); err == nil { - swagger = s - rr.Set(r) - } else { - log.Printf("ERROR: Unable to load OpenAPI document: %s", err) - } - } - case err, ok := <-watcher.Errors: - if !ok { - return + if s, r, err := load(uri, data); err == nil { + swagger = s + rr.Set(r) + } else { + log.Printf("ERROR: Unable to load OpenAPI document: %s", err) } - fmt.Println("error:", err) } + case err, ok := <-watcher.Errors: + if !ok { + return + } + fmt.Println("error:", err) } - }() + } + }() - watcher.Add(uri) - } + watcher.Add(uri) } swagger, router, err := load(uri, data) @@ -699,31 +716,23 @@ func server(cmd *cobra.Command, args []string) { if strings.HasPrefix(uri, "http") { http.HandleFunc("/__reload", func(w http.ResponseWriter, r *http.Request) { - resp, err := http.Get(uri) - if err != nil { - log.Printf("ERROR: %v", err) - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("error while reloading")) - return + log.Printf("🌙 Reloading %s\n", uri) + data, err = loadSwaggerFromUri(uri) + if err == nil { + if s, r, err := load(uri, data); err == nil { + swagger = s + rr.Set(r) + } } - - data, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - log.Printf("ERROR: %v", err) + if err == nil { + log.Printf("Reloaded from %s", uri) + w.WriteHeader(200) + w.Write([]byte("reloaded")) + } else { + log.Printf("ERROR: %s", err) w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("error while parsing")) - return - } - - if s, r, err := load(uri, data); err == nil { - swagger = s - rr.Set(r) + w.Write([]byte("error while reloading")) } - - w.WriteHeader(200) - w.Write([]byte("reloaded")) - log.Printf("Reloaded from %s", uri) }) } From 4cb20a260115f79f743c20c0f16038265ce56060 Mon Sep 17 00:00:00 2001 From: Christian Boitel Date: Thu, 6 Feb 2020 10:25:31 +0100 Subject: [PATCH 2/4] Align watch/__reload to report report errors --- apisprout.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apisprout.go b/apisprout.go index 506b364..436e86b 100644 --- a/apisprout.go +++ b/apisprout.go @@ -683,9 +683,9 @@ func server(cmd *cobra.Command, args []string) { } if event.Op&fsnotify.Write == fsnotify.Write { fmt.Printf("🌙 Reloading %s\n", uri) - data, err = ioutil.ReadFile(uri) + data, err = loadSwaggerFromUri(uri) if err != nil { - log.Fatal(err) + log.Printf("ERROR: %s", err) } if s, r, err := load(uri, data); err == nil { From 7f560ea4ef31c629bbb62058093b9b690b1302b4 Mon Sep 17 00:00:00 2001 From: Christian Boitel Date: Thu, 6 Feb 2020 10:46:34 +0100 Subject: [PATCH 3/4] Keep original api if reload failed in watch --- apisprout.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/apisprout.go b/apisprout.go index 436e86b..3f266d4 100644 --- a/apisprout.go +++ b/apisprout.go @@ -686,20 +686,20 @@ func server(cmd *cobra.Command, args []string) { data, err = loadSwaggerFromUri(uri) if err != nil { log.Printf("ERROR: %s", err) - } - - if s, r, err := load(uri, data); err == nil { - swagger = s - rr.Set(r) } else { - log.Printf("ERROR: Unable to load OpenAPI document: %s", err) + if s, r, err := load(uri, data); err == nil { + swagger = s + rr.Set(r) + } else { + log.Printf("ERROR: Unable to load OpenAPI document: %s", err) + } } } case err, ok := <-watcher.Errors: if !ok { return } - fmt.Println("error:", err) + log.Printf("ERROR: %s", err) } } }() From 2cb323768c6c2c5939c7f5338454cf1f468457a3 Mon Sep 17 00:00:00 2001 From: Christian Boitel Date: Thu, 6 Feb 2020 10:58:36 +0100 Subject: [PATCH 4/4] Allow __reload usage for file as well --- apisprout.go | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/apisprout.go b/apisprout.go index 3f266d4..dfcda34 100644 --- a/apisprout.go +++ b/apisprout.go @@ -714,27 +714,25 @@ func server(cmd *cobra.Command, args []string) { rr.Set(router) - if strings.HasPrefix(uri, "http") { - http.HandleFunc("/__reload", func(w http.ResponseWriter, r *http.Request) { - log.Printf("🌙 Reloading %s\n", uri) - data, err = loadSwaggerFromUri(uri) - if err == nil { - if s, r, err := load(uri, data); err == nil { - swagger = s - rr.Set(r) - } - } - if err == nil { - log.Printf("Reloaded from %s", uri) - w.WriteHeader(200) - w.Write([]byte("reloaded")) - } else { - log.Printf("ERROR: %s", err) - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("error while reloading")) + http.HandleFunc("/__reload", func(w http.ResponseWriter, r *http.Request) { + log.Printf("🌙 Reloading %s\n", uri) + data, err = loadSwaggerFromUri(uri) + if err == nil { + if s, r, err := load(uri, data); err == nil { + swagger = s + rr.Set(r) } - }) - } + } + if err == nil { + log.Printf("Reloaded from %s", uri) + w.WriteHeader(200) + w.Write([]byte("reloaded")) + } else { + log.Printf("ERROR: %s", err) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("error while reloading")) + } + }) // Add a health check route which returns 200 http.HandleFunc("/__health", func(w http.ResponseWriter, r *http.Request) {