diff --git a/web/auth.go b/web/auth.go index b9d905af..8af5ef7c 100644 --- a/web/auth.go +++ b/web/auth.go @@ -528,12 +528,7 @@ func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.FormValue("backto") != "" { backto := r.FormValue("backto") - - // to prevent redirecting to an external URL we only set the session data when: - // 1. we fail to parse backto - // 2. backto does not include a hostname - parsed, err := url.ParseRequestURI(backto) - if err != nil || parsed.Hostname() == "" { + if IsValidBacktoURL(backto) { session.Data["backto"] = backto } } diff --git a/web/backto.go b/web/backto.go new file mode 100644 index 00000000..fc96f4df --- /dev/null +++ b/web/backto.go @@ -0,0 +1,43 @@ +package web + +import ( + "net/url" + "strings" +) + +var hostnameWhitelist = map[string]bool{ + "localhost": true, + "viam.dev": true, + "viam.com": true, +} + +func isWhitelisted(hostname string) bool { + return hostnameWhitelist[hostname] +} + +// IsValidBacktoURL returns true if the passed string is a secure URL to a whitelisted +// hostname. The whitelisted hostnames are: "localhost", "viam.dev", and "viam.com". +// +// - https://example.com -> false +// - http://viam.com/path/name -> false +// - https://viam.com/path/name -> true +func IsValidBacktoURL(path string) bool { + normalized := strings.ReplaceAll(path, "\\", "/") + url, err := url.ParseRequestURI(normalized) + if err != nil { + // ignore invalid URLs/URL components + return false + } + + if url.Scheme != "" && url.Scheme != "https" { + // ignore non-secure URLs + return false + } + + if isWhitelisted(url.Hostname()) { + // ignore non-whitelisted hosts + return true + } + + return false +} diff --git a/web/backto_test.go b/web/backto_test.go new file mode 100644 index 00000000..9a796a18 --- /dev/null +++ b/web/backto_test.go @@ -0,0 +1,66 @@ +package web + +import ( + "testing" + + "go.viam.com/test" +) + +func TestIsValidBacktoURL(t *testing.T) { + t.Run("rejects external URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("https://example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("http://example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("ftp://example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("://example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("www.example.com"), test.ShouldBeFalse) + }) + + t.Run("rejects invalid production URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("http://viam.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("ftp://viam.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("://viam.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//viam.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//viam.com/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("viam.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("viam.com/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("www.viam.com"), test.ShouldBeFalse) + }) + + t.Run("accepts valid production URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("https://viam.com"), test.ShouldBeTrue) + test.That(t, IsValidBacktoURL("https://viam.com/some/path"), test.ShouldBeTrue) + }) + + t.Run("rejects invalid staging URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("http://viam.dev"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("ftp://viam.dev"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("://viam.dev"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//viam.dev"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//viam.dev/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("viam.dev"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("viam.dev/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("www.viam.dev"), test.ShouldBeFalse) + }) + + t.Run("accepts valid staging URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("https://viam.dev"), test.ShouldBeTrue) + test.That(t, IsValidBacktoURL("https://viam.dev/some/path"), test.ShouldBeTrue) + }) + + t.Run("rejects invalid local URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("http://localhost"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("ftp://localhost"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("://localhost"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//localhost"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//localhost/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("localhost"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("localhost/some/path"), test.ShouldBeFalse) + }) + + t.Run("accepts valid local URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("https://localhost"), test.ShouldBeTrue) + test.That(t, IsValidBacktoURL("https://localhost/some/path"), test.ShouldBeTrue) + }) +}