aboutsummaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
authorFelix Hanley <felix@userspace.com.au>2020-03-04 01:28:17 +0000
committerFelix Hanley <felix@userspace.com.au>2020-03-04 01:28:17 +0000
commit9d687611493d2af60f6aedb3503dd85f6b3d49d6 (patch)
tree214dad63ca36551dd2ba0077ee8335602fd4e423 /cmd
parent1fbf1e1e1edc16e9a08fdb0d2012ca10f695f7ca (diff)
downloadsws-9d687611493d2af60f6aedb3503dd85f6b3d49d6.tar.gz
sws-9d687611493d2af60f6aedb3503dd85f6b3d49d6.tar.bz2
Refactor auth, template payloads and validation
Diffstat (limited to 'cmd')
-rw-r--r--cmd/server/auth.go122
-rw-r--r--cmd/server/handlers.go34
-rw-r--r--cmd/server/helpers.go4
-rw-r--r--cmd/server/hits.go16
-rw-r--r--cmd/server/main.go146
-rw-r--r--cmd/server/routes.go186
-rw-r--r--cmd/server/sites.go76
-rw-r--r--cmd/server/users.go62
8 files changed, 419 insertions, 227 deletions
diff --git a/cmd/server/auth.go b/cmd/server/auth.go
index ed0c75f..e9ded0d 100644
--- a/cmd/server/auth.go
+++ b/cmd/server/auth.go
@@ -16,71 +16,48 @@ const (
func handleLogin(db sws.UserStore, rndr Renderer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- email := r.PostFormValue("email")
- password := r.PostFormValue("password")
+ if r.Method == "POST" {
+ var user *sws.User
+ r, user = authUser(db, r)
+ if user != nil {
+ expiry := time.Now().Add(time.Hour)
- if email == "" || password == "" {
- //httpError(w, 406, "bad auth")
- r = flashSet(r, flashError, "invalid credentials")
- http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther)
- return
- }
-
- debug("authing email", email)
-
- user, err := db.GetUserByEmail(email)
- if err != nil || user == nil {
- //httpError(w, 404, err.Error())
- r = flashSet(r, flashError, "invalid user")
- http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther)
- return
- }
+ _, t, err := tokenAuth.Encode(jwt.MapClaims{
+ "user_id": *user.ID,
+ "exp": expiry.Unix(),
+ })
+ if err != nil {
+ log("failed to encode claims:", err)
+ r = flashSet(r, flashError, "internal error")
+ http.Redirect(w, r, flashURL(r, "/"), http.StatusSeeOther)
+ return
+ }
- if !user.Enabled {
- debug("user", email, "is disabled")
- //httpError(w, 403, "forbidden")
- r = flashSet(r, flashError, "access denied")
- http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther)
- return
+ http.SetCookie(w, &http.Cookie{
+ Name: "jwt",
+ Value: t,
+ HttpOnly: true,
+ Path: "/",
+ //Secure: true,
+ Expires: expiry,
+ })
+ r = r.WithContext(context.WithValue(r.Context(), "user", user))
+ r = flashSet(r, flashSuccess, "authenticated successfully")
+ qs := r.URL.Query()
+ if returnPath := qs.Get("return_to"); returnPath != "" {
+ qs.Del("return_to")
+ r.URL.RawQuery = qs.Encode()
+ debug("redirecting to", returnPath)
+ http.Redirect(w, r, flashURL(r, returnPath), http.StatusSeeOther)
+ }
+ http.Redirect(w, r, flashURL(r, "/sites"), http.StatusSeeOther)
+ }
}
- if err := user.ValidPassword(password); err != nil {
- //httpError(w, 401, err.Error())
- r = flashSet(r, flashError, "authentication failed")
- http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther)
- return
- }
- debug("user", email, "is authed")
-
- expiry := time.Now().Add(time.Hour)
-
- _, t, err := tokenAuth.Encode(jwt.MapClaims{
- "user_id": *user.ID,
- "exp": expiry.Unix(),
- })
- if err != nil {
+ payload := newTemplateData(r)
+ if err := rndr.Render(w, "login", payload); err != nil {
httpError(w, 500, err.Error())
- return
}
-
- http.SetCookie(w, &http.Cookie{
- Name: "jwt",
- Value: t,
- HttpOnly: true,
- Path: "/",
- //Secure: true,
- Expires: expiry,
- })
- r = r.WithContext(context.WithValue(r.Context(), "user", user))
- r = flashSet(r, flashSuccess, "authenticated successfully")
- qs := r.URL.Query()
- if returnPath := qs.Get("return_to"); returnPath != "" {
- qs.Del("return_to")
- r.URL.RawQuery = qs.Encode()
- debug("redirecting to", returnPath)
- http.Redirect(w, r, flashURL(r, returnPath), http.StatusSeeOther)
- }
- http.Redirect(w, r, flashURL(r, "/sites"), http.StatusSeeOther)
}
}
@@ -107,3 +84,30 @@ func authRedirect(w http.ResponseWriter, r *http.Request, msg string) {
r.URL.RawQuery = qs.Encode()
http.Redirect(w, r, flashURL(r, loginURL), http.StatusSeeOther)
}
+
+func authUser(db sws.UserStore, r *http.Request) (*http.Request, *sws.User) {
+ email := r.PostFormValue("email")
+ password := r.PostFormValue("password")
+
+ if email == "" || password == "" {
+ return flashSet(r, flashError, "invalid credentials"), nil
+ }
+
+ debug("authing email", email)
+
+ user, err := db.GetUserByEmail(email)
+ if err != nil || user == nil {
+ return flashSet(r, flashError, "invalid user"), nil
+ }
+
+ if !user.Enabled {
+ debug("user", email, "is disabled")
+ return flashSet(r, flashError, "access denied"), nil
+ }
+
+ if err := user.ValidPassword(password); err != nil {
+ return flashSet(r, flashError, "authentication failed"), nil
+ }
+ debug("user", email, "is authed")
+ return r, user
+}
diff --git a/cmd/server/handlers.go b/cmd/server/handlers.go
index e5ab385..18ee0fa 100644
--- a/cmd/server/handlers.go
+++ b/cmd/server/handlers.go
@@ -1,28 +1,50 @@
package main
import (
+ "html/template"
"net/http"
+ "strings"
"time"
"src.userspace.com.au/sws"
)
type templateData struct {
+ Payload string
+ Endpoint string
User *sws.User
- Flashes []flashMsg
+ Flash template.HTML
Begin *time.Time
End *time.Time
Site *sws.Site
Sites []*sws.Site
- PageSet *sws.PageSet
- Browsers *sws.BrowserSet
+ PageSet sws.PageSet
+ Browsers sws.BrowserSet
Hits *sws.HitSet
}
func newTemplateData(r *http.Request) *templateData {
- out := &templateData{Flashes: flashGet(r)}
- if user := r.Context().Value("user"); user != nil {
- out.User = user.(*sws.User)
+ out := &templateData{
+ Payload: "//" + *domain + "/sws.js",
+ Endpoint: "//" + *domain + "/sws.gif",
+ }
+ if r != nil {
+ flashes := flashGet(r)
+ var flash strings.Builder
+ for _, f := range flashes {
+ flash.WriteString(`<span class="`)
+ flash.WriteString(string(f.Level))
+ flash.WriteString(`">`)
+ flash.WriteString(f.Message)
+ flash.WriteString("</span>")
+ }
+ if len(flashes) > 0 {
+ out.Flash = template.HTML(flash.String())
+ }
+
+ if user := r.Context().Value("user"); user != nil {
+ out.User = user.(*sws.User)
+ }
}
return out
}
diff --git a/cmd/server/helpers.go b/cmd/server/helpers.go
index d17fa43..32e5cf3 100644
--- a/cmd/server/helpers.go
+++ b/cmd/server/helpers.go
@@ -60,6 +60,10 @@ func extractTimeRange(r *http.Request) (*time.Time, *time.Time) {
return begin, end
}
+func stringPtr(s string) *string {
+ return &s
+}
+
func timePtr(t time.Time) *time.Time {
return &t
}
diff --git a/cmd/server/hits.go b/cmd/server/hits.go
index af06757..9adb455 100644
--- a/cmd/server/hits.go
+++ b/cmd/server/hits.go
@@ -14,15 +14,14 @@ import (
)
const (
- endpoint = "//stats.userspace.com.au/sws.gif"
- gif = "R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7"
+ gif = "R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7"
)
-func handleHits(db sws.HitStore) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- return
- }
-}
+// func handleHits(db sws.HitStore) http.HandlerFunc {
+// return func(w http.ResponseWriter, r *http.Request) {
+// return
+// }
+// }
func handleHitCounter(db sws.CounterStore) http.HandlerFunc {
gifBytes, err := base64.StdEncoding.DecodeString(gif)
@@ -65,9 +64,8 @@ func handleCounter(addr string) http.HandlerFunc {
if err != nil || tmpl == nil {
panic(err)
}
- data := map[string]string{"endpoint": endpoint}
var buf bytes.Buffer
- if err := tmpl.Execute(&buf, data); err != nil {
+ if err := tmpl.Execute(&buf, newTemplateData(nil)); err != nil {
panic(err)
}
etag := fmt.Sprintf(`"%x"`, sha1.Sum(buf.Bytes()))
diff --git a/cmd/server/main.go b/cmd/server/main.go
index fc9515c..4e98b3b 100644
--- a/cmd/server/main.go
+++ b/cmd/server/main.go
@@ -1,32 +1,23 @@
package main
import (
- "bytes"
- "context"
"flag"
"fmt"
"net/http"
"os"
- "path/filepath"
- "strconv"
"strings"
"time"
- "github.com/go-chi/chi"
- "github.com/go-chi/chi/middleware"
- "github.com/go-chi/jwtauth"
_ "github.com/jackc/pgx/stdlib"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
"src.userspace.com.au/sws"
"src.userspace.com.au/sws/store"
- "src.userspace.com.au/templates"
)
var (
Version string
log, debug sws.Logger
- tokenAuth *jwtauth.JWTAuth
)
// Flags
@@ -34,6 +25,7 @@ var (
verbose *bool
addr *string
dsn *string
+ domain *string
noMigrate *bool
)
@@ -41,13 +33,13 @@ func init() {
verbose = boolFlag("verbose", "v", false, "VERBOSE", "enable verbose output")
addr = stringFlag("listen", "l", "localhost:5000", "LISTEN", "listen address")
dsn = stringFlag("dsn", "", "file:sws.db?cache=shared", "DSN", "database password")
+ domain = stringFlag("domain", "", "stats.userspace.com.au", "DOMAIN", "stats domain")
noMigrate = boolFlag("no-migrate", "m", false, "NOMIGRATE", "disable migrations")
// Default to no log
log = func(v ...interface{}) {}
debug = func(v ...interface{}) {}
- tokenAuth = jwtauth.New("HS256", []byte("lkjasd0f9u203ijsldkfj"), nil)
}
type Renderer interface {
@@ -102,144 +94,12 @@ func main() {
st = store.NewSqlite3Store(db)
}
- tmplsCommon := []string{"flash.tmpl", "navbar.tmpl"}
- tmplsAuthed := append(tmplsCommon, []string{"layouts/base.tmpl", "charts.tmpl"}...)
- tmplsPublic := append(tmplsCommon, "layouts/public.tmpl")
-
- tmpls, err := LoadHTMLTemplateMap(map[string][]string{
- "sites": append([]string{"sites.tmpl"}, tmplsAuthed...),
- "site": append([]string{"site.tmpl"}, tmplsAuthed...),
- "home": append([]string{"home.tmpl"}, tmplsPublic...),
- "login": append([]string{"login.tmpl"}, tmplsPublic...),
- "example": []string{"example.tmpl"},
- }, funcMap)
+ r, err := createRouter(st)
if err != nil {
log(err)
os.Exit(1)
}
- debug(tmpls["login"].DefinedTemplates())
- debug(tmpls["home"].DefinedTemplates())
- renderer := templates.NewRenderer(tmpls)
-
- r := chi.NewRouter()
- r.Use(middleware.RealIP)
- r.Use(middleware.RequestID)
- r.Use(middleware.RequestID)
- compressor := middleware.NewCompressor(5, "text/html", "text/css")
- r.Use(compressor.Handler())
- if *verbose {
- r.Use(middleware.Logger)
- }
- r.Use(middleware.Recoverer)
-
- siteCtx := getSiteCtx(st)
- userCtx := getUserCtx(st)
-
- // For counter
- r.Get("/sws.js", handleCounter(*addr))
- r.Get("/sws.gif", handleHitCounter(st))
-
- // For UI
- r.Get("/hits", handleHits(st))
-
- // Public routes
- r.Get("/", handleIndex(renderer))
- r.Get(loginURL, func(w http.ResponseWriter, r *http.Request) {
- payload := newTemplateData(r)
- if err := renderer.Render(w, "login", payload); err != nil {
- httpError(w, 500, err.Error())
- return
- }
- return
- })
-
- r.Post(loginURL, handleLogin(st, renderer))
-
- r.Get("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- p := strings.TrimPrefix(r.URL.Path, "/")
- if b, err := StaticLoadTemplate(p); err == nil {
- name := filepath.Base(p)
- http.ServeContent(w, r, name, time.Now(), bytes.NewReader(b))
- }
- }))
-
- // Authed routes
- r.Group(func(r chi.Router) {
- r.Use(jwtauth.Verifier(tokenAuth))
- r.Use(userCtx)
- r.Get(logoutURL, handleLogout(renderer))
- r.Route("/sites", func(r chi.Router) {
- r.Get("/", handleSites(st, renderer))
- r.Route("/{siteID}", func(r chi.Router) {
- r.Use(siteCtx)
- r.Get("/", handleSite(st, renderer))
- r.Route("/sparklines", func(r chi.Router) {
- r.Get("/{b:\\d+}-{e:\\d+}.svg", sparklineHandler(st))
- })
- r.Route("/charts", func(r chi.Router) {
- r.Get("/{b:\\d+}-{e:\\d+}.svg", svgChartHandler(st))
- r.Get("/{b:\\d+}-{e:\\d+}.png", svgChartHandler(st))
- })
- })
- })
- })
-
- // Example
- r.Get("/test.html", handleExample(renderer))
log("listening at", *addr)
http.ListenAndServe(*addr, r)
}
-
-func getUserCtx(db sws.UserStore) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- token, claims, err := jwtauth.FromContext(r.Context())
-
- if err != nil {
- authRedirect(w, r, "token error")
- return
- }
-
- if token == nil || !token.Valid {
- authRedirect(w, r, "invalid token")
- return
- }
-
- // Token is authenticated, get claims
-
- id, ok := claims["user_id"]
- if !ok {
- authRedirect(w, r, "missing user ID")
- return
- }
-
- user, err := db.GetUserByID(int(id.(float64)))
- if err != nil {
- authRedirect(w, r, "missing user")
- return
- }
- debug("found user, adding to context")
- ctx := context.WithValue(r.Context(), "user", user)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
- }
-}
-
-func getSiteCtx(db sws.SiteStore) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- id, err := strconv.Atoi(chi.URLParam(r, "siteID"))
- if err != nil {
- panic(err)
- }
- site, err := db.GetSiteByID(id)
- if err != nil {
- http.Error(w, http.StatusText(404), 404)
- return
- }
- ctx := context.WithValue(r.Context(), "site", site)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
- }
-}
diff --git a/cmd/server/routes.go b/cmd/server/routes.go
new file mode 100644
index 0000000..7d7ab32
--- /dev/null
+++ b/cmd/server/routes.go
@@ -0,0 +1,186 @@
+package main
+
+import (
+ "bytes"
+ "context"
+ "net/http"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/go-chi/chi"
+ "github.com/go-chi/chi/middleware"
+ "github.com/go-chi/jwtauth"
+ "src.userspace.com.au/sws"
+ "src.userspace.com.au/templates"
+)
+
+var tokenAuth *jwtauth.JWTAuth
+
+func init() {
+ tokenAuth = jwtauth.New("HS256", []byte("lkjasd0f9u203ijsldkfj"), nil)
+}
+
+func createRouter(db sws.Store) (chi.Router, error) {
+ tmplsCommon := []string{"flash.tmpl", "navbar.tmpl"}
+ tmplsAuthed := append(tmplsCommon, []string{"layouts/base.tmpl", "charts.tmpl"}...)
+ tmplsPublic := append(tmplsCommon, "layouts/public.tmpl")
+
+ tmpls, err := LoadHTMLTemplateMap(map[string][]string{
+ "sites": append([]string{"sites.tmpl"}, tmplsAuthed...),
+ "site": append([]string{"site.tmpl"}, tmplsAuthed...),
+ "home": append([]string{"home.tmpl"}, tmplsPublic...),
+ "login": append([]string{"login.tmpl"}, tmplsPublic...),
+ "user": append([]string{"user.tmpl"}, tmplsAuthed...),
+ "example": []string{"example.tmpl"},
+ }, funcMap)
+ if err != nil {
+ return nil, err
+ }
+ debug(tmpls["login"].DefinedTemplates())
+ debug(tmpls["home"].DefinedTemplates())
+
+ rndr := templates.NewRenderer(tmpls)
+
+ r := chi.NewRouter()
+ r.Use(middleware.RealIP)
+ r.Use(middleware.RequestID)
+ r.Use(middleware.RequestID)
+ compressor := middleware.NewCompressor(5, "text/html", "text/css")
+ r.Use(compressor.Handler())
+ if *verbose {
+ r.Use(middleware.Logger)
+ }
+ r.Use(middleware.Recoverer)
+
+ // For counter
+ r.Get("/sws.js", handleCounter(*addr))
+ r.Get("/sws.gif", handleHitCounter(db))
+ //r.Get("/hits", handleHits(db))
+
+ r.Group(func(r chi.Router) {
+ r.Use(jwtauth.Verifier(tokenAuth))
+ // Populate contect with user if present
+ r.Use(getUserCtx(db))
+
+ // Public routes
+ r.Group(func(r chi.Router) {
+ r.Get("/", handleIndex(rndr))
+ r.Get(loginURL, handleLogin(db, rndr))
+ })
+
+ r.Post(loginURL, handleLogin(db, rndr))
+
+ r.Get("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ p := strings.TrimPrefix(r.URL.Path, "/")
+ if b, err := StaticLoadTemplate(p); err == nil {
+ name := filepath.Base(p)
+ http.ServeContent(w, r, name, time.Now(), bytes.NewReader(b))
+ }
+ }))
+
+ // Authed routes
+ r.Group(func(r chi.Router) {
+ // Ensure we have a user in context
+ r.Use(func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if user := r.Context().Value("user"); user == nil {
+ authRedirect(w, r, "authentication required")
+ }
+ next.ServeHTTP(w, r)
+ })
+ })
+
+ r.Get(logoutURL, handleLogout(rndr))
+ r.Route("/sites", func(r chi.Router) {
+ r.Get("/", handleSites(db, rndr))
+ r.Post("/", handleSites(db, rndr))
+ r.Get("/new", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ payload := newTemplateData(r)
+ payload.Site = &sws.Site{}
+ if err := rndr.Render(w, "site", payload); err != nil {
+ httpError(w, 500, err.Error())
+ return
+ }
+ }))
+ r.Route("/{siteID}", func(r chi.Router) {
+ // Populate contect with site if present
+ r.Use(getSiteCtx(db))
+ r.Get("/", handleSite(db, rndr))
+ r.Post("/", handleSite(db, rndr))
+ r.Get("/edit", handleSiteEdit(db, rndr))
+ r.Route("/sparklines", func(r chi.Router) {
+ r.Get("/{b:\\d+}-{e:\\d+}.svg", sparklineHandler(db))
+ })
+ r.Route("/charts", func(r chi.Router) {
+ r.Get("/{b:\\d+}-{e:\\d+}.svg", svgChartHandler(db))
+ //r.Get("/{b:\\d+}-{e:\\d+}.png", svgChartHandler(db))
+ })
+ })
+ })
+ r.Route("/users", func(r chi.Router) {
+ userH := handleUsers(db, rndr)
+ r.Route("/{email}", func(r chi.Router) {
+ r.Get("/", userH)
+ r.Post("/", userH)
+ })
+ })
+ })
+ })
+
+ // Example
+ r.Get("/test.html", handleExample(rndr))
+ return r, nil
+}
+
+func getUserCtx(db sws.UserStore) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer func() {
+ next.ServeHTTP(w, r)
+ }()
+
+ _, claims, err := jwtauth.FromContext(r.Context())
+ if err != nil {
+ log("failed to extract user from context", err)
+ return
+ }
+
+ // Token is authenticated, get claims
+
+ id, ok := claims["user_id"]
+ if !ok {
+ log("missing user ID")
+ return
+ }
+
+ user, err := db.GetUserByID(int(id.(float64)))
+ if err != nil {
+ log("missing user")
+ return
+ }
+ debug("found user, adding to context")
+ ctx := context.WithValue(r.Context(), "user", user)
+ r = r.WithContext(ctx)
+ })
+ }
+}
+
+func getSiteCtx(db sws.SiteStore) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ id, err := strconv.Atoi(chi.URLParam(r, "siteID"))
+ if err != nil {
+ panic(err)
+ }
+ site, err := db.GetSiteByID(id)
+ if err != nil {
+ http.Error(w, http.StatusText(404), 404)
+ return
+ }
+ ctx := context.WithValue(r.Context(), "site", site)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+ }
+}
diff --git a/cmd/server/sites.go b/cmd/server/sites.go
index f7e7793..46bcdf3 100644
--- a/cmd/server/sites.go
+++ b/cmd/server/sites.go
@@ -2,12 +2,29 @@ package main
import (
"net/http"
+ "strings"
"src.userspace.com.au/sws"
)
func handleSites(db sws.SiteStore, rndr Renderer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == "POST" {
+ site := &sws.Site{
+ Name: r.FormValue("name"),
+ Description: r.FormValue("description"),
+ Aliases: r.FormValue("aliases"),
+ }
+ if errs := site.Validate(); len(errs) > 0 {
+ log("invalid site:", errs)
+ r = flashSet(r, flashError, strings.Join(errs, "<br>"))
+ } else if err := db.SaveSite(site); err != nil {
+ httpError(w, 500, err.Error())
+ return
+ }
+ r = flashSet(r, flashSuccess, "site created")
+ }
+
sites, err := db.GetSites()
if err != nil {
httpError(w, 500, err.Error())
@@ -32,6 +49,10 @@ func handleSite(db sws.SiteStore, rndr Renderer) http.HandlerFunc {
httpError(w, 422, "no site in context")
return
}
+
+ payload := newTemplateData(r)
+ payload.Site = site
+
begin, end := extractTimeRange(r)
if begin == nil || end == nil {
httpError(w, 406, "invalid time range")
@@ -52,23 +73,58 @@ func handleSite(db sws.SiteStore, rndr Renderer) http.HandlerFunc {
httpError(w, 406, err.Error())
return
}
- hitSet.Fill(begin, end)
- hitSet.SortByDate()
+ if hitSet != nil {
+ hitSet.Fill(begin, end)
+ hitSet.SortByDate()
+ payload.Hits = hitSet
- pageSet, err := sws.NewPageSet(hitSet)
- if err != nil {
- httpError(w, 406, err.Error())
+ pageSet, err := sws.NewPageSet(hitSet)
+ if err != nil {
+ httpError(w, 406, err.Error())
+ return
+ }
+
+ if pageSet != nil {
+ pageSet.SortByHits()
+ payload.PageSet = pageSet
+ }
+ browserSet := sws.NewBrowserSet(hitSet)
+ payload.Browsers = browserSet
+ }
+
+ if err := rndr.Render(w, "site", payload); err != nil {
+ httpError(w, 500, err.Error())
return
}
- pageSet.SortByHits()
+ }
+}
+
+func handleSiteEdit(db sws.SiteStore, rndr Renderer) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ site, ok := ctx.Value("site").(*sws.Site)
+ if !ok {
+ httpError(w, 422, "no site in context")
+ return
+ }
+
+ if r.Method == "POST" {
+ site.Name = r.FormValue("name")
+ site.Description = r.FormValue("description")
+ site.Aliases = r.FormValue("aliases")
- browserSet := sws.NewBrowserSet(hitSet)
+ if errs := site.Validate(); len(errs) > 0 {
+ log("invalid site:", errs)
+ r = flashSet(r, flashError, strings.Join(errs, "<br>"))
+ } else if err := db.SaveSite(site); err != nil {
+ httpError(w, 500, err.Error())
+ return
+ }
+ r = flashSet(r, flashSuccess, "site updated")
+ }
payload := newTemplateData(r)
payload.Site = site
- payload.PageSet = &pageSet
- payload.Browsers = &browserSet
- payload.Hits = hitSet
if err := rndr.Render(w, "site", payload); err != nil {
httpError(w, 500, err.Error())
diff --git a/cmd/server/users.go b/cmd/server/users.go
new file mode 100644
index 0000000..4218742
--- /dev/null
+++ b/cmd/server/users.go
@@ -0,0 +1,62 @@
+package main
+
+import (
+ "net/http"
+ "strings"
+
+ "github.com/go-chi/chi"
+ "src.userspace.com.au/sws"
+)
+
+func handleUsers(db sws.UserStore, rndr Renderer) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ authedUser, ok := ctx.Value("user").(*sws.User)
+ if !ok {
+ httpError(w, 422, "no user in context")
+ return
+ }
+ email := chi.URLParam(r, "email")
+ if email == "" {
+ httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
+ return
+ }
+ user, err := db.GetUserByEmail(email)
+ if err != nil {
+ http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
+ return
+ }
+ if *authedUser.ID != *user.ID && !authedUser.Admin {
+ log("failed attempt to edit user", *user.ID, *authedUser.ID)
+ http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
+ return
+ }
+
+ if r.Method == "POST" {
+ user.FirstName = stringPtr(r.FormValue("first_name"))
+ user.LastName = stringPtr(r.FormValue("last_name"))
+ user.Email = stringPtr(r.FormValue("email"))
+ user.PasswordConfirm = r.FormValue("password_confirmation")
+ user.Password = r.FormValue("password")
+ if errs := user.Validate(); len(errs) > 0 {
+ log("invalid user:", errs)
+ r = flashSet(r, flashError, strings.Join(errs, "<br>"))
+ } else {
+ if err := db.SaveUser(user); err != nil {
+ log("failed to update user:", err)
+ r = flashSet(r, flashError, err.Error())
+ } else {
+ r = flashSet(r, flashSuccess, "successfully updated")
+ }
+ }
+ }
+
+ payload := newTemplateData(r)
+ payload.User = user
+
+ if err := rndr.Render(w, "user", payload); err != nil {
+ httpError(w, 500, err.Error())
+ return
+ }
+ }
+}