diff options
Diffstat (limited to 'vendor/github.com/a-h/templ/runtime.go')
| -rw-r--r-- | vendor/github.com/a-h/templ/runtime.go | 714 |
1 files changed, 714 insertions, 0 deletions
diff --git a/vendor/github.com/a-h/templ/runtime.go b/vendor/github.com/a-h/templ/runtime.go new file mode 100644 index 0000000..c804b08 --- /dev/null +++ b/vendor/github.com/a-h/templ/runtime.go @@ -0,0 +1,714 @@ +package templ + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "html" + "html/template" + "io" + "net/http" + "reflect" + "sort" + "strings" + "sync" + + "github.com/a-h/templ/safehtml" +) + +// Types exposed by all components. + +// Component is the interface that all templates implement. +type Component interface { + // Render the template. + Render(ctx context.Context, w io.Writer) error +} + +// ComponentFunc converts a function that matches the Component interface's +// Render method into a Component. +type ComponentFunc func(ctx context.Context, w io.Writer) error + +// Render the template. +func (cf ComponentFunc) Render(ctx context.Context, w io.Writer) error { + return cf(ctx, w) +} + +// WithNonce sets a CSP nonce on the context and returns it. +func WithNonce(ctx context.Context, nonce string) context.Context { + ctx, v := getContext(ctx) + v.nonce = nonce + return ctx +} + +// GetNonce returns the CSP nonce value set with WithNonce, or an +// empty string if none has been set. +func GetNonce(ctx context.Context) (nonce string) { + if ctx == nil { + return "" + } + _, v := getContext(ctx) + return v.nonce +} + +func WithChildren(ctx context.Context, children Component) context.Context { + ctx, v := getContext(ctx) + v.children = &children + return ctx +} + +func ClearChildren(ctx context.Context) context.Context { + _, v := getContext(ctx) + v.children = nil + return ctx +} + +// NopComponent is a component that doesn't render anything. +var NopComponent = ComponentFunc(func(ctx context.Context, w io.Writer) error { return nil }) + +// GetChildren from the context. +func GetChildren(ctx context.Context) Component { + _, v := getContext(ctx) + if v.children == nil { + return NopComponent + } + return *v.children +} + +// EscapeString escapes HTML text within templates. +func EscapeString[T ~string](s T) string { + return html.EscapeString(string(s)) +} + +// Bool attribute value. +func Bool(value bool) bool { + return value +} + +// Classes for CSS. +// Supported types are string, ConstantCSSClass, ComponentCSSClass, map[string]bool. +func Classes(classes ...any) CSSClasses { + return CSSClasses(classes) +} + +// CSSClasses is a slice of CSS classes. +type CSSClasses []any + +// String returns the names of all CSS classes. +func (classes CSSClasses) String() string { + if len(classes) == 0 { + return "" + } + cp := newCSSProcessor() + for _, v := range classes { + cp.Add(v) + } + return cp.String() +} + +func newCSSProcessor() *cssProcessor { + return &cssProcessor{ + classNameToEnabled: make(map[string]bool), + } +} + +type cssProcessor struct { + classNameToEnabled map[string]bool + orderedNames []string +} + +func (cp *cssProcessor) Add(item any) { + switch c := item.(type) { + case []string: + for _, className := range c { + cp.AddClassName(className, true) + } + case string: + cp.AddClassName(c, true) + case ConstantCSSClass: + cp.AddClassName(c.ClassName(), true) + case ComponentCSSClass: + cp.AddClassName(c.ClassName(), true) + case map[string]bool: + // In Go, map keys are iterated in a randomized order. + // So the keys in the map must be sorted to produce consistent output. + keys := make([]string, len(c)) + var i int + for key := range c { + keys[i] = key + i++ + } + sort.Strings(keys) + for _, className := range keys { + cp.AddClassName(className, c[className]) + } + case []KeyValue[string, bool]: + for _, kv := range c { + cp.AddClassName(kv.Key, kv.Value) + } + case KeyValue[string, bool]: + cp.AddClassName(c.Key, c.Value) + case []KeyValue[CSSClass, bool]: + for _, kv := range c { + cp.AddClassName(kv.Key.ClassName(), kv.Value) + } + case KeyValue[CSSClass, bool]: + cp.AddClassName(c.Key.ClassName(), c.Value) + case CSSClasses: + for _, item := range c { + cp.Add(item) + } + case []CSSClass: + for _, item := range c { + cp.Add(item) + } + case func() CSSClass: + cp.AddClassName(c().ClassName(), true) + default: + cp.AddClassName(unknownTypeClassName, true) + } +} + +func (cp *cssProcessor) AddClassName(className string, enabled bool) { + cp.classNameToEnabled[className] = enabled + cp.orderedNames = append(cp.orderedNames, className) +} + +func (cp *cssProcessor) String() string { + // Order the outputs according to how they were input, and remove disabled names. + rendered := make(map[string]any, len(cp.classNameToEnabled)) + var names []string + for _, name := range cp.orderedNames { + if enabled := cp.classNameToEnabled[name]; !enabled { + continue + } + if _, hasBeenRendered := rendered[name]; hasBeenRendered { + continue + } + names = append(names, name) + rendered[name] = struct{}{} + } + + return strings.Join(names, " ") +} + +// KeyValue is a key and value pair. +type KeyValue[TKey comparable, TValue any] struct { + Key TKey `json:"name"` + Value TValue `json:"value"` +} + +// KV creates a new key/value pair from the input key and value. +func KV[TKey comparable, TValue any](key TKey, value TValue) KeyValue[TKey, TValue] { + return KeyValue[TKey, TValue]{ + Key: key, + Value: value, + } +} + +const unknownTypeClassName = "--templ-css-class-unknown-type" + +// Class returns a CSS class name. +// Deprecated: use a string instead. +func Class(name string) CSSClass { + return SafeClass(name) +} + +// SafeClass bypasses CSS class name validation. +// Deprecated: use a string instead. +func SafeClass(name string) CSSClass { + return ConstantCSSClass(name) +} + +// CSSClass provides a class name. +type CSSClass interface { + ClassName() string +} + +// ConstantCSSClass is a string constant of a CSS class name. +// Deprecated: use a string instead. +type ConstantCSSClass string + +// ClassName of the CSS class. +func (css ConstantCSSClass) ClassName() string { + return string(css) +} + +// ComponentCSSClass is a templ.CSS +type ComponentCSSClass struct { + // ID of the class, will be autogenerated. + ID string + // Definition of the CSS. + Class SafeCSS +} + +// ClassName of the CSS class. +func (css ComponentCSSClass) ClassName() string { + return css.ID +} + +// CSSID calculates an ID. +func CSSID(name string, css string) string { + sum := sha256.Sum256([]byte(css)) + hs := hex.EncodeToString(sum[:])[0:8] // NOTE: See issue #978. Minimum recommended hs length is 6. + // Benchmarking showed this was fastest, and with fewest allocations (1). + // Using strings.Builder (2 allocs). + // Using fmt.Sprintf (3 allocs). + return name + "_" + hs +} + +// NewCSSMiddleware creates HTTP middleware that renders a global stylesheet of ComponentCSSClass +// CSS if the request path matches, or updates the HTTP context to ensure that any handlers that +// use templ.Components skip rendering <style> elements for classes that are included in the global +// stylesheet. By default, the stylesheet path is /styles/templ.css +func NewCSSMiddleware(next http.Handler, classes ...CSSClass) CSSMiddleware { + return CSSMiddleware{ + Path: "/styles/templ.css", + CSSHandler: NewCSSHandler(classes...), + Next: next, + } +} + +// CSSMiddleware renders a global stylesheet. +type CSSMiddleware struct { + Path string + CSSHandler CSSHandler + Next http.Handler +} + +func (cssm CSSMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == cssm.Path { + cssm.CSSHandler.ServeHTTP(w, r) + return + } + // Add registered classes to the context. + ctx, v := getContext(r.Context()) + for _, c := range cssm.CSSHandler.Classes { + v.addClass(c.ID) + } + // Serve the request. Templ components will use the updated context + // to know to skip rendering <style> elements for any component CSS + // classes that have been included in the global stylesheet. + cssm.Next.ServeHTTP(w, r.WithContext(ctx)) +} + +// NewCSSHandler creates a handler that serves a stylesheet containing the CSS of the +// classes passed in. This is used by the CSSMiddleware to provide global stylesheets +// for templ components. +func NewCSSHandler(classes ...CSSClass) CSSHandler { + ccssc := make([]ComponentCSSClass, 0, len(classes)) + for _, c := range classes { + ccss, ok := c.(ComponentCSSClass) + if !ok { + continue + } + ccssc = append(ccssc, ccss) + } + return CSSHandler{ + Classes: ccssc, + } +} + +// CSSHandler is a HTTP handler that serves CSS. +type CSSHandler struct { + Logger func(err error) + Classes []ComponentCSSClass +} + +func (cssh CSSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/css") + for _, c := range cssh.Classes { + _, err := w.Write([]byte(c.Class)) + if err != nil && cssh.Logger != nil { + cssh.Logger(err) + } + } +} + +// RenderCSSItems renders the CSS to the writer, if the items haven't already been rendered. +func RenderCSSItems(ctx context.Context, w io.Writer, classes ...any) (err error) { + if len(classes) == 0 { + return nil + } + _, v := getContext(ctx) + sb := new(strings.Builder) + renderCSSItemsToBuilder(sb, v, classes...) + if sb.Len() == 0 { + return nil + } + if _, err = io.WriteString(w, `<style type="text/css"`); err != nil { + return err + } + if v.nonce != "" { + if err = writeStrings(w, ` nonce="`, EscapeString(v.nonce), `"`); err != nil { + return err + } + } + return writeStrings(w, `>`, sb.String(), `</style>`) +} + +func renderCSSItemsToBuilder(sb *strings.Builder, v *contextValue, classes ...any) { + for _, c := range classes { + switch ccc := c.(type) { + case ComponentCSSClass: + if !v.hasClassBeenRendered(ccc.ID) { + sb.WriteString(string(ccc.Class)) + v.addClass(ccc.ID) + } + case KeyValue[ComponentCSSClass, bool]: + if !ccc.Value { + continue + } + renderCSSItemsToBuilder(sb, v, ccc.Key) + case KeyValue[CSSClass, bool]: + if !ccc.Value { + continue + } + renderCSSItemsToBuilder(sb, v, ccc.Key) + case CSSClasses: + renderCSSItemsToBuilder(sb, v, ccc...) + case []CSSClass: + for _, item := range ccc { + renderCSSItemsToBuilder(sb, v, item) + } + case func() CSSClass: + renderCSSItemsToBuilder(sb, v, ccc()) + case []string: + // Skip. These are class names, not CSS classes. + case string: + // Skip. This is a class name, not a CSS class. + case ConstantCSSClass: + // Skip. This is a class name, not a CSS class. + case CSSClass: + // Skip. This is a class name, not a CSS class. + case map[string]bool: + // Skip. These are class names, not CSS classes. + case KeyValue[string, bool]: + // Skip. These are class names, not CSS classes. + case []KeyValue[string, bool]: + // Skip. These are class names, not CSS classes. + case KeyValue[ConstantCSSClass, bool]: + // Skip. These are class names, not CSS classes. + case []KeyValue[ConstantCSSClass, bool]: + // Skip. These are class names, not CSS classes. + } + } +} + +// SafeCSS is CSS that has been sanitized. +type SafeCSS string + +type SafeCSSProperty string + +var safeCSSPropertyType = reflect.TypeOf(SafeCSSProperty("")) + +// SanitizeCSS sanitizes CSS properties to ensure that they are safe. +func SanitizeCSS[T ~string](property string, value T) SafeCSS { + if reflect.TypeOf(value) == safeCSSPropertyType { + return SafeCSS(safehtml.SanitizeCSSProperty(property) + ":" + string(value) + ";") + } + p, v := safehtml.SanitizeCSS(property, string(value)) + return SafeCSS(p + ":" + v + ";") +} + +type Attributer interface { + Items() []KeyValue[string, any] +} + +// Attributes is an alias to map[string]any made for spread attributes. +type Attributes map[string]any + +var _ Attributer = Attributes{} + +// Returns the items of the attributes map in key sorted order. +func (a Attributes) Items() []KeyValue[string, any] { + var ( + items = make([]KeyValue[string, any], len(a)) + i int + ) + for k, v := range a { + items[i] = KeyValue[string, any]{Key: k, Value: v} + i++ + } + sort.Slice(items, func(i, j int) bool { + return items[i].Key < items[j].Key + }) + return items +} + +// OrderedAttributes stores attributes in order of insertion. +type OrderedAttributes []KeyValue[string, any] + +var _ Attributer = OrderedAttributes{} + +func (a OrderedAttributes) Items() []KeyValue[string, any] { + return a +} + +func writeStrings(w io.Writer, ss ...string) (err error) { + for _, s := range ss { + if _, err = io.WriteString(w, s); err != nil { + return err + } + } + return nil +} + +func RenderAttributes(ctx context.Context, w io.Writer, attributes Attributer) (err error) { + for _, item := range attributes.Items() { + key := item.Key + value := item.Value + switch value := value.(type) { + case string: + if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(value), `"`); err != nil { + return err + } + case *string: + if value == nil { + continue + } + if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(*value), `"`); err != nil { + return err + } + case bool: + if !value { + continue + } + if err = writeStrings(w, ` `, EscapeString(key)); err != nil { + return err + } + case *bool: + if value == nil || !*value { + continue + } + if err = writeStrings(w, ` `, EscapeString(key)); err != nil { + return err + } + case int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, uintptr, + float32, float64, complex64, complex128: + if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(fmt.Sprint(value)), `"`); err != nil { + return err + } + case *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, + *float32, *float64, *complex64, *complex128: + value = ptrValue(value) + if value == nil { + continue + } + if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(fmt.Sprint(value)), `"`); err != nil { + return err + } + case KeyValue[string, bool]: + if !value.Value { + continue + } + if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(value.Key), `"`); err != nil { + return err + } + case KeyValue[bool, bool]: + if !value.Value || !value.Key { + continue + } + if err = writeStrings(w, ` `, EscapeString(key)); err != nil { + return err + } + case func() bool: + if !value() { + continue + } + if err = writeStrings(w, ` `, EscapeString(key)); err != nil { + return err + } + } + } + return nil +} + +func ptrValue(v any) any { + if v == nil { + return nil + } + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return v + } + if rv.IsNil() { + return nil + } + return rv.Elem().Interface() +} + +// Context. + +type contextKeyType int + +const contextKey = contextKeyType(0) + +type contextValue struct { + ss map[string]struct{} + onceHandles map[*OnceHandle]struct{} + children *Component + nonce string +} + +func (v *contextValue) setHasBeenRendered(h *OnceHandle) { + if v.onceHandles == nil { + v.onceHandles = map[*OnceHandle]struct{}{} + } + v.onceHandles[h] = struct{}{} +} + +func (v *contextValue) getHasBeenRendered(h *OnceHandle) (ok bool) { + if v.onceHandles == nil { + v.onceHandles = map[*OnceHandle]struct{}{} + } + _, ok = v.onceHandles[h] + return +} + +func (v *contextValue) addScript(s string) { + if v.ss == nil { + v.ss = map[string]struct{}{} + } + v.ss["script_"+s] = struct{}{} +} + +func (v *contextValue) hasScriptBeenRendered(s string) (ok bool) { + if v.ss == nil { + v.ss = map[string]struct{}{} + } + _, ok = v.ss["script_"+s] + return +} + +func (v *contextValue) addClass(s string) { + if v.ss == nil { + v.ss = map[string]struct{}{} + } + v.ss["class_"+s] = struct{}{} +} + +func (v *contextValue) hasClassBeenRendered(s string) (ok bool) { + if v.ss == nil { + v.ss = map[string]struct{}{} + } + _, ok = v.ss["class_"+s] + return +} + +// InitializeContext initializes context used to store internal state used during rendering. +func InitializeContext(ctx context.Context) context.Context { + if _, ok := ctx.Value(contextKey).(*contextValue); ok { + return ctx + } + v := &contextValue{} + ctx = context.WithValue(ctx, contextKey, v) + return ctx +} + +func getContext(ctx context.Context) (context.Context, *contextValue) { + v, ok := ctx.Value(contextKey).(*contextValue) + if !ok { + ctx = InitializeContext(ctx) + v = ctx.Value(contextKey).(*contextValue) + } + return ctx, v +} + +var bufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, +} + +func GetBuffer() *bytes.Buffer { + return bufferPool.Get().(*bytes.Buffer) +} + +func ReleaseBuffer(b *bytes.Buffer) { + b.Reset() + bufferPool.Put(b) +} + +type ints interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +type uints interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} + +type floats interface { + ~float32 | ~float64 +} + +type complexNumbers interface { + ~complex64 | ~complex128 +} + +type stringable interface { + ints | uints | floats | complexNumbers | ~string | ~bool +} + +// JoinStringErrs joins an optional list of errors. +func JoinStringErrs[T stringable](s T, errs ...error) (string, error) { + return fmt.Sprint(s), errors.Join(errs...) +} + +// Error returned during template rendering. +type Error struct { + Err error + // FileName of the template file. + FileName string + // Line index of the error. + Line int + // Col index of the error. + Col int +} + +func (e Error) Error() string { + if e.FileName == "" { + e.FileName = "templ" + } + return fmt.Sprintf("%s: error at line %d, col %d: %v", e.FileName, e.Line, e.Col, e.Err) +} + +func (e Error) Unwrap() error { + return e.Err +} + +// Raw renders the input HTML to the output without applying HTML escaping. +// +// Use of this component presents a security risk - the HTML should come from +// a trusted source, because it will be included as-is in the output. +func Raw[T ~string](html T, errs ...error) Component { + return ComponentFunc(func(ctx context.Context, w io.Writer) (err error) { + if err = errors.Join(errs...); err != nil { + return err + } + _, err = io.WriteString(w, string(html)) + return err + }) +} + +// FromGoHTML creates a templ Component from a Go html/template template. +func FromGoHTML(t *template.Template, data any) Component { + return ComponentFunc(func(ctx context.Context, w io.Writer) (err error) { + return t.Execute(w, data) + }) +} + +// ToGoHTML renders the component to a Go html/template template.HTML string. +func ToGoHTML(ctx context.Context, c Component) (s template.HTML, err error) { + b := GetBuffer() + defer ReleaseBuffer(b) + if err = c.Render(ctx, b); err != nil { + return + } + s = template.HTML(b.String()) + return +} |
