diff options
| author | Felix Hanley <felix@userspace.com.au> | 2020-03-04 01:28:17 +0000 |
|---|---|---|
| committer | Felix Hanley <felix@userspace.com.au> | 2020-03-04 01:28:17 +0000 |
| commit | 9d687611493d2af60f6aedb3503dd85f6b3d49d6 (patch) | |
| tree | 214dad63ca36551dd2ba0077ee8335602fd4e423 /cmd | |
| parent | 1fbf1e1e1edc16e9a08fdb0d2012ca10f695f7ca (diff) | |
| download | sws-9d687611493d2af60f6aedb3503dd85f6b3d49d6.tar.gz sws-9d687611493d2af60f6aedb3503dd85f6b3d49d6.tar.bz2 | |
Refactor auth, template payloads and validation
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/server/auth.go | 122 | ||||
| -rw-r--r-- | cmd/server/handlers.go | 34 | ||||
| -rw-r--r-- | cmd/server/helpers.go | 4 | ||||
| -rw-r--r-- | cmd/server/hits.go | 16 | ||||
| -rw-r--r-- | cmd/server/main.go | 146 | ||||
| -rw-r--r-- | cmd/server/routes.go | 186 | ||||
| -rw-r--r-- | cmd/server/sites.go | 76 | ||||
| -rw-r--r-- | cmd/server/users.go | 62 |
8 files changed, 419 insertions, 227 deletions
diff --git a/cmd/server/auth.go b/cmd/server/auth.go index ed0c75f..e9ded0d 100644 --- a/cmd/server/auth.go +++ b/cmd/server/auth.go @@ -16,71 +16,48 @@ const ( func handleLogin(db sws.UserStore, rndr Renderer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - email := r.PostFormValue("email") - password := r.PostFormValue("password") + if r.Method == "POST" { + var user *sws.User + r, user = authUser(db, r) + if user != nil { + expiry := time.Now().Add(time.Hour) - if email == "" || password == "" { - //httpError(w, 406, "bad auth") - r = flashSet(r, flashError, "invalid credentials") - http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther) - return - } - - debug("authing email", email) - - user, err := db.GetUserByEmail(email) - if err != nil || user == nil { - //httpError(w, 404, err.Error()) - r = flashSet(r, flashError, "invalid user") - http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther) - return - } + _, t, err := tokenAuth.Encode(jwt.MapClaims{ + "user_id": *user.ID, + "exp": expiry.Unix(), + }) + if err != nil { + log("failed to encode claims:", err) + r = flashSet(r, flashError, "internal error") + http.Redirect(w, r, flashURL(r, "/"), http.StatusSeeOther) + return + } - if !user.Enabled { - debug("user", email, "is disabled") - //httpError(w, 403, "forbidden") - r = flashSet(r, flashError, "access denied") - http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther) - return + http.SetCookie(w, &http.Cookie{ + Name: "jwt", + Value: t, + HttpOnly: true, + Path: "/", + //Secure: true, + Expires: expiry, + }) + r = r.WithContext(context.WithValue(r.Context(), "user", user)) + r = flashSet(r, flashSuccess, "authenticated successfully") + qs := r.URL.Query() + if returnPath := qs.Get("return_to"); returnPath != "" { + qs.Del("return_to") + r.URL.RawQuery = qs.Encode() + debug("redirecting to", returnPath) + http.Redirect(w, r, flashURL(r, returnPath), http.StatusSeeOther) + } + http.Redirect(w, r, flashURL(r, "/sites"), http.StatusSeeOther) + } } - if err := user.ValidPassword(password); err != nil { - //httpError(w, 401, err.Error()) - r = flashSet(r, flashError, "authentication failed") - http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther) - return - } - debug("user", email, "is authed") - - expiry := time.Now().Add(time.Hour) - - _, t, err := tokenAuth.Encode(jwt.MapClaims{ - "user_id": *user.ID, - "exp": expiry.Unix(), - }) - if err != nil { + payload := newTemplateData(r) + if err := rndr.Render(w, "login", payload); err != nil { httpError(w, 500, err.Error()) - return } - - http.SetCookie(w, &http.Cookie{ - Name: "jwt", - Value: t, - HttpOnly: true, - Path: "/", - //Secure: true, - Expires: expiry, - }) - r = r.WithContext(context.WithValue(r.Context(), "user", user)) - r = flashSet(r, flashSuccess, "authenticated successfully") - qs := r.URL.Query() - if returnPath := qs.Get("return_to"); returnPath != "" { - qs.Del("return_to") - r.URL.RawQuery = qs.Encode() - debug("redirecting to", returnPath) - http.Redirect(w, r, flashURL(r, returnPath), http.StatusSeeOther) - } - http.Redirect(w, r, flashURL(r, "/sites"), http.StatusSeeOther) } } @@ -107,3 +84,30 @@ func authRedirect(w http.ResponseWriter, r *http.Request, msg string) { r.URL.RawQuery = qs.Encode() http.Redirect(w, r, flashURL(r, loginURL), http.StatusSeeOther) } + +func authUser(db sws.UserStore, r *http.Request) (*http.Request, *sws.User) { + email := r.PostFormValue("email") + password := r.PostFormValue("password") + + if email == "" || password == "" { + return flashSet(r, flashError, "invalid credentials"), nil + } + + debug("authing email", email) + + user, err := db.GetUserByEmail(email) + if err != nil || user == nil { + return flashSet(r, flashError, "invalid user"), nil + } + + if !user.Enabled { + debug("user", email, "is disabled") + return flashSet(r, flashError, "access denied"), nil + } + + if err := user.ValidPassword(password); err != nil { + return flashSet(r, flashError, "authentication failed"), nil + } + debug("user", email, "is authed") + return r, user +} diff --git a/cmd/server/handlers.go b/cmd/server/handlers.go index e5ab385..18ee0fa 100644 --- a/cmd/server/handlers.go +++ b/cmd/server/handlers.go @@ -1,28 +1,50 @@ package main import ( + "html/template" "net/http" + "strings" "time" "src.userspace.com.au/sws" ) type templateData struct { + Payload string + Endpoint string User *sws.User - Flashes []flashMsg + Flash template.HTML Begin *time.Time End *time.Time Site *sws.Site Sites []*sws.Site - PageSet *sws.PageSet - Browsers *sws.BrowserSet + PageSet sws.PageSet + Browsers sws.BrowserSet Hits *sws.HitSet } func newTemplateData(r *http.Request) *templateData { - out := &templateData{Flashes: flashGet(r)} - if user := r.Context().Value("user"); user != nil { - out.User = user.(*sws.User) + out := &templateData{ + Payload: "//" + *domain + "/sws.js", + Endpoint: "//" + *domain + "/sws.gif", + } + if r != nil { + flashes := flashGet(r) + var flash strings.Builder + for _, f := range flashes { + flash.WriteString(`<span class="`) + flash.WriteString(string(f.Level)) + flash.WriteString(`">`) + flash.WriteString(f.Message) + flash.WriteString("</span>") + } + if len(flashes) > 0 { + out.Flash = template.HTML(flash.String()) + } + + if user := r.Context().Value("user"); user != nil { + out.User = user.(*sws.User) + } } return out } diff --git a/cmd/server/helpers.go b/cmd/server/helpers.go index d17fa43..32e5cf3 100644 --- a/cmd/server/helpers.go +++ b/cmd/server/helpers.go @@ -60,6 +60,10 @@ func extractTimeRange(r *http.Request) (*time.Time, *time.Time) { return begin, end } +func stringPtr(s string) *string { + return &s +} + func timePtr(t time.Time) *time.Time { return &t } diff --git a/cmd/server/hits.go b/cmd/server/hits.go index af06757..9adb455 100644 --- a/cmd/server/hits.go +++ b/cmd/server/hits.go @@ -14,15 +14,14 @@ import ( ) const ( - endpoint = "//stats.userspace.com.au/sws.gif" - gif = "R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7" + gif = "R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7" ) -func handleHits(db sws.HitStore) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - return - } -} +// func handleHits(db sws.HitStore) http.HandlerFunc { +// return func(w http.ResponseWriter, r *http.Request) { +// return +// } +// } func handleHitCounter(db sws.CounterStore) http.HandlerFunc { gifBytes, err := base64.StdEncoding.DecodeString(gif) @@ -65,9 +64,8 @@ func handleCounter(addr string) http.HandlerFunc { if err != nil || tmpl == nil { panic(err) } - data := map[string]string{"endpoint": endpoint} var buf bytes.Buffer - if err := tmpl.Execute(&buf, data); err != nil { + if err := tmpl.Execute(&buf, newTemplateData(nil)); err != nil { panic(err) } etag := fmt.Sprintf(`"%x"`, sha1.Sum(buf.Bytes())) diff --git a/cmd/server/main.go b/cmd/server/main.go index fc9515c..4e98b3b 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,32 +1,23 @@ package main import ( - "bytes" - "context" "flag" "fmt" "net/http" "os" - "path/filepath" - "strconv" "strings" "time" - "github.com/go-chi/chi" - "github.com/go-chi/chi/middleware" - "github.com/go-chi/jwtauth" _ "github.com/jackc/pgx/stdlib" "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" "src.userspace.com.au/sws" "src.userspace.com.au/sws/store" - "src.userspace.com.au/templates" ) var ( Version string log, debug sws.Logger - tokenAuth *jwtauth.JWTAuth ) // Flags @@ -34,6 +25,7 @@ var ( verbose *bool addr *string dsn *string + domain *string noMigrate *bool ) @@ -41,13 +33,13 @@ func init() { verbose = boolFlag("verbose", "v", false, "VERBOSE", "enable verbose output") addr = stringFlag("listen", "l", "localhost:5000", "LISTEN", "listen address") dsn = stringFlag("dsn", "", "file:sws.db?cache=shared", "DSN", "database password") + domain = stringFlag("domain", "", "stats.userspace.com.au", "DOMAIN", "stats domain") noMigrate = boolFlag("no-migrate", "m", false, "NOMIGRATE", "disable migrations") // Default to no log log = func(v ...interface{}) {} debug = func(v ...interface{}) {} - tokenAuth = jwtauth.New("HS256", []byte("lkjasd0f9u203ijsldkfj"), nil) } type Renderer interface { @@ -102,144 +94,12 @@ func main() { st = store.NewSqlite3Store(db) } - tmplsCommon := []string{"flash.tmpl", "navbar.tmpl"} - tmplsAuthed := append(tmplsCommon, []string{"layouts/base.tmpl", "charts.tmpl"}...) - tmplsPublic := append(tmplsCommon, "layouts/public.tmpl") - - tmpls, err := LoadHTMLTemplateMap(map[string][]string{ - "sites": append([]string{"sites.tmpl"}, tmplsAuthed...), - "site": append([]string{"site.tmpl"}, tmplsAuthed...), - "home": append([]string{"home.tmpl"}, tmplsPublic...), - "login": append([]string{"login.tmpl"}, tmplsPublic...), - "example": []string{"example.tmpl"}, - }, funcMap) + r, err := createRouter(st) if err != nil { log(err) os.Exit(1) } - debug(tmpls["login"].DefinedTemplates()) - debug(tmpls["home"].DefinedTemplates()) - renderer := templates.NewRenderer(tmpls) - - r := chi.NewRouter() - r.Use(middleware.RealIP) - r.Use(middleware.RequestID) - r.Use(middleware.RequestID) - compressor := middleware.NewCompressor(5, "text/html", "text/css") - r.Use(compressor.Handler()) - if *verbose { - r.Use(middleware.Logger) - } - r.Use(middleware.Recoverer) - - siteCtx := getSiteCtx(st) - userCtx := getUserCtx(st) - - // For counter - r.Get("/sws.js", handleCounter(*addr)) - r.Get("/sws.gif", handleHitCounter(st)) - - // For UI - r.Get("/hits", handleHits(st)) - - // Public routes - r.Get("/", handleIndex(renderer)) - r.Get(loginURL, func(w http.ResponseWriter, r *http.Request) { - payload := newTemplateData(r) - if err := renderer.Render(w, "login", payload); err != nil { - httpError(w, 500, err.Error()) - return - } - return - }) - - r.Post(loginURL, handleLogin(st, renderer)) - - r.Get("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - p := strings.TrimPrefix(r.URL.Path, "/") - if b, err := StaticLoadTemplate(p); err == nil { - name := filepath.Base(p) - http.ServeContent(w, r, name, time.Now(), bytes.NewReader(b)) - } - })) - - // Authed routes - r.Group(func(r chi.Router) { - r.Use(jwtauth.Verifier(tokenAuth)) - r.Use(userCtx) - r.Get(logoutURL, handleLogout(renderer)) - r.Route("/sites", func(r chi.Router) { - r.Get("/", handleSites(st, renderer)) - r.Route("/{siteID}", func(r chi.Router) { - r.Use(siteCtx) - r.Get("/", handleSite(st, renderer)) - r.Route("/sparklines", func(r chi.Router) { - r.Get("/{b:\\d+}-{e:\\d+}.svg", sparklineHandler(st)) - }) - r.Route("/charts", func(r chi.Router) { - r.Get("/{b:\\d+}-{e:\\d+}.svg", svgChartHandler(st)) - r.Get("/{b:\\d+}-{e:\\d+}.png", svgChartHandler(st)) - }) - }) - }) - }) - - // Example - r.Get("/test.html", handleExample(renderer)) log("listening at", *addr) http.ListenAndServe(*addr, r) } - -func getUserCtx(db sws.UserStore) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token, claims, err := jwtauth.FromContext(r.Context()) - - if err != nil { - authRedirect(w, r, "token error") - return - } - - if token == nil || !token.Valid { - authRedirect(w, r, "invalid token") - return - } - - // Token is authenticated, get claims - - id, ok := claims["user_id"] - if !ok { - authRedirect(w, r, "missing user ID") - return - } - - user, err := db.GetUserByID(int(id.(float64))) - if err != nil { - authRedirect(w, r, "missing user") - return - } - debug("found user, adding to context") - ctx := context.WithValue(r.Context(), "user", user) - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - -func getSiteCtx(db sws.SiteStore) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - id, err := strconv.Atoi(chi.URLParam(r, "siteID")) - if err != nil { - panic(err) - } - site, err := db.GetSiteByID(id) - if err != nil { - http.Error(w, http.StatusText(404), 404) - return - } - ctx := context.WithValue(r.Context(), "site", site) - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} diff --git a/cmd/server/routes.go b/cmd/server/routes.go new file mode 100644 index 0000000..7d7ab32 --- /dev/null +++ b/cmd/server/routes.go @@ -0,0 +1,186 @@ +package main + +import ( + "bytes" + "context" + "net/http" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" + "github.com/go-chi/jwtauth" + "src.userspace.com.au/sws" + "src.userspace.com.au/templates" +) + +var tokenAuth *jwtauth.JWTAuth + +func init() { + tokenAuth = jwtauth.New("HS256", []byte("lkjasd0f9u203ijsldkfj"), nil) +} + +func createRouter(db sws.Store) (chi.Router, error) { + tmplsCommon := []string{"flash.tmpl", "navbar.tmpl"} + tmplsAuthed := append(tmplsCommon, []string{"layouts/base.tmpl", "charts.tmpl"}...) + tmplsPublic := append(tmplsCommon, "layouts/public.tmpl") + + tmpls, err := LoadHTMLTemplateMap(map[string][]string{ + "sites": append([]string{"sites.tmpl"}, tmplsAuthed...), + "site": append([]string{"site.tmpl"}, tmplsAuthed...), + "home": append([]string{"home.tmpl"}, tmplsPublic...), + "login": append([]string{"login.tmpl"}, tmplsPublic...), + "user": append([]string{"user.tmpl"}, tmplsAuthed...), + "example": []string{"example.tmpl"}, + }, funcMap) + if err != nil { + return nil, err + } + debug(tmpls["login"].DefinedTemplates()) + debug(tmpls["home"].DefinedTemplates()) + + rndr := templates.NewRenderer(tmpls) + + r := chi.NewRouter() + r.Use(middleware.RealIP) + r.Use(middleware.RequestID) + r.Use(middleware.RequestID) + compressor := middleware.NewCompressor(5, "text/html", "text/css") + r.Use(compressor.Handler()) + if *verbose { + r.Use(middleware.Logger) + } + r.Use(middleware.Recoverer) + + // For counter + r.Get("/sws.js", handleCounter(*addr)) + r.Get("/sws.gif", handleHitCounter(db)) + //r.Get("/hits", handleHits(db)) + + r.Group(func(r chi.Router) { + r.Use(jwtauth.Verifier(tokenAuth)) + // Populate contect with user if present + r.Use(getUserCtx(db)) + + // Public routes + r.Group(func(r chi.Router) { + r.Get("/", handleIndex(rndr)) + r.Get(loginURL, handleLogin(db, rndr)) + }) + + r.Post(loginURL, handleLogin(db, rndr)) + + r.Get("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p := strings.TrimPrefix(r.URL.Path, "/") + if b, err := StaticLoadTemplate(p); err == nil { + name := filepath.Base(p) + http.ServeContent(w, r, name, time.Now(), bytes.NewReader(b)) + } + })) + + // Authed routes + r.Group(func(r chi.Router) { + // Ensure we have a user in context + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if user := r.Context().Value("user"); user == nil { + authRedirect(w, r, "authentication required") + } + next.ServeHTTP(w, r) + }) + }) + + r.Get(logoutURL, handleLogout(rndr)) + r.Route("/sites", func(r chi.Router) { + r.Get("/", handleSites(db, rndr)) + r.Post("/", handleSites(db, rndr)) + r.Get("/new", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + payload := newTemplateData(r) + payload.Site = &sws.Site{} + if err := rndr.Render(w, "site", payload); err != nil { + httpError(w, 500, err.Error()) + return + } + })) + r.Route("/{siteID}", func(r chi.Router) { + // Populate contect with site if present + r.Use(getSiteCtx(db)) + r.Get("/", handleSite(db, rndr)) + r.Post("/", handleSite(db, rndr)) + r.Get("/edit", handleSiteEdit(db, rndr)) + r.Route("/sparklines", func(r chi.Router) { + r.Get("/{b:\\d+}-{e:\\d+}.svg", sparklineHandler(db)) + }) + r.Route("/charts", func(r chi.Router) { + r.Get("/{b:\\d+}-{e:\\d+}.svg", svgChartHandler(db)) + //r.Get("/{b:\\d+}-{e:\\d+}.png", svgChartHandler(db)) + }) + }) + }) + r.Route("/users", func(r chi.Router) { + userH := handleUsers(db, rndr) + r.Route("/{email}", func(r chi.Router) { + r.Get("/", userH) + r.Post("/", userH) + }) + }) + }) + }) + + // Example + r.Get("/test.html", handleExample(rndr)) + return r, nil +} + +func getUserCtx(db sws.UserStore) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + next.ServeHTTP(w, r) + }() + + _, claims, err := jwtauth.FromContext(r.Context()) + if err != nil { + log("failed to extract user from context", err) + return + } + + // Token is authenticated, get claims + + id, ok := claims["user_id"] + if !ok { + log("missing user ID") + return + } + + user, err := db.GetUserByID(int(id.(float64))) + if err != nil { + log("missing user") + return + } + debug("found user, adding to context") + ctx := context.WithValue(r.Context(), "user", user) + r = r.WithContext(ctx) + }) + } +} + +func getSiteCtx(db sws.SiteStore) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id, err := strconv.Atoi(chi.URLParam(r, "siteID")) + if err != nil { + panic(err) + } + site, err := db.GetSiteByID(id) + if err != nil { + http.Error(w, http.StatusText(404), 404) + return + } + ctx := context.WithValue(r.Context(), "site", site) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/cmd/server/sites.go b/cmd/server/sites.go index f7e7793..46bcdf3 100644 --- a/cmd/server/sites.go +++ b/cmd/server/sites.go @@ -2,12 +2,29 @@ package main import ( "net/http" + "strings" "src.userspace.com.au/sws" ) func handleSites(db sws.SiteStore, rndr Renderer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + if r.Method == "POST" { + site := &sws.Site{ + Name: r.FormValue("name"), + Description: r.FormValue("description"), + Aliases: r.FormValue("aliases"), + } + if errs := site.Validate(); len(errs) > 0 { + log("invalid site:", errs) + r = flashSet(r, flashError, strings.Join(errs, "<br>")) + } else if err := db.SaveSite(site); err != nil { + httpError(w, 500, err.Error()) + return + } + r = flashSet(r, flashSuccess, "site created") + } + sites, err := db.GetSites() if err != nil { httpError(w, 500, err.Error()) @@ -32,6 +49,10 @@ func handleSite(db sws.SiteStore, rndr Renderer) http.HandlerFunc { httpError(w, 422, "no site in context") return } + + payload := newTemplateData(r) + payload.Site = site + begin, end := extractTimeRange(r) if begin == nil || end == nil { httpError(w, 406, "invalid time range") @@ -52,23 +73,58 @@ func handleSite(db sws.SiteStore, rndr Renderer) http.HandlerFunc { httpError(w, 406, err.Error()) return } - hitSet.Fill(begin, end) - hitSet.SortByDate() + if hitSet != nil { + hitSet.Fill(begin, end) + hitSet.SortByDate() + payload.Hits = hitSet - pageSet, err := sws.NewPageSet(hitSet) - if err != nil { - httpError(w, 406, err.Error()) + pageSet, err := sws.NewPageSet(hitSet) + if err != nil { + httpError(w, 406, err.Error()) + return + } + + if pageSet != nil { + pageSet.SortByHits() + payload.PageSet = pageSet + } + browserSet := sws.NewBrowserSet(hitSet) + payload.Browsers = browserSet + } + + if err := rndr.Render(w, "site", payload); err != nil { + httpError(w, 500, err.Error()) return } - pageSet.SortByHits() + } +} + +func handleSiteEdit(db sws.SiteStore, rndr Renderer) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + site, ok := ctx.Value("site").(*sws.Site) + if !ok { + httpError(w, 422, "no site in context") + return + } + + if r.Method == "POST" { + site.Name = r.FormValue("name") + site.Description = r.FormValue("description") + site.Aliases = r.FormValue("aliases") - browserSet := sws.NewBrowserSet(hitSet) + if errs := site.Validate(); len(errs) > 0 { + log("invalid site:", errs) + r = flashSet(r, flashError, strings.Join(errs, "<br>")) + } else if err := db.SaveSite(site); err != nil { + httpError(w, 500, err.Error()) + return + } + r = flashSet(r, flashSuccess, "site updated") + } payload := newTemplateData(r) payload.Site = site - payload.PageSet = &pageSet - payload.Browsers = &browserSet - payload.Hits = hitSet if err := rndr.Render(w, "site", payload); err != nil { httpError(w, 500, err.Error()) diff --git a/cmd/server/users.go b/cmd/server/users.go new file mode 100644 index 0000000..4218742 --- /dev/null +++ b/cmd/server/users.go @@ -0,0 +1,62 @@ +package main + +import ( + "net/http" + "strings" + + "github.com/go-chi/chi" + "src.userspace.com.au/sws" +) + +func handleUsers(db sws.UserStore, rndr Renderer) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + authedUser, ok := ctx.Value("user").(*sws.User) + if !ok { + httpError(w, 422, "no user in context") + return + } + email := chi.URLParam(r, "email") + if email == "" { + httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest)) + return + } + user, err := db.GetUserByEmail(email) + if err != nil { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } + if *authedUser.ID != *user.ID && !authedUser.Admin { + log("failed attempt to edit user", *user.ID, *authedUser.ID) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + if r.Method == "POST" { + user.FirstName = stringPtr(r.FormValue("first_name")) + user.LastName = stringPtr(r.FormValue("last_name")) + user.Email = stringPtr(r.FormValue("email")) + user.PasswordConfirm = r.FormValue("password_confirmation") + user.Password = r.FormValue("password") + if errs := user.Validate(); len(errs) > 0 { + log("invalid user:", errs) + r = flashSet(r, flashError, strings.Join(errs, "<br>")) + } else { + if err := db.SaveUser(user); err != nil { + log("failed to update user:", err) + r = flashSet(r, flashError, err.Error()) + } else { + r = flashSet(r, flashSuccess, "successfully updated") + } + } + } + + payload := newTemplateData(r) + payload.User = user + + if err := rndr.Render(w, "user", payload); err != nil { + httpError(w, 500, err.Error()) + return + } + } +} |
