From bee5bd5f92f05acbaa7eb97c755f4f28df9d89c0 Mon Sep 17 00:00:00 2001 From: davidnewhall2 Date: Tue, 4 Feb 2020 12:04:50 -0800 Subject: [PATCH] play the csrf game --- core/unifi/types.go | 1 + core/unifi/unifi.go | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/core/unifi/types.go b/core/unifi/types.go index 3f682919..cb0d778d 100644 --- a/core/unifi/types.go +++ b/core/unifi/types.go @@ -87,6 +87,7 @@ type Unifi struct { *Config *server isNew bool + csrf string } // server is the /status endpoint from the Unifi controller. diff --git a/core/unifi/unifi.go b/core/unifi/unifi.go index a72b9613..7254091c 100644 --- a/core/unifi/unifi.go +++ b/core/unifi/unifi.go @@ -66,7 +66,6 @@ func NewUnifi(config *Config) (*Unifi, error) { // Login is a helper method. It can be called to grab a new authentication cookie. func (u *Unifi) Login() error { - APILoginPath := u.path(APILoginPath) start := time.Now() // magic login. @@ -83,11 +82,11 @@ func (u *Unifi) Login() error { defer resp.Body.Close() // we need no data here. _, _ = io.Copy(ioutil.Discard, resp.Body) // avoid leaking. u.DebugLog("Requested %s: elapsed %v, returned %d bytes", - APILoginPath, time.Since(start).Round(time.Millisecond), resp.ContentLength) + req.URL, time.Since(start).Round(time.Millisecond), resp.ContentLength) if resp.StatusCode != http.StatusOK { return fmt.Errorf("authentication failed (user: %s): %s (status: %s)", - u.User, u.URL+APILoginPath, resp.Status) + u.User, req.URL, resp.Status) } return nil @@ -127,7 +126,7 @@ func (u *Unifi) checkNewStyleAPI() error { if resp.StatusCode == http.StatusOK { // The new version returns a "200" for a / request. u.isNew = true - u.DebugLog("Using NEW UniFi controller API paths!") + u.DebugLog("Using NEW UniFi controller API paths for %s", req.URL) } // The old version returns a "302" (to /manage) for a / request @@ -156,7 +155,7 @@ func (u *Unifi) GetData(apiPath string, v interface{}, params ...string) error { } u.DebugLog("Requested %s: elapsed %v, returned %d bytes", - u.path(apiPath), time.Since(start).Round(time.Millisecond), len(body)) + u.URL+u.path(apiPath), time.Since(start).Round(time.Millisecond), len(body)) return json.Unmarshal(body, v) } @@ -177,14 +176,16 @@ func (u *Unifi) UniReq(apiPath string, params string) (req *http.Request, err er return } + // Add the saved CSRF header. + req.Header.Set("X-CSRF-Token", u.csrf) req.Header.Add("Accept", "application/json") req.Header.Add("Content-Type", "application/json; charset=utf-8") if u.Client.Jar != nil { - parsedURL, _ := url.Parse(u.URL + apiPath) - u.DebugLog("Requesting %s, with params: %v, cookies: %d", apiPath, params != "", len(u.Client.Jar.Cookies(parsedURL))) + parsedURL, _ := url.Parse(req.URL.String()) + u.DebugLog("Requesting %s, with params: %v, cookies: %d", req.URL, params != "", len(u.Client.Jar.Cookies(parsedURL))) } else { - u.DebugLog("Requesting %s, with params: %v,", apiPath, params != "") + u.DebugLog("Requesting %s, with params: %v,", req.URL, params != "") } return @@ -209,8 +210,13 @@ func (u *Unifi) GetJSON(apiPath string, params ...string) ([]byte, error) { return body, err } + // Save the returned CSRF header. + if csrf := resp.Header.Get("x-csrf-token"); csrf != "" { + u.csrf = resp.Header.Get("x-csrf-token") + } + if resp.StatusCode != http.StatusOK { - err = fmt.Errorf("invalid status code from server %s", resp.Status) + err = fmt.Errorf("invalid status code from server for %s: %s", req.URL, resp.Status) } return body, err