aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFelix Hanley <felix@userspace.com.au>2020-03-04 01:28:17 +0000
committerFelix Hanley <felix@userspace.com.au>2020-03-04 01:28:17 +0000
commit9d687611493d2af60f6aedb3503dd85f6b3d49d6 (patch)
tree214dad63ca36551dd2ba0077ee8335602fd4e423
parent1fbf1e1e1edc16e9a08fdb0d2012ca10f695f7ca (diff)
downloadsws-9d687611493d2af60f6aedb3503dd85f6b3d49d6.tar.gz
sws-9d687611493d2af60f6aedb3503dd85f6b3d49d6.tar.bz2
Refactor auth, template payloads and validation
-rw-r--r--browser_set.go3
-rw-r--r--cmd/server/auth.go122
-rw-r--r--cmd/server/handlers.go34
-rw-r--r--cmd/server/helpers.go4
-rw-r--r--cmd/server/hits.go16
-rw-r--r--cmd/server/main.go146
-rw-r--r--cmd/server/routes.go186
-rw-r--r--cmd/server/sites.go76
-rw-r--r--cmd/server/users.go62
-rw-r--r--counter/sws.js2
-rw-r--r--page_set.go3
-rw-r--r--site.go19
-rw-r--r--sql/sqlite3/01_sites.sql8
-rw-r--r--static/default.css2
-rw-r--r--store.go3
-rw-r--r--store/sqlite3.go25
-rw-r--r--templates/flash.tmpl8
-rw-r--r--templates/login.tmpl4
-rw-r--r--templates/site.tmpl128
-rw-r--r--templates/user.tmpl39
-rw-r--r--user.go52
21 files changed, 645 insertions, 297 deletions
diff --git a/browser_set.go b/browser_set.go
index c316480..a46ccd2 100644
--- a/browser_set.go
+++ b/browser_set.go
@@ -43,6 +43,9 @@ func NewBrowserSet(hs *HitSet) BrowserSet {
//b.hitSet.Add(h)
tmp[browser] = b
}
+ if len(tmp) < 1 {
+ return nil
+ }
out := make([]*Browser, len(tmp))
i := 0
for _, b := range tmp {
diff --git a/cmd/server/auth.go b/cmd/server/auth.go
index ed0c75f..e9ded0d 100644
--- a/cmd/server/auth.go
+++ b/cmd/server/auth.go
@@ -16,71 +16,48 @@ const (
func handleLogin(db sws.UserStore, rndr Renderer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- email := r.PostFormValue("email")
- password := r.PostFormValue("password")
+ if r.Method == "POST" {
+ var user *sws.User
+ r, user = authUser(db, r)
+ if user != nil {
+ expiry := time.Now().Add(time.Hour)
- if email == "" || password == "" {
- //httpError(w, 406, "bad auth")
- r = flashSet(r, flashError, "invalid credentials")
- http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther)
- return
- }
-
- debug("authing email", email)
-
- user, err := db.GetUserByEmail(email)
- if err != nil || user == nil {
- //httpError(w, 404, err.Error())
- r = flashSet(r, flashError, "invalid user")
- http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther)
- return
- }
+ _, t, err := tokenAuth.Encode(jwt.MapClaims{
+ "user_id": *user.ID,
+ "exp": expiry.Unix(),
+ })
+ if err != nil {
+ log("failed to encode claims:", err)
+ r = flashSet(r, flashError, "internal error")
+ http.Redirect(w, r, flashURL(r, "/"), http.StatusSeeOther)
+ return
+ }
- if !user.Enabled {
- debug("user", email, "is disabled")
- //httpError(w, 403, "forbidden")
- r = flashSet(r, flashError, "access denied")
- http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther)
- return
+ http.SetCookie(w, &http.Cookie{
+ Name: "jwt",
+ Value: t,
+ HttpOnly: true,
+ Path: "/",
+ //Secure: true,
+ Expires: expiry,
+ })
+ r = r.WithContext(context.WithValue(r.Context(), "user", user))
+ r = flashSet(r, flashSuccess, "authenticated successfully")
+ qs := r.URL.Query()
+ if returnPath := qs.Get("return_to"); returnPath != "" {
+ qs.Del("return_to")
+ r.URL.RawQuery = qs.Encode()
+ debug("redirecting to", returnPath)
+ http.Redirect(w, r, flashURL(r, returnPath), http.StatusSeeOther)
+ }
+ http.Redirect(w, r, flashURL(r, "/sites"), http.StatusSeeOther)
+ }
}
- if err := user.ValidPassword(password); err != nil {
- //httpError(w, 401, err.Error())
- r = flashSet(r, flashError, "authentication failed")
- http.Redirect(w, flashQuery(r), loginURL, http.StatusSeeOther)
- return
- }
- debug("user", email, "is authed")
-
- expiry := time.Now().Add(time.Hour)
-
- _, t, err := tokenAuth.Encode(jwt.MapClaims{
- "user_id": *user.ID,
- "exp": expiry.Unix(),
- })
- if err != nil {
+ payload := newTemplateData(r)
+ if err := rndr.Render(w, "login", payload); err != nil {
httpError(w, 500, err.Error())
- return
}
-
- http.SetCookie(w, &http.Cookie{
- Name: "jwt",
- Value: t,
- HttpOnly: true,
- Path: "/",
- //Secure: true,
- Expires: expiry,
- })
- r = r.WithContext(context.WithValue(r.Context(), "user", user))
- r = flashSet(r, flashSuccess, "authenticated successfully")
- qs := r.URL.Query()
- if returnPath := qs.Get("return_to"); returnPath != "" {
- qs.Del("return_to")
- r.URL.RawQuery = qs.Encode()
- debug("redirecting to", returnPath)
- http.Redirect(w, r, flashURL(r, returnPath), http.StatusSeeOther)
- }
- http.Redirect(w, r, flashURL(r, "/sites"), http.StatusSeeOther)
}
}
@@ -107,3 +84,30 @@ func authRedirect(w http.ResponseWriter, r *http.Request, msg string) {
r.URL.RawQuery = qs.Encode()
http.Redirect(w, r, flashURL(r, loginURL), http.StatusSeeOther)
}
+
+func authUser(db sws.UserStore, r *http.Request) (*http.Request, *sws.User) {
+ email := r.PostFormValue("email")
+ password := r.PostFormValue("password")
+
+ if email == "" || password == "" {
+ return flashSet(r, flashError, "invalid credentials"), nil
+ }
+
+ debug("authing email", email)
+
+ user, err := db.GetUserByEmail(email)
+ if err != nil || user == nil {
+ return flashSet(r, flashError, "invalid user"), nil
+ }
+
+ if !user.Enabled {
+ debug("user", email, "is disabled")
+ return flashSet(r, flashError, "access denied"), nil
+ }
+
+ if err := user.ValidPassword(password); err != nil {
+ return flashSet(r, flashError, "authentication failed"), nil
+ }
+ debug("user", email, "is authed")
+ return r, user
+}
diff --git a/cmd/server/handlers.go b/cmd/server/handlers.go
index e5ab385..18ee0fa 100644
--- a/cmd/server/handlers.go
+++ b/cmd/server/handlers.go
@@ -1,28 +1,50 @@
package main
import (
+ "html/template"
"net/http"
+ "strings"
"time"
"src.userspace.com.au/sws"
)
type templateData struct {
+ Payload string
+ Endpoint string
User *sws.User
- Flashes []flashMsg
+ Flash template.HTML
Begin *time.Time
End *time.Time
Site *sws.Site
Sites []*sws.Site
- PageSet *sws.PageSet
- Browsers *sws.BrowserSet
+ PageSet sws.PageSet
+ Browsers sws.BrowserSet
Hits *sws.HitSet
}
func newTemplateData(r *http.Request) *templateData {
- out := &templateData{Flashes: flashGet(r)}
- if user := r.Context().Value("user"); user != nil {
- out.User = user.(*sws.User)
+ out := &templateData{
+ Payload: "//" + *domain + "/sws.js",
+ Endpoint: "//" + *domain + "/sws.gif",
+ }
+ if r != nil {
+ flashes := flashGet(r)
+ var flash strings.Builder
+ for _, f := range flashes {
+ flash.WriteString(`<span class="`)
+ flash.WriteString(string(f.Level))
+ flash.WriteString(`">`)
+ flash.WriteString(f.Message)
+ flash.WriteString("</span>")
+ }
+ if len(flashes) > 0 {
+ out.Flash = template.HTML(flash.String())
+ }
+
+ if user := r.Context().Value("user"); user != nil {
+ out.User = user.(*sws.User)
+ }
}
return out
}
diff --git a/cmd/server/helpers.go b/cmd/server/helpers.go
index d17fa43..32e5cf3 100644
--- a/cmd/server/helpers.go
+++ b/cmd/server/helpers.go
@@ -60,6 +60,10 @@ func extractTimeRange(r *http.Request) (*time.Time, *time.Time) {
return begin, end
}
+func stringPtr(s string) *string {
+ return &s
+}
+
func timePtr(t time.Time) *time.Time {
return &t
}
diff --git a/cmd/server/hits.go b/cmd/server/hits.go
index af06757..9adb455 100644
--- a/cmd/server/hits.go
+++ b/cmd/server/hits.go
@@ -14,15 +14,14 @@ import (
)
const (
- endpoint = "//stats.userspace.com.au/sws.gif"
- gif = "R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7"
+ gif = "R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7"
)
-func handleHits(db sws.HitStore) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- return
- }
-}
+// func handleHits(db sws.HitStore) http.HandlerFunc {
+// return func(w http.ResponseWriter, r *http.Request) {
+// return
+// }
+// }
func handleHitCounter(db sws.CounterStore) http.HandlerFunc {
gifBytes, err := base64.StdEncoding.DecodeString(gif)
@@ -65,9 +64,8 @@ func handleCounter(addr string) http.HandlerFunc {
if err != nil || tmpl == nil {
panic(err)
}
- data := map[string]string{"endpoint": endpoint}
var buf bytes.Buffer
- if err := tmpl.Execute(&buf, data); err != nil {
+ if err := tmpl.Execute(&buf, newTemplateData(nil)); err != nil {
panic(err)
}
etag := fmt.Sprintf(`"%x"`, sha1.Sum(buf.Bytes()))
diff --git a/cmd/server/main.go b/cmd/server/main.go
index fc9515c..4e98b3b 100644
--- a/cmd/server/main.go
+++ b/cmd/server/main.go
@@ -1,32 +1,23 @@
package main
import (
- "bytes"
- "context"
"flag"
"fmt"
"net/http"
"os"
- "path/filepath"
- "strconv"
"strings"
"time"
- "github.com/go-chi/chi"
- "github.com/go-chi/chi/middleware"
- "github.com/go-chi/jwtauth"
_ "github.com/jackc/pgx/stdlib"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
"src.userspace.com.au/sws"
"src.userspace.com.au/sws/store"
- "src.userspace.com.au/templates"
)
var (
Version string
log, debug sws.Logger
- tokenAuth *jwtauth.JWTAuth
)
// Flags
@@ -34,6 +25,7 @@ var (
verbose *bool
addr *string
dsn *string
+ domain *string
noMigrate *bool
)
@@ -41,13 +33,13 @@ func init() {
verbose = boolFlag("verbose", "v", false, "VERBOSE", "enable verbose output")
addr = stringFlag("listen", "l", "localhost:5000", "LISTEN", "listen address")
dsn = stringFlag("dsn", "", "file:sws.db?cache=shared", "DSN", "database password")
+ domain = stringFlag("domain", "", "stats.userspace.com.au", "DOMAIN", "stats domain")
noMigrate = boolFlag("no-migrate", "m", false, "NOMIGRATE", "disable migrations")
// Default to no log
log = func(v ...interface{}) {}
debug = func(v ...interface{}) {}
- tokenAuth = jwtauth.New("HS256", []byte("lkjasd0f9u203ijsldkfj"), nil)
}
type Renderer interface {
@@ -102,144 +94,12 @@ func main() {
st = store.NewSqlite3Store(db)
}
- tmplsCommon := []string{"flash.tmpl", "navbar.tmpl"}
- tmplsAuthed := append(tmplsCommon, []string{"layouts/base.tmpl", "charts.tmpl"}...)
- tmplsPublic := append(tmplsCommon, "layouts/public.tmpl")
-
- tmpls, err := LoadHTMLTemplateMap(map[string][]string{
- "sites": append([]string{"sites.tmpl"}, tmplsAuthed...),
- "site": append([]string{"site.tmpl"}, tmplsAuthed...),
- "home": append([]string{"home.tmpl"}, tmplsPublic...),
- "login": append([]string{"login.tmpl"}, tmplsPublic...),
- "example": []string{"example.tmpl"},
- }, funcMap)
+ r, err := createRouter(st)
if err != nil {
log(err)
os.Exit(1)
}
- debug(tmpls["login"].DefinedTemplates())
- debug(tmpls["home"].DefinedTemplates())
- renderer := templates.NewRenderer(tmpls)
-
- r := chi.NewRouter()
- r.Use(middleware.RealIP)
- r.Use(middleware.RequestID)
- r.Use(middleware.RequestID)
- compressor := middleware.NewCompressor(5, "text/html", "text/css")
- r.Use(compressor.Handler())
- if *verbose {
- r.Use(middleware.Logger)
- }
- r.Use(middleware.Recoverer)
-
- siteCtx := getSiteCtx(st)
- userCtx := getUserCtx(st)
-
- // For counter
- r.Get("/sws.js", handleCounter(*addr))
- r.Get("/sws.gif", handleHitCounter(st))
-
- // For UI
- r.Get("/hits", handleHits(st))
-
- // Public routes
- r.Get("/", handleIndex(renderer))
- r.Get(loginURL, func(w http.ResponseWriter, r *http.Request) {
- payload := newTemplateData(r)
- if err := renderer.Render(w, "login", payload); err != nil {
- httpError(w, 500, err.Error())
- return
- }
- return
- })
-
- r.Post(loginURL, handleLogin(st, renderer))
-
- r.Get("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- p := strings.TrimPrefix(r.URL.Path, "/")
- if b, err := StaticLoadTemplate(p); err == nil {
- name := filepath.Base(p)
- http.ServeContent(w, r, name, time.Now(), bytes.NewReader(b))
- }
- }))
-
- // Authed routes
- r.Group(func(r chi.Router) {
- r.Use(jwtauth.Verifier(tokenAuth))
- r.Use(userCtx)
- r.Get(logoutURL, handleLogout(renderer))
- r.Route("/sites", func(r chi.Router) {
- r.Get("/", handleSites(st, renderer))
- r.Route("/{siteID}", func(r chi.Router) {
- r.Use(siteCtx)
- r.Get("/", handleSite(st, renderer))
- r.Route("/sparklines", func(r chi.Router) {
- r.Get("/{b:\\d+}-{e:\\d+}.svg", sparklineHandler(st))
- })
- r.Route("/charts", func(r chi.Router) {
- r.Get("/{b:\\d+}-{e:\\d+}.svg", svgChartHandler(st))
- r.Get("/{b:\\d+}-{e:\\d+}.png", svgChartHandler(st))
- })
- })
- })
- })
-
- // Example
- r.Get("/test.html", handleExample(renderer))
log("listening at", *addr)
http.ListenAndServe(*addr, r)
}
-
-func getUserCtx(db sws.UserStore) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- token, claims, err := jwtauth.FromContext(r.Context())
-
- if err != nil {
- authRedirect(w, r, "token error")
- return
- }
-
- if token == nil || !token.Valid {
- authRedirect(w, r, "invalid token")
- return
- }
-
- // Token is authenticated, get claims
-
- id, ok := claims["user_id"]
- if !ok {
- authRedirect(w, r, "missing user ID")
- return
- }
-
- user, err := db.GetUserByID(int(id.(float64)))
- if err != nil {
- authRedirect(w, r, "missing user")
- return
- }
- debug("found user, adding to context")
- ctx := context.WithValue(r.Context(), "user", user)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
- }
-}
-
-func getSiteCtx(db sws.SiteStore) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- id, err := strconv.Atoi(chi.URLParam(r, "siteID"))
- if err != nil {
- panic(err)
- }
- site, err := db.GetSiteByID(id)
- if err != nil {
- http.Error(w, http.StatusText(404), 404)
- return
- }
- ctx := context.WithValue(r.Context(), "site", site)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
- }
-}
diff --git a/cmd/server/routes.go b/cmd/server/routes.go
new file mode 100644
index 0000000..7d7ab32
--- /dev/null
+++ b/cmd/server/routes.go
@@ -0,0 +1,186 @@
+package main
+
+import (
+ "bytes"
+ "context"
+ "net/http"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/go-chi/chi"
+ "github.com/go-chi/chi/middleware"
+ "github.com/go-chi/jwtauth"
+ "src.userspace.com.au/sws"
+ "src.userspace.com.au/templates"
+)
+
+var tokenAuth *jwtauth.JWTAuth
+
+func init() {
+ tokenAuth = jwtauth.New("HS256", []byte("lkjasd0f9u203ijsldkfj"), nil)
+}
+
+func createRouter(db sws.Store) (chi.Router, error) {
+ tmplsCommon := []string{"flash.tmpl", "navbar.tmpl"}
+ tmplsAuthed := append(tmplsCommon, []string{"layouts/base.tmpl", "charts.tmpl"}...)
+ tmplsPublic := append(tmplsCommon, "layouts/public.tmpl")
+
+ tmpls, err := LoadHTMLTemplateMap(map[string][]string{
+ "sites": append([]string{"sites.tmpl"}, tmplsAuthed...),
+ "site": append([]string{"site.tmpl"}, tmplsAuthed...),
+ "home": append([]string{"home.tmpl"}, tmplsPublic...),
+ "login": append([]string{"login.tmpl"}, tmplsPublic...),
+ "user": append([]string{"user.tmpl"}, tmplsAuthed...),
+ "example": []string{"example.tmpl"},
+ }, funcMap)
+ if err != nil {
+ return nil, err
+ }
+ debug(tmpls["login"].DefinedTemplates())
+ debug(tmpls["home"].DefinedTemplates())
+
+ rndr := templates.NewRenderer(tmpls)
+
+ r := chi.NewRouter()
+ r.Use(middleware.RealIP)
+ r.Use(middleware.RequestID)
+ r.Use(middleware.RequestID)
+ compressor := middleware.NewCompressor(5, "text/html", "text/css")
+ r.Use(compressor.Handler())
+ if *verbose {
+ r.Use(middleware.Logger)
+ }
+ r.Use(middleware.Recoverer)
+
+ // For counter
+ r.Get("/sws.js", handleCounter(*addr))
+ r.Get("/sws.gif", handleHitCounter(db))
+ //r.Get("/hits", handleHits(db))
+
+ r.Group(func(r chi.Router) {
+ r.Use(jwtauth.Verifier(tokenAuth))
+ // Populate contect with user if present
+ r.Use(getUserCtx(db))
+
+ // Public routes
+ r.Group(func(r chi.Router) {
+ r.Get("/", handleIndex(rndr))
+ r.Get(loginURL, handleLogin(db, rndr))
+ })
+
+ r.Post(loginURL, handleLogin(db, rndr))
+
+ r.Get("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ p := strings.TrimPrefix(r.URL.Path, "/")
+ if b, err := StaticLoadTemplate(p); err == nil {
+ name := filepath.Base(p)
+ http.ServeContent(w, r, name, time.Now(), bytes.NewReader(b))
+ }
+ }))
+
+ // Authed routes
+ r.Group(func(r chi.Router) {
+ // Ensure we have a user in context
+ r.Use(func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if user := r.Context().Value("user"); user == nil {
+ authRedirect(w, r, "authentication required")
+ }
+ next.ServeHTTP(w, r)
+ })
+ })
+
+ r.Get(logoutURL, handleLogout(rndr))
+ r.Route("/sites", func(r chi.Router) {
+ r.Get("/", handleSites(db, rndr))
+ r.Post("/", handleSites(db, rndr))
+ r.Get("/new", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ payload := newTemplateData(r)
+ payload.Site = &sws.Site{}
+ if err := rndr.Render(w, "site", payload); err != nil {
+ httpError(w, 500, err.Error())
+ return
+ }
+ }))
+ r.Route("/{siteID}", func(r chi.Router) {
+ // Populate contect with site if present
+ r.Use(getSiteCtx(db))
+ r.Get("/", handleSite(db, rndr))
+ r.Post("/", handleSite(db, rndr))
+ r.Get("/edit", handleSiteEdit(db, rndr))
+ r.Route("/sparklines", func(r chi.Router) {
+ r.Get("/{b:\\d+}-{e:\\d+}.svg", sparklineHandler(db))
+ })
+ r.Route("/charts", func(r chi.Router) {
+ r.Get("/{b:\\d+}-{e:\\d+}.svg", svgChartHandler(db))
+ //r.Get("/{b:\\d+}-{e:\\d+}.png", svgChartHandler(db))
+ })
+ })
+ })
+ r.Route("/users", func(r chi.Router) {
+ userH := handleUsers(db, rndr)
+ r.Route("/{email}", func(r chi.Router) {
+ r.Get("/", userH)
+ r.Post("/", userH)
+ })
+ })
+ })
+ })
+
+ // Example
+ r.Get("/test.html", handleExample(rndr))
+ return r, nil
+}
+
+func getUserCtx(db sws.UserStore) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer func() {
+ next.ServeHTTP(w, r)
+ }()
+
+ _, claims, err := jwtauth.FromContext(r.Context())
+ if err != nil {
+ log("failed to extract user from context", err)
+ return
+ }
+
+ // Token is authenticated, get claims
+
+ id, ok := claims["user_id"]
+ if !ok {
+ log("missing user ID")
+ return
+ }
+
+ user, err := db.GetUserByID(int(id.(float64)))
+ if err != nil {
+ log("missing user")
+ return
+ }
+ debug("found user, adding to context")
+ ctx := context.WithValue(r.Context(), "user", user)
+ r = r.WithContext(ctx)
+ })
+ }
+}
+
+func getSiteCtx(db sws.SiteStore) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ id, err := strconv.Atoi(chi.URLParam(r, "siteID"))
+ if err != nil {
+ panic(err)
+ }
+ site, err := db.GetSiteByID(id)
+ if err != nil {
+ http.Error(w, http.StatusText(404), 404)
+ return
+ }
+ ctx := context.WithValue(r.Context(), "site", site)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+ }
+}
diff --git a/cmd/server/sites.go b/cmd/server/sites.go
index f7e7793..46bcdf3 100644
--- a/cmd/server/sites.go
+++ b/cmd/server/sites.go
@@ -2,12 +2,29 @@ package main
import (
"net/http"
+ "strings"
"src.userspace.com.au/sws"
)
func handleSites(db sws.SiteStore, rndr Renderer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == "POST" {
+ site := &sws.Site{
+ Name: r.FormValue("name"),
+ Description: r.FormValue("description"),
+ Aliases: r.FormValue("aliases"),
+ }
+ if errs := site.Validate(); len(errs) > 0 {
+ log("invalid site:", errs)
+ r = flashSet(r, flashError, strings.Join(errs, "<br>"))
+ } else if err := db.SaveSite(site); err != nil {
+ httpError(w, 500, err.Error())
+ return
+ }
+ r = flashSet(r, flashSuccess, "site created")
+ }
+
sites, err := db.GetSites()
if err != nil {
httpError(w, 500, err.Error())
@@ -32,6 +49,10 @@ func handleSite(db sws.SiteStore, rndr Renderer) http.HandlerFunc {
httpError(w, 422, "no site in context")
return
}
+
+ payload := newTemplateData(r)
+ payload.Site = site
+
begin, end := extractTimeRange(r)
if begin == nil || end == nil {
httpError(w, 406, "invalid time range")
@@ -52,23 +73,58 @@ func handleSite(db sws.SiteStore, rndr Renderer) http.HandlerFunc {
httpError(w, 406, err.Error())
return
}
- hitSet.Fill(begin, end)
- hitSet.SortByDate()
+ if hitSet != nil {
+ hitSet.Fill(begin, end)
+ hitSet.SortByDate()
+ payload.Hits = hitSet
- pageSet, err := sws.NewPageSet(hitSet)
- if err != nil {
- httpError(w, 406, err.Error())
+ pageSet, err := sws.NewPageSet(hitSet)
+ if err != nil {
+ httpError(w, 406, err.Error())
+ return
+ }
+
+ if pageSet != nil {
+ pageSet.SortByHits()
+ payload.PageSet = pageSet
+ }
+ browserSet := sws.NewBrowserSet(hitSet)
+ payload.Browsers = browserSet
+ }
+
+ if err := rndr.Render(w, "site", payload); err != nil {
+ httpError(w, 500, err.Error())
return
}
- pageSet.SortByHits()
+ }
+}
+
+func handleSiteEdit(db sws.SiteStore, rndr Renderer) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ site, ok := ctx.Value("site").(*sws.Site)
+ if !ok {
+ httpError(w, 422, "no site in context")
+ return
+ }
+
+ if r.Method == "POST" {
+ site.Name = r.FormValue("name")
+ site.Description = r.FormValue("description")
+ site.Aliases = r.FormValue("aliases")
- browserSet := sws.NewBrowserSet(hitSet)
+ if errs := site.Validate(); len(errs) > 0 {
+ log("invalid site:", errs)
+ r = flashSet(r, flashError, strings.Join(errs, "<br>"))
+ } else if err := db.SaveSite(site); err != nil {
+ httpError(w, 500, err.Error())
+ return
+ }
+ r = flashSet(r, flashSuccess, "site updated")
+ }
payload := newTemplateData(r)
payload.Site = site
- payload.PageSet = &pageSet
- payload.Browsers = &browserSet
- payload.Hits = hitSet
if err := rndr.Render(w, "site", payload); err != nil {
httpError(w, 500, err.Error())
diff --git a/cmd/server/users.go b/cmd/server/users.go
new file mode 100644
index 0000000..4218742
--- /dev/null
+++ b/cmd/server/users.go
@@ -0,0 +1,62 @@
+package main
+
+import (
+ "net/http"
+ "strings"
+
+ "github.com/go-chi/chi"
+ "src.userspace.com.au/sws"
+)
+
+func handleUsers(db sws.UserStore, rndr Renderer) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ authedUser, ok := ctx.Value("user").(*sws.User)
+ if !ok {
+ httpError(w, 422, "no user in context")
+ return
+ }
+ email := chi.URLParam(r, "email")
+ if email == "" {
+ httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
+ return
+ }
+ user, err := db.GetUserByEmail(email)
+ if err != nil {
+ http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
+ return
+ }
+ if *authedUser.ID != *user.ID && !authedUser.Admin {
+ log("failed attempt to edit user", *user.ID, *authedUser.ID)
+ http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
+ return
+ }
+
+ if r.Method == "POST" {
+ user.FirstName = stringPtr(r.FormValue("first_name"))
+ user.LastName = stringPtr(r.FormValue("last_name"))
+ user.Email = stringPtr(r.FormValue("email"))
+ user.PasswordConfirm = r.FormValue("password_confirmation")
+ user.Password = r.FormValue("password")
+ if errs := user.Validate(); len(errs) > 0 {
+ log("invalid user:", errs)
+ r = flashSet(r, flashError, strings.Join(errs, "<br>"))
+ } else {
+ if err := db.SaveUser(user); err != nil {
+ log("failed to update user:", err)
+ r = flashSet(r, flashError, err.Error())
+ } else {
+ r = flashSet(r, flashSuccess, "successfully updated")
+ }
+ }
+ }
+
+ payload := newTemplateData(r)
+ payload.User = user
+
+ if err := rndr.Render(w, "user", payload); err != nil {
+ httpError(w, 500, err.Error())
+ return
+ }
+ }
+}
diff --git a/counter/sws.js b/counter/sws.js
index 8c4223d..5145cee 100644
--- a/counter/sws.js
+++ b/counter/sws.js
@@ -46,7 +46,7 @@ var viewPort = (w.innerWidth || de.clientWidth || d.body.clientWidth) + 'x' +
ready(function () {
if (!_sws.noauto) {
var ep = new URL(_sws.d)
- count('{{ .endpoint }}', {
+ count('{{ .Endpoint }}', {
i: _sws.site,
s: l.protocol,
h: l.host,
diff --git a/page_set.go b/page_set.go
index d97c963..9077731 100644
--- a/page_set.go
+++ b/page_set.go
@@ -28,6 +28,9 @@ func NewPageSet(hs *HitSet) (PageSet, error) {
//p.hitSet.Add(h)
tmp[h.Path] = p
}
+ if len(tmp) < 1 {
+ return nil, nil
+ }
out := make([]*Page, len(tmp))
i := 0
for _, p := range tmp {
diff --git a/site.go b/site.go
index 892a6e5..bb23f8a 100644
--- a/site.go
+++ b/site.go
@@ -1,27 +1,26 @@
package sws
import (
- "fmt"
"time"
)
const slugSalt = "saltyslugs"
type Site struct {
- ID *int `json:"id,omitempty"`
- Name *string `json:"name,omitempty"`
- Description string `json:"description,omitempty"`
- Aliases *string `json:"aliases,omitempty"`
- Enabled bool `json:"enabled"`
+ ID *int `json:"id,omitempty"`
+ Name string `json:"name,omitempty"`
+ Description string `json:"description,omitempty"`
+ Aliases string `json:"aliases,omitempty"`
+ Enabled bool `json:"enabled"`
//ExcludePaths []string
CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"`
UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"`
}
-func (d *Site) Validate() []error {
- var out []error
- if d.Name == nil {
- out = append(out, fmt.Errorf("missing name"))
+func (d *Site) Validate() []string {
+ var out []string
+ if d.Name == "" {
+ out = append(out, "missing name")
}
return out
}
diff --git a/sql/sqlite3/01_sites.sql b/sql/sqlite3/01_sites.sql
index 6422e94..59d336a 100644
--- a/sql/sqlite3/01_sites.sql
+++ b/sql/sqlite3/01_sites.sql
@@ -1,12 +1,12 @@
create table sites (
id integer primary key autoincrement,
name varchar not null check(length(name) >= 4 and length(name) <= 255),
- description varchar null,
- aliases varchar null,
+ description varchar not null,
+ aliases varchar not null,
enabled integer not null default 0,
created_at timestamp not null,
updated_at timestamp not null
);
-insert into sites (name, description, enabled, created_at, updated_at)
-values ('localhost', 'Example site', 1, date('now'), date('now'));
+insert into sites (name, description, aliases, enabled, created_at, updated_at)
+values ('localhost', 'Example site', '', 1, date('now'), date('now'));
diff --git a/static/default.css b/static/default.css
index 8080298..ffe6d8c 100644
--- a/static/default.css
+++ b/static/default.css
@@ -51,7 +51,7 @@ header.site {
height: 50px;
}
-.flashes {
+.flash {
display: block;
left: 0;
margin-left: auto;
diff --git a/store.go b/store.go
index e4808a4..a24f228 100644
--- a/store.go
+++ b/store.go
@@ -13,7 +13,7 @@ type SiteStore interface {
GetSites() ([]*Site, error)
GetSiteByID(int) (*Site, error)
GetHits(Site, map[string]interface{}) ([]*Hit, error)
- //SaveSite(*Site) error
+ SaveSite(*Site) error
}
type HitStore interface {
@@ -27,4 +27,5 @@ type CounterStore interface {
type UserStore interface {
GetUserByID(int) (*User, error)
GetUserByEmail(string) (*User, error)
+ SaveUser(*User) error
}
diff --git a/store/sqlite3.go b/store/sqlite3.go
index 33590fb..631fde2 100644
--- a/store/sqlite3.go
+++ b/store/sqlite3.go
@@ -113,6 +113,13 @@ func (s *Sqlite3) GetUserByID(id int) (*sws.User, error) {
return &u, nil
}
+func (s *Sqlite3) SaveUser(u *sws.User) error {
+ if _, err := s.db.NamedExec(stmts["saveUser"], u); err != nil {
+ return err
+ }
+ return nil
+}
+
func processFilter(sql *string, filter map[string]interface{}) {
if sql == nil {
panic("empty sql")
@@ -147,7 +154,12 @@ where id = $1 limit 1`,
"saveSite": `insert into sites (
name, description, aliases, enabled, created_at, updated_at)
-values (:name, :description, :aliases, :enabled, :created_at, :updated_at)`,
+values (:name, :description, :aliases, :enabled, date('now'), date('now'))
+on conflict(id) do update set
+name = :name,
+description = :description,
+aliases = :aliases,
+updated_at = date('now')`,
"userAgentByHash": `select id, hash, name, last_seen_at from sites
where hash = $1 limit 1`,
@@ -178,4 +190,15 @@ where email = $1`,
created_at, updated_at, last_login_at
from users
where id = $1`,
+
+ "saveUser": `insert into users
+(id, first_name, last_name, email, pw_hash, pw_salt, created_at, updated_at)
+values (:id, :first_name, :last_name, :email, :pw_hash, :pw_salt, date('now'), date('now'))
+on conflict(id) do update set
+first_name = :first_name,
+last_name = :last_name,
+email = :email,
+pw_hash = :pw_hash,
+pw_salt = :pw_salt,
+updated_at = date('now')`,
}
diff --git a/templates/flash.tmpl b/templates/flash.tmpl
index 826320a..2a5a7cf 100644
--- a/templates/flash.tmpl
+++ b/templates/flash.tmpl
@@ -1,9 +1,7 @@
{{ define "flash" }}
- {{ if .Flashes }}
- <div class="flashes">
- {{ range .Flashes }}
- <div class="flash {{ .Level }}">{{ .Message }}</div>
- {{ end }}
+ {{ if .Flash }}
+ <div class="flash">
+ {{ .Flash }}
</div>
{{ end }}
{{ end }}
diff --git a/templates/login.tmpl b/templates/login.tmpl
index f4bf395..8547d2c 100644
--- a/templates/login.tmpl
+++ b/templates/login.tmpl
@@ -4,10 +4,10 @@
<h2>Login</h2>
<form method="post" action="/login">
<div class="field">
- <input type="email" name="email" placeholder="your email" />
+ <input type="email" name="email" placeholder="your email" required />
</div>
<div class="field">
- <input type="password" name="password" placeholder="your password" />
+ <input type="password" name="password" placeholder="your password" required />
</div>
<div class="field">
<input type="submit" />
diff --git a/templates/site.tmpl b/templates/site.tmpl
index 3a2b198..fc74180 100644
--- a/templates/site.tmpl
+++ b/templates/site.tmpl
@@ -1,30 +1,81 @@
{{ define "content" }}
<main>
<header>
- {{ with .Site }}
- <h1>{{ .Name }}</h1>
- <span>{{ .Description }}</span>
+ {{ if .Site.ID }}
+ {{ with .Site }}
+ <h1>{{ .Name }}</h1>
+ <span>{{ .Description }}</span>
+ {{ end }}
+ {{ else }}
+ <h1>New Site</h1>
{{ end }}
</header>
- <fig>
- {{ template "timeBarChart" .Hits }}
- </fig>
+ {{ if .Hits }}
+ {{ template "siteView" . }}
+ {{ else }}
+ <form method="post" action="/sites{{ if .Site.ID }}/{{ .Site.ID }}{{ end }}">
+ {{ template "siteEdit" .Site }}
+ <div class="field">
+ <input type="submit" />
+ </div>
+ </form>
+ {{ template "siteConfig" . }}
+ {{ end }}
+ </main>
+{{ end }}
+
+{{ define "pageForList" }}
+ <li>
+ <h4 class="path">{{ .Path }}</h4>
+ <span class="title">{{ .Title }}</span>
+ <span class="last-visit">{{ .LastVisitedAt|timeLong }}</span>
+ <span class="count">{{ .Count }}</span>
+ </li>
+{{ end }}
+
+{{ define "browserForList" }}
+ <li>
+ <h4 class="name">{{ .Name }}</h4>
+ <span class="last-seen">{{ .LastSeenAt|timeLong }}</span>
+ <span class="count">{{ .Count }}</span>
+ </li>
+{{ end }}
+
+{{ define "siteView" }}
+ <div class="panel panel-wide">
+ <h2>Hits</h2>
+ {{ if .Hits }}
+ <fig>
+ {{ template "timeBarChart" .Hits }}
+ </fig>
+ {{ else }}
+ <p>No hits yet</p>
+ {{ end }}
+ </div>
+ <div class="panel">
<h2>Popular pages</h2>
- <fig>
- {{ template "barChartHorizontal" .PageSet }}
- </fig>
-
- <ul class="pages">
- {{ $pages := .PageSet }}
- {{ range .PageSet }}
- {{ template "pageForList" . }}
- <fig>
- {{ $pathHits := $pages.Page .Path }}
- {{ template "timeBarChart" $pathHits }}
- </fig>
- {{ end }}
- </ul>
+ {{ if .PageSet }}
+ <fig>
+ {{ template "barChartHorizontal" .PageSet }}
+ </fig>
+
+ <ul class="pages">
+ {{ $pages := .PageSet }}
+ {{ range .PageSet }}
+ {{ template "pageForList" . }}
+ <fig>
+ {{ $pathHits := $pages.Page .Path }}
+ {{ template "timeBarChart" $pathHits }}
+ </fig>
+ {{ end }}
+ </ul>
+ {{ else }}
+ <p>No page views yet</p>
+ {{ end }}
+ </div>
+
+ <div class="panel">
<h2>User Agents</h2>
{{ if .Browsers }}
<fig>
@@ -36,24 +87,31 @@
{{ end }}
</ul>
{{ else }}
- <p>No user agents</p>
+ <p>No browsers visits yet</p>
{{ end }}
- </main>
+ </div>
{{ end }}
-{{ define "pageForList" }}
-<li>
- <h4 class="path">{{ .Path }}</h4>
- <span class="title">{{ .Title }}</span>
- <span class="last-visit">{{ .LastVisitedAt|timeLong }}</span>
- <span class="count">{{ .Count }}</span>
-</li>
+{{ define "siteEdit" }}
+ <div class="field">
+ <label>Name</label>
+ <input type="text" name="name" value="{{ .Name }}" placeholder="example.com" required />
+ </div>
+ <div class="field">
+ <label>Description</label>
+ <input type="text" name="description" value="{{ .Description }}" placeholder="site description" />
+ </div>
+ <div class="field">
+ <label>Aliases</label>
+ <input type="text" name="aliases" value="{{ .Aliases }}" placeholder="www.example.com" />
+ </div>
{{ end }}
-{{ define "browserForList" }}
-<li>
- <h4 class="name">{{ .Name }}</h4>
- <span class="last-seen">{{ .LastSeenAt|timeLong }}</span>
- <span class="count">{{ .Count }}</span>
-</li>
+{{ define "siteConfig" }}
+ <p>Add the following HTML snippet to your website to collect data:</p>
+ <pre><code>&lt;!-- start SWS snippet --&gt;
+&lt;script async src="{{ .Payload }}" data-site="{{ .Site.ID }}"&gt;&lt;/script&gt;
+&lt;noscript&gt;&lt;img src="{{ .Endpoint }}" /&gt;&lt;/noscript&gt;
+&lt;!-- end SWS snippet --&gt;
+ </code></pre>
{{ end }}
diff --git a/templates/user.tmpl b/templates/user.tmpl
new file mode 100644
index 0000000..eaa1733
--- /dev/null
+++ b/templates/user.tmpl
@@ -0,0 +1,39 @@
+{{ define "content" }}
+ <main>
+ <header>
+ {{ with .User }}
+ <h1>Profile</h1>
+ <span>{{ .Email }}</span>
+ {{ end }}
+ </header>
+ <form method="post">
+ {{ template "userEdit" .User }}
+ <div class="field">
+ <input type="submit" />
+ </div>
+ </form>
+ </main>
+{{ end }}
+
+{{ define "userEdit" }}
+ <div class="field">
+ <label>Email</label>
+ <input type="email" name="email" value="{{ .Email }}" required />
+ </div>
+ <div class="field">
+ <label>First name</label>
+ <input type="text" name="first_name" value="{{ .FirstName }}" />
+ </div>
+ <div class="field">
+ <label>Last name</label>
+ <input type="text" name="last_name" value="{{ .LastName }}" />
+ </div>
+ <div class="field">
+ <label>Password</label>
+ <input type="password" name="password" />
+ </div>
+ <div class="field">
+ <label>Password confirmation</label>
+ <input type="password" name="password_confirmation" />
+ </div>
+{{ end }}
diff --git a/user.go b/user.go
index 78e3497..25b979b 100644
--- a/user.go
+++ b/user.go
@@ -11,16 +11,19 @@ import (
)
type User struct {
- ID *int `json:"id,omitempty"`
- Email *string `json:"email,omitempty"`
- FirstName *string `json:"first_name,omitempty" db:"first_name"`
- LastName *string `json:"last_name,omitempty" db:"last_name"`
- Enabled bool `json:"enabled"`
- PwHash *string `json:"pw_hash" db:"pw_hash"`
- PwSalt *string `json:"pw_salt" db:"pw_salt"`
- LastLoginAt *time.Time `json:"last_login_at" db:"last_login_at"`
- CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"`
- UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"`
+ ID *int `json:"id,omitempty"`
+ Email *string `json:"email,omitempty"`
+ FirstName *string `json:"first_name,omitempty" db:"first_name"`
+ LastName *string `json:"last_name,omitempty" db:"last_name"`
+ Enabled bool `json:"enabled"`
+ PwHash *string `json:"-" db:"pw_hash"`
+ PwSalt *string `json:"-" db:"pw_salt"`
+ Password string `json:"-" db:"-"`
+ PasswordConfirm string `json:"-" db:"-"`
+ Admin bool
+ LastLoginAt *time.Time `json:"last_login_at" db:"last_login_at"`
+ CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"`
+ UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"`
}
const (
@@ -30,6 +33,35 @@ const (
pwLength = 32
)
+func (u *User) Validate() []string {
+ var out []string
+
+ if u.FirstName == nil || *u.FirstName == "" {
+ out = append(out, fmt.Sprintf("invalid first name"))
+ }
+ if u.LastName == nil || *u.LastName == "" {
+ out = append(out, fmt.Sprintf("invalid last name"))
+ }
+ if u.Email == nil || *u.Email == "" {
+ out = append(out, fmt.Sprintf("invalid email"))
+ }
+ if u.PwHash == nil || *u.PwHash == "" {
+ out = append(out, fmt.Sprint("invalid password"))
+ }
+ if u.PasswordConfirm != "" {
+ if u.Password != u.PasswordConfirm {
+ out = append(out, fmt.Sprintf("password confirmation mismatch"))
+ } else {
+ if err := u.SetPassword(u.Password); err != nil {
+ out = append(out, fmt.Sprintf("failed to update password: %s", err))
+ }
+ u.Password = ""
+ u.PasswordConfirm = ""
+ }
+ }
+ return out
+}
+
func (u *User) SetPassword(pw string) error {
// Generate a Salt
saltB := make([]byte, 16)