aboutsummaryrefslogtreecommitdiff
path: root/src/vendor/github.com/rs/xhandler/middleware_test.go
blob: 51306e390d5357471bdea24d610beb219bb2f8b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package xhandler

import (
	"log"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"golang.org/x/net/context"
)

func TestTimeoutHandler(t *testing.T) {
	ctx := context.WithValue(context.Background(), contextKey, "value")
	xh := TimeoutHandler(time.Second)(&handler{})
	h := New(ctx, xh)
	w := httptest.NewRecorder()
	r, err := http.NewRequest("GET", "http://example.com/foo", nil)
	if err != nil {
		log.Fatal(err)
	}
	h.ServeHTTP(w, r)
	assert.Equal(t, "value with deadline", w.Body.String())
}

type closeNotifyWriter struct {
	*httptest.ResponseRecorder
	closed bool
}

func (w *closeNotifyWriter) CloseNotify() <-chan bool {
	notify := make(chan bool, 1)
	if w.closed {
		// return an already "closed" notifier
		notify <- true
	}
	return notify
}

func TestCloseHandlerClientClose(t *testing.T) {
	ctx := context.WithValue(context.Background(), contextKey, "value")
	xh := CloseHandler(&handler{})
	h := New(ctx, xh)
	w := &closeNotifyWriter{ResponseRecorder: httptest.NewRecorder(), closed: true}
	r, err := http.NewRequest("GET", "http://example.com/foo", nil)
	if err != nil {
		log.Fatal(err)
	}
	h.ServeHTTP(w, r)
	assert.Equal(t, "value canceled", w.Body.String())
}

func TestCloseHandlerRequestEnds(t *testing.T) {
	ctx := context.WithValue(context.Background(), contextKey, "value")
	xh := CloseHandler(&handler{})
	h := New(ctx, xh)
	w := &closeNotifyWriter{ResponseRecorder: httptest.NewRecorder(), closed: false}
	r, err := http.NewRequest("GET", "http://example.com/foo", nil)
	if err != nil {
		log.Fatal(err)
	}
	h.ServeHTTP(w, r)
	assert.Equal(t, "value", w.Body.String())
}

func TestIf(t *testing.T) {
	trueHandler := HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		assert.Equal(t, "/true", r.URL.Path)
	})
	falseHandler := HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		assert.NotEqual(t, "/true", r.URL.Path)
	})
	ctx := context.WithValue(context.Background(), contextKey, "value")
	xh := If(
		func(ctx context.Context, w http.ResponseWriter, r *http.Request) bool {
			return r.URL.Path == "/true"
		},
		func(next HandlerC) HandlerC {
			return trueHandler
		},
	)(falseHandler)
	h := New(ctx, xh)
	r, _ := http.NewRequest("GET", "http://example.com/true", nil)
	h.ServeHTTP(nil, r)
	r, _ = http.NewRequest("GET", "http://example.com/false", nil)
	h.ServeHTTP(nil, r)
}