summaryrefslogtreecommitdiff
path: root/vendor/github.com/smallstep/certificates/acme/api/middleware.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/smallstep/certificates/acme/api/middleware.go')
-rw-r--r--vendor/github.com/smallstep/certificates/acme/api/middleware.go98
1 files changed, 48 insertions, 50 deletions
diff --git a/vendor/github.com/smallstep/certificates/acme/api/middleware.go b/vendor/github.com/smallstep/certificates/acme/api/middleware.go
index 628da7e..afccca7 100644
--- a/vendor/github.com/smallstep/certificates/acme/api/middleware.go
+++ b/vendor/github.com/smallstep/certificates/acme/api/middleware.go
@@ -36,7 +36,7 @@ func addNonce(next nextHTTP) nextHTTP {
db := acme.MustDatabaseFromContext(r.Context())
nonce, err := db.CreateNonce(r.Context())
if err != nil {
- render.Error(w, r, err)
+ render.Error(w, err)
return
}
w.Header().Set("Replay-Nonce", string(nonce))
@@ -64,7 +64,7 @@ func verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
p, err := provisionerFromContext(r.Context())
if err != nil {
- render.Error(w, r, err)
+ render.Error(w, err)
return
}
@@ -88,7 +88,7 @@ func verifyContentType(next nextHTTP) nextHTTP {
return
}
}
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
+ render.Error(w, acme.NewError(acme.ErrorMalformedType,
"expected content-type to be in %s, but got %s", expected, ct))
}
}
@@ -98,12 +98,12 @@ func parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
- render.Error(w, r, acme.WrapErrorISE(err, "failed to read request body"))
+ render.Error(w, acme.WrapErrorISE(err, "failed to read request body"))
return
}
jws, err := jose.ParseJWS(string(body))
if err != nil {
- render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
+ render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
return
}
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
@@ -133,15 +133,15 @@ func validateJWS(next nextHTTP) nextHTTP {
jws, err := jwsFromContext(ctx)
if err != nil {
- render.Error(w, r, err)
+ render.Error(w, err)
return
}
if len(jws.Signatures) == 0 {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
return
}
if len(jws.Signatures) > 1 {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
return
}
@@ -152,7 +152,7 @@ func validateJWS(next nextHTTP) nextHTTP {
uh.Algorithm != "" ||
uh.Nonce != "" ||
len(uh.ExtraHeaders) > 0 {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
return
}
hdr := sig.Protected
@@ -162,13 +162,13 @@ func validateJWS(next nextHTTP) nextHTTP {
switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey:
if k.Size() < keyutil.MinRSAKeyBytes {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
+ render.Error(w, acme.NewError(acme.ErrorMalformedType,
"rsa keys must be at least %d bits (%d bytes) in size",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
return
}
default:
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
+ render.Error(w, acme.NewError(acme.ErrorMalformedType,
"jws key type and algorithm do not match"))
return
}
@@ -176,35 +176,35 @@ func validateJWS(next nextHTTP) nextHTTP {
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good
default:
- render.Error(w, r, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
+ render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
return
}
// Check the validity/freshness of the Nonce.
if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
- render.Error(w, r, err)
+ render.Error(w, err)
return
}
// Check that the JWS url matches the requested url.
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
if !ok {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
return
}
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
if jwsURL != reqURL.String() {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
+ render.Error(w, acme.NewError(acme.ErrorMalformedType,
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
return
}
if hdr.JSONWebKey != nil && hdr.KeyID != "" {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return
}
if hdr.JSONWebKey == nil && hdr.KeyID == "" {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return
}
next(w, r)
@@ -221,23 +221,23 @@ func extractJWK(next nextHTTP) nextHTTP {
jws, err := jwsFromContext(ctx)
if err != nil {
- render.Error(w, r, err)
+ render.Error(w, err)
return
}
jwk := jws.Signatures[0].Protected.JSONWebKey
if jwk == nil {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
return
}
if !jwk.Valid() {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
return
}
// Overwrite KeyID with the JWK thumbprint.
jwk.KeyID, err = acme.KeyToID(jwk)
if err != nil {
- render.Error(w, r, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
+ render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
return
}
@@ -247,15 +247,15 @@ func extractJWK(next nextHTTP) nextHTTP {
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key
acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID)
switch {
- case acme.IsErrNotFound(err):
+ case errors.Is(err, acme.ErrNotFound):
// For NewAccount and Revoke requests ...
break
case err != nil:
- render.Error(w, r, err)
+ render.Error(w, err)
return
default:
if !acc.IsValid() {
- render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
+ render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return
}
ctx = context.WithValue(ctx, accContextKey, acc)
@@ -274,11 +274,11 @@ func checkPrerequisites(next nextHTTP) nextHTTP {
if ok {
ok, err := checkFunc(ctx)
if err != nil {
- render.Error(w, r, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
+ render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return
}
if !ok {
- render.Error(w, r, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
+ render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
}
@@ -296,13 +296,13 @@ func lookupJWK(next nextHTTP) nextHTTP {
jws, err := jwsFromContext(ctx)
if err != nil {
- render.Error(w, r, err)
+ render.Error(w, err)
return
}
kid := jws.Signatures[0].Protected.KeyID
if kid == "" {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"))
return
}
@@ -310,14 +310,14 @@ func lookupJWK(next nextHTTP) nextHTTP {
acc, err := db.GetAccount(ctx, accID)
switch {
case acme.IsErrNotFound(err):
- render.Error(w, r, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
+ render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
return
case err != nil:
- render.Error(w, r, err)
+ render.Error(w, err)
return
default:
if !acc.IsValid() {
- render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
+ render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return
}
@@ -325,7 +325,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
if kid != storedLocation {
// ACME accounts should have a stored location equivalent to the
// kid in the ACME request.
- render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
+ render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"kid does not match stored account location; expected %s, but got %s",
storedLocation, kid))
return
@@ -334,16 +334,14 @@ func lookupJWK(next nextHTTP) nextHTTP {
// Verify that the provisioner with which the account was created
// matches the provisioner in the request URL.
reqProv := acme.MustProvisionerFromContext(ctx)
- switch {
- case acc.ProvisionerID == "" && acc.ProvisionerName != reqProv.GetName():
- render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
- "account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s",
- acc.ProvisionerName, reqProv.GetName()))
- return
- case acc.ProvisionerID != "" && acc.ProvisionerID != reqProv.GetID():
- render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
+ reqProvName := reqProv.GetName()
+ accProvName := acc.ProvisionerName
+ if reqProvName != accProvName {
+ // Provisioner in the URL must match the provisioner with
+ // which the account was created.
+ render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s",
- acc.ProvisionerID, reqProv.GetID()))
+ accProvName, reqProvName))
return
}
} else {
@@ -355,7 +353,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
linker := acme.MustLinkerFromContext(ctx)
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
if !strings.HasPrefix(kid, kidPrefix) {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
+ render.Error(w, acme.NewError(acme.ErrorMalformedType,
"kid does not have required prefix; expected %s, but got %s",
kidPrefix, kid))
return
@@ -376,7 +374,7 @@ func extractOrLookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
- render.Error(w, r, err)
+ render.Error(w, err)
return
}
@@ -412,16 +410,16 @@ func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
- render.Error(w, r, err)
+ render.Error(w, err)
return
}
jwk, err := jwkFromContext(ctx)
if err != nil {
- render.Error(w, r, err)
+ render.Error(w, err)
return
}
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return
}
@@ -430,11 +428,11 @@ func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
case errors.Is(err, jose.ErrCryptoFailure):
payload, err = retryVerificationWithPatchedSignatures(jws, jwk)
if err != nil {
- render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws with patched signature(s)"))
+ render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws with patched signature(s)"))
return
}
case err != nil:
- render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
+ render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
return
}
@@ -551,11 +549,11 @@ func isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r.Context())
if err != nil {
- render.Error(w, r, err)
+ render.Error(w, err)
return
}
if !payload.isPostAsGet {
- render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
+ render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
return
}
next(w, r)