aboutsummaryrefslogtreecommitdiff
path: root/vendor/src/github.com/justinas/alice/chain_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/src/github.com/justinas/alice/chain_test.go')
-rw-r--r--vendor/src/github.com/justinas/alice/chain_test.go144
1 files changed, 144 insertions, 0 deletions
diff --git a/vendor/src/github.com/justinas/alice/chain_test.go b/vendor/src/github.com/justinas/alice/chain_test.go
new file mode 100644
index 0000000..49c0470
--- /dev/null
+++ b/vendor/src/github.com/justinas/alice/chain_test.go
@@ -0,0 +1,144 @@
+// Package alice implements a middleware chaining solution.
+package alice
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "reflect"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+// A constructor for middleware
+// that writes its own "tag" into the RW and does nothing else.
+// Useful in checking if a chain is behaving in the right order.
+func tagMiddleware(tag string) Constructor {
+ return func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(tag))
+ h.ServeHTTP(w, r)
+ })
+ }
+}
+
+// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
+// but the best we can do.
+func funcsEqual(f1, f2 interface{}) bool {
+ val1 := reflect.ValueOf(f1)
+ val2 := reflect.ValueOf(f2)
+ return val1.Pointer() == val2.Pointer()
+}
+
+var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("app\n"))
+})
+
+// Tests creating a new chain
+func TestNew(t *testing.T) {
+ c1 := func(h http.Handler) http.Handler {
+ return nil
+ }
+ c2 := func(h http.Handler) http.Handler {
+ return http.StripPrefix("potato", nil)
+ }
+
+ slice := []Constructor{c1, c2}
+
+ chain := New(slice...)
+ assert.True(t, funcsEqual(chain.constructors[0], slice[0]))
+ assert.True(t, funcsEqual(chain.constructors[1], slice[1]))
+}
+
+func TestThenWorksWithNoMiddleware(t *testing.T) {
+ assert.NotPanics(t, func() {
+ chain := New()
+ final := chain.Then(testApp)
+
+ assert.True(t, funcsEqual(final, testApp))
+ })
+}
+
+func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
+ chained := New().Then(nil)
+ assert.Equal(t, chained, http.DefaultServeMux)
+}
+
+func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
+ chained := New().ThenFunc(nil)
+ assert.Equal(t, chained, http.DefaultServeMux)
+}
+
+func TestThenOrdersHandlersRight(t *testing.T) {
+ t1 := tagMiddleware("t1\n")
+ t2 := tagMiddleware("t2\n")
+ t3 := tagMiddleware("t3\n")
+
+ chained := New(t1, t2, t3).Then(testApp)
+
+ w := httptest.NewRecorder()
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ chained.ServeHTTP(w, r)
+
+ assert.Equal(t, w.Body.String(), "t1\nt2\nt3\napp\n")
+}
+
+func TestAppendAddsHandlersCorrectly(t *testing.T) {
+ chain := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
+ newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
+
+ assert.Equal(t, len(chain.constructors), 2)
+ assert.Equal(t, len(newChain.constructors), 4)
+
+ chained := newChain.Then(testApp)
+
+ w := httptest.NewRecorder()
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ chained.ServeHTTP(w, r)
+
+ assert.Equal(t, w.Body.String(), "t1\nt2\nt3\nt4\napp\n")
+}
+
+func TestAppendRespectsImmutability(t *testing.T) {
+ chain := New(tagMiddleware(""))
+ newChain := chain.Append(tagMiddleware(""))
+
+ assert.NotEqual(t, &chain.constructors[0], &newChain.constructors[0])
+}
+
+func TestExtendAddsHandlersCorrectly(t *testing.T) {
+ chain1 := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
+ chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
+ newChain := chain1.Extend(chain2)
+
+ assert.Equal(t, len(chain1.constructors), 2)
+ assert.Equal(t, len(chain2.constructors), 2)
+ assert.Equal(t, len(newChain.constructors), 4)
+
+ chained := newChain.Then(testApp)
+
+ w := httptest.NewRecorder()
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ chained.ServeHTTP(w, r)
+
+ assert.Equal(t, w.Body.String(), "t1\nt2\nt3\nt4\napp\n")
+}
+
+func TestExtendRespectsImmutability(t *testing.T) {
+ chain := New(tagMiddleware(""))
+ newChain := chain.Extend(New(tagMiddleware("")))
+
+ assert.NotEqual(t, &chain.constructors[0], &newChain.constructors[0])
+}