diff --git a/chttp/cookieauth.go b/chttp/cookieauth.go index 47b09a44..ebf290ed 100644 --- a/chttp/cookieauth.go +++ b/chttp/cookieauth.go @@ -31,7 +31,8 @@ type CookieAuth struct { client *Client // transport stores the original transport that is overridden by this auth // mechanism - transport http.RoundTripper + transport http.RoundTripper + authExpiry *time.Time } var _ Authenticator = &CookieAuth{} @@ -48,26 +49,6 @@ func (a *CookieAuth) Authenticate(c *Client) error { return nil } -// shouldAuth returns true if there is no cookie set, or if it has expired. -func (a *CookieAuth) shouldAuth(req *http.Request) bool { - if _, err := req.Cookie(kivik.SessionCookieName); err == nil { - return false - } - cookie := a.Cookie() - if cookie == nil { - return true - } - if !cookie.Expires.IsZero() { - return cookie.Expires.Before(time.Now().Add(time.Minute)) - } - // If we get here, it means the server did not include an expiry time in - // the session cookie. Some CouchDB configurations do this, but rather than - // re-authenticating for every request, we'll let the session expire. A - // future change might be to make a client-configurable option to set the - // re-authentication timeout. - return false -} - // Cookie returns the current session cookie if found, or nil if not. func (a *CookieAuth) Cookie() *http.Cookie { if a.client == nil { @@ -102,24 +83,41 @@ func (a *CookieAuth) RoundTrip(req *http.Request) (*http.Response, error) { // set to expire yesterday to allow us to ditch it cookie.Expires = time.Now().AddDate(0, 0, -1) a.client.Jar.SetCookies(a.client.dsn, []*http.Cookie{cookie}) + a.client.authMU.Lock() + a.authExpiry = nil + a.client.authMU.Unlock() } } return res, nil } +// shouldAuth returns true if there is no cookie set, or if it has expired. +func (a *CookieAuth) shouldAuth(req *http.Request) bool { + if _, err := req.Cookie(kivik.SessionCookieName); err == nil { + return false + } + if a.authExpiry == nil { + return true + } + if !a.authExpiry.IsZero() { + return a.authExpiry.Before(time.Now()) + } + // If we get here, it means the server did not include an expiry time in + // the session cookie. Some CouchDB configurations do this, but rather than + // re-authenticating for every request, we'll let the session expire. A + // future change might be to make a client-configurable option to set the + // re-authentication timeout. + return false +} + func (a *CookieAuth) authenticate(req *http.Request) error { ctx := req.Context() if inProg, _ := ctx.Value(authInProgress).(bool); inProg { return nil } - if !a.shouldAuth(req) { - return nil - } a.client.authMU.Lock() defer a.client.authMU.Unlock() - if c := a.Cookie(); c != nil { - // In case another simultaneous process authenticated successfully first - req.AddCookie(c) + if !a.shouldAuth(req) { return nil } ctx = context.WithValue(ctx, authInProgress, true) @@ -129,9 +127,28 @@ func (a *CookieAuth) authenticate(req *http.Request) error { HeaderIdempotencyKey: []string{}, }, } - if _, err := a.client.DoError(ctx, http.MethodPost, "/_session", opts); err != nil { + res, err := a.client.DoError(ctx, http.MethodPost, "/_session", opts) + if err != nil { return err } + for _, cookie := range res.Cookies() { + if cookie.Name == kivik.SessionCookieName { + expiry := cookie.Expires + if !expiry.IsZero() { + expiry = expiry.Add(-time.Minute) + } + a.authExpiry = &expiry + break + } + } + + cookies := req.Cookies() + req.Header.Del("Cookie") + for _, cookie := range cookies { + if cookie.Name != kivik.SessionCookieName { + req.AddCookie(cookie) + } + } if c := a.Cookie(); c != nil { req.AddCookie(c) } diff --git a/chttp/cookieauth_test.go b/chttp/cookieauth_test.go index c9c99d68..f4181bd4 100644 --- a/chttp/cookieauth_test.go +++ b/chttp/cookieauth_test.go @@ -20,7 +20,6 @@ import ( "net/url" "strings" "testing" - "time" "gitlab.com/flimzy/testy" "golang.org/x/net/publicsuffix" @@ -177,92 +176,6 @@ func (j *dummyJar) SetCookies(_ *url.URL, cookies []*http.Cookie) { *j = cookies } -func Test_shouldAuth(t *testing.T) { - type tt struct { - a *CookieAuth - req *http.Request - want bool - } - - tests := testy.NewTable() - tests.Add("no session", tt{ - a: &CookieAuth{}, - req: httptest.NewRequest("GET", "/", nil), - want: true, - }) - tests.Add("authed request", func() interface{} { - req := httptest.NewRequest("GET", "/", nil) - req.AddCookie(&http.Cookie{Name: kivik.SessionCookieName}) - return tt{ - a: &CookieAuth{}, - req: req, - want: false, - } - }) - tests.Add("valid session", func() interface{} { - c, _ := New("http://example.com/") - c.Jar = &dummyJar{&http.Cookie{ - Name: kivik.SessionCookieName, - Expires: time.Now().Add(20 * time.Minute), - }} - a := &CookieAuth{client: c} - - return tt{ - a: a, - req: httptest.NewRequest("GET", "/", nil), - want: false, - } - }) - tests.Add("expired session", func() interface{} { - c, _ := New("http://example.com/") - c.Jar = &dummyJar{&http.Cookie{ - Name: kivik.SessionCookieName, - Expires: time.Now().Add(-20 * time.Second), - }} - a := &CookieAuth{client: c} - - return tt{ - a: a, - req: httptest.NewRequest("GET", "/", nil), - want: true, - } - }) - tests.Add("no expiry time", func() interface{} { - c, _ := New("http://example.com/") - c.Jar = &dummyJar{&http.Cookie{ - Name: kivik.SessionCookieName, - }} - a := &CookieAuth{client: c} - - return tt{ - a: a, - req: httptest.NewRequest("GET", "/", nil), - want: false, - } - }) - tests.Add("about to expire", func() interface{} { - c, _ := New("http://example.com/") - c.Jar = &dummyJar{&http.Cookie{ - Name: kivik.SessionCookieName, - Expires: time.Now().Add(20 * time.Second), - }} - a := &CookieAuth{client: c} - - return tt{ - a: a, - req: httptest.NewRequest("GET", "/", nil), - want: true, - } - }) - - tests.Run(t, func(t *testing.T, tt tt) { - got := tt.a.shouldAuth(tt.req) - if got != tt.want { - t.Errorf("Want %t, got %t", tt.want, got) - } - }) -} - func Test401Response(t *testing.T) { var sessCounter, getCounter int s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -273,7 +186,7 @@ func Test401Response(t *testing.T) { if r.URL.Path == "/_session" { sessCounter++ if sessCounter > 2 { - t.Fatal("Too many calls to /_session") + t.Fatal("Too many requests to /_session") } var cookie string if sessCounter == 1 { @@ -286,26 +199,26 @@ func Test401Response(t *testing.T) { h.Add("Set-Cookie", "AuthSession="+cookie+"; Version=1; Path=/; HttpOnly") w.WriteHeader(200) _, _ = w.Write([]byte(`{"ok":true,"name":"admin","roles":["_admin"]}`)) - } else { - getCounter++ - cookie := r.Header.Get("Cookie") - if !(strings.Contains(cookie, "AuthSession=")) { - t.Errorf("Expected cookie not found: %s", cookie) - } - // because of the way the request is baked before the auth loop - // cookies other than the auth cookie set when calling _session won't - // get applied to requests until after that first request. - if getCounter > 1 && !strings.Contains(cookie, "Other=foo") { - t.Errorf("Expected cookie not found: %s", cookie) - } - if getCounter == 2 { - w.WriteHeader(401) - _, _ = w.Write([]byte(`{"error":"unauthorized","reason":"You are not authorized to access this db."}`)) - return - } - w.WriteHeader(200) - _, _ = w.Write([]byte(`{"ok":true}`)) + return + } + getCounter++ + cookie := r.Header.Get("Cookie") + if !strings.Contains(cookie, "AuthSession=") { + t.Errorf("Expected cookie not found: %s", cookie) + } + // because of the way the request is baked before the auth loop + // cookies other than the auth cookie set when calling _session won't + // get applied to requests until after that first request. + if getCounter > 1 && !strings.Contains(cookie, "Other=foo") { + t.Errorf("Expected cookie not found: %s", cookie) + } + if getCounter == 2 { + w.WriteHeader(401) + _, _ = w.Write([]byte(`{"error":"unauthorized","reason":"You are not authorized to access this db."}`)) + return } + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"ok":true}`)) })) c, err := New(s.URL)