aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFelix Hanley <felix@userspace.com.au>2020-05-12 07:00:36 +0000
committerFelix Hanley <felix@userspace.com.au>2020-05-12 07:00:36 +0000
commit848d1aa993e5bd25dd26eed67e6f7144f58ac198 (patch)
tree949afef2171793d06f9c96322f96bdc13f092d78
parentf39a00443e2785862ec2042e00b7bbf42044227d (diff)
downloadsws-848d1aa993e5bd25dd26eed67e6f7144f58ac198.tar.gz
sws-848d1aa993e5bd25dd26eed67e6f7144f58ac198.tar.bz2
Fix referrer handlingHEADmaster
-rw-r--r--cmd/server/hits.go16
-rw-r--r--cmd/server/site.go2
-rw-r--r--hit.go8
-rw-r--r--referrer.go45
-rw-r--r--referrer_test.go73
-rw-r--r--site.go20
6 files changed, 126 insertions, 38 deletions
diff --git a/cmd/server/hits.go b/cmd/server/hits.go
index 776e610..63b21c1 100644
--- a/cmd/server/hits.go
+++ b/cmd/server/hits.go
@@ -5,12 +5,11 @@ import (
"crypto/sha1"
"encoding/base64"
"fmt"
- "net"
"net/http"
"strings"
"text/template"
- "github.com/hashicorp/golang-lru"
+ lru "github.com/hashicorp/golang-lru"
"src.userspace.com.au/sws"
)
@@ -47,11 +46,6 @@ func handleHitCounter(db sws.CounterStore, mmdbPath string) http.HandlerFunc {
return
}
- hit.Addr = r.RemoteAddr
- if strings.Contains(r.RemoteAddr, ":") {
- hit.Addr, _, err = net.SplitHostPort(r.RemoteAddr)
- }
-
if r.Header.Get("X-Moz") == "prefetch" || r.Header.Get("X-Purpose") == "preview" {
w.Header().Set("Content-Type", "image/gif")
w.Write(gifBytes)
@@ -105,12 +99,8 @@ func verifyHit(db sws.SiteGetter, h *sws.Hit) (*sws.Site, error) {
debug(h.Host, "equals site name:", site.Name)
return site, nil
}
- if strings.Contains(site.Aliases, h.Host) {
- debug(h.Host, "equals site alias:", site.Name)
- return site, nil
- }
- if site.AcceptSubdomains && strings.HasSuffix(h.Host, site.Name) {
- debug(h.Host, "is subdomain:", site.Name)
+ if site.IncludesDomain(h.Host) {
+ debug(h.Host, "includes:", site.Name)
return site, nil
}
return nil, fmt.Errorf("invalid host")
diff --git a/cmd/server/site.go b/cmd/server/site.go
index 9156bc6..728d571 100644
--- a/cmd/server/site.go
+++ b/cmd/server/site.go
@@ -54,7 +54,7 @@ func handleSite(db sws.SiteStore, rndr Renderer) http.HandlerFunc {
}
}
if _, ok := filter["referrer"]; !ok {
- if rs := sws.NewReferrerSet(hitSet); rs != nil {
+ if rs := sws.NewReferrerSet(hitSet, *site); rs != nil {
rs.SortByHits()
payload.ReferrerSet = rs
}
diff --git a/hit.go b/hit.go
index 97dffef..b9e9ce2 100644
--- a/hit.go
+++ b/hit.go
@@ -2,6 +2,7 @@ package sws
import (
"fmt"
+ "net"
"net/http"
"net/url"
"sort"
@@ -61,9 +62,14 @@ func SortHits(hits []*Hit) {
}
func HitFromRequest(r *http.Request) (*Hit, error) {
+ // Strip port from remote address
+ addr := r.RemoteAddr
+ if strings.Contains(r.RemoteAddr, ":") {
+ addr, _, _ = net.SplitHostPort(r.RemoteAddr)
+ }
out := &Hit{
CreatedAt: time.Now(),
- Addr: r.RemoteAddr,
+ Addr: addr,
}
q := r.URL.Query()
diff --git a/referrer.go b/referrer.go
index cbb1198..521e7e3 100644
--- a/referrer.go
+++ b/referrer.go
@@ -3,48 +3,51 @@ package sws
import (
"net/url"
"sort"
- "strings"
"time"
)
type Referrer struct {
Name string `json:"name"`
+ URL string `json:"url"`
LastSeenAt time.Time `json:"last_seen_at" db:"last_seen_at"`
hitSet *HitSet
}
type ReferrerSet []*Referrer
-func NewReferrerSet(hs *HitSet) *ReferrerSet {
+func NewReferrerSet(hs *HitSet, site Site) *ReferrerSet {
tmp := make(map[string]*Referrer)
for _, h := range hs.Hits() {
- if h.Referrer == nil {
- continue
- }
+ host := "direct"
+ u := ""
- u, err := url.Parse(*h.Referrer)
- if err != nil || h.Host == u.Host {
- continue
+ if h.Referrer != nil {
+ if r, err := url.Parse(*h.Referrer); err == nil {
+ host = r.Host
+ }
+ u = *h.Referrer
}
- host := u.Host
- if u.Host == "" {
- host = "direct"
+ // Check for internal referrer
+ if site.IncludesDomain(host) {
+ //host = "internal"
+ continue
}
- r := &Referrer{
- Name: host,
- LastSeenAt: h.CreatedAt,
+ tmp[host] = &Referrer{
+ Name: host,
+ URL: u,
hitSet: hs.Filter(func(t *Hit) bool {
- if t.Referrer == nil {
+ if h.Referrer == nil && t.Referrer == nil {
+ return true
+ }
+ if h.Referrer == nil && t.Referrer != nil {
+ return false
+ }
+ if h.Referrer != nil && t.Referrer == nil {
return false
}
- return strings.Contains(*t.Referrer, u.Host)
+ return *t.Referrer == *t.Referrer
}),
}
- // if b.LastSeenAt.Before(h.CreatedAt) {
- // b.LastSeenAt = h.CreatedAt
- // }
- //b.hitSet.Add(h)
- tmp[u.Host] = r
}
if len(tmp) < 1 {
return nil
diff --git a/referrer_test.go b/referrer_test.go
new file mode 100644
index 0000000..b2a9cb5
--- /dev/null
+++ b/referrer_test.go
@@ -0,0 +1,73 @@
+package sws
+
+import (
+ "testing"
+ "time"
+)
+
+func TestNewReferrerSet(t *testing.T) {
+ now := time.Now()
+ site := Site{Name: "example.com"}
+
+ tests := []struct {
+ hits []*Hit
+ expected ReferrerSet
+ }{
+ {
+ hits: []*Hit{
+ {CreatedAt: now, Referrer: strPtr("http://example1.com")},
+ {CreatedAt: now, Referrer: strPtr("http://example1.com")},
+ {CreatedAt: now, Referrer: strPtr("http://example2.com")},
+ },
+ expected: ReferrerSet{
+ &Referrer{Name: "example1.com", URL: "http://example1.com"},
+ &Referrer{Name: "example2.com", URL: "http://example2.com"},
+ },
+ },
+ {
+ hits: []*Hit{
+ {CreatedAt: now, Referrer: strPtr("http://example1.com")},
+ {CreatedAt: now, Referrer: strPtr("http://example1.com")},
+ {CreatedAt: now, Referrer: nil},
+ },
+ expected: ReferrerSet{
+ &Referrer{Name: "example1.com", URL: "http://example1.com"},
+ &Referrer{Name: "direct", URL: ""},
+ },
+ },
+ {
+ hits: []*Hit{
+ {CreatedAt: now, Referrer: strPtr("http://example1.com")},
+ {CreatedAt: now, Referrer: strPtr("http://example1.com")},
+ {CreatedAt: now, Referrer: strPtr("http://example.com")},
+ {CreatedAt: now, Referrer: strPtr("http://example2.com")},
+ },
+ expected: ReferrerSet{
+ &Referrer{Name: "example1.com", URL: "http://example1.com"},
+ &Referrer{Name: "example2.com", URL: "http://example2.com"},
+ },
+ },
+ }
+
+ for i, tt := range tests {
+ hs, err := NewHitSet(FromHits(tt.hits))
+ if err != nil {
+ t.Fatalf("%d => failed %s", i, err)
+ }
+ rs := NewReferrerSet(hs, site)
+
+ if len(*rs) != len(tt.expected) {
+ t.Errorf("%d => expected %d, got %d", i, len(tt.expected), len(*rs))
+ }
+ for j := range *rs {
+ if (*rs)[j].Name != tt.expected[j].Name {
+ t.Errorf("%d => expected %s, got %s", i, tt.expected[j].Name, (*rs)[j].Name)
+ }
+ if (*rs)[j].URL != tt.expected[j].URL {
+ t.Errorf("%d => expected %s, got %s", i, tt.expected[j].URL, (*rs)[j].URL)
+ }
+ }
+ }
+}
+
+func strPtr(s string) *string { return &s }
diff --git a/site.go b/site.go
index 31e88c3..dc38788 100644
--- a/site.go
+++ b/site.go
@@ -1,6 +1,7 @@
package sws
import (
+ "strings"
"time"
)
@@ -18,10 +19,25 @@ type Site struct {
UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"`
}
-func (d *Site) Validate() []string {
+func (s *Site) Validate() []string {
var out []string
- if d.Name == "" {
+ if s.Name == "" {
out = append(out, "missing name")
}
return out
}
+
+func (s *Site) IncludesDomain(fqdn string) bool {
+ if fqdn == s.Name {
+ return true
+ }
+ for _, a := range strings.Split(s.Aliases, ",") {
+ if a == fqdn {
+ return true
+ }
+ }
+ if s.AcceptSubdomains && strings.HasSuffix(fqdn, s.Name) {
+ return true
+ }
+ return false
+}