summaryrefslogtreecommitdiff
path: root/vendor/github.com/smallstep/certificates/authority/provisioner/scep.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/smallstep/certificates/authority/provisioner/scep.go')
-rw-r--r--vendor/github.com/smallstep/certificates/authority/provisioner/scep.go63
1 files changed, 39 insertions, 24 deletions
diff --git a/vendor/github.com/smallstep/certificates/authority/provisioner/scep.go b/vendor/github.com/smallstep/certificates/authority/provisioner/scep.go
index 7213285..a97ff8e 100644
--- a/vendor/github.com/smallstep/certificates/authority/provisioner/scep.go
+++ b/vendor/github.com/smallstep/certificates/authority/provisioner/scep.go
@@ -8,15 +8,16 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
- "net/http"
"time"
"github.com/pkg/errors"
+ "github.com/smallstep/linkedca"
"go.step.sm/crypto/kms"
kmsapi "go.step.sm/crypto/kms/apiv1"
- "go.step.sm/linkedca"
+ "go.step.sm/crypto/x509util"
+ "github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/webhook"
)
@@ -111,13 +112,14 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration {
}
type challengeValidationController struct {
- client *http.Client
- webhooks []*Webhook
+ client HTTPClient
+ wrapTransport httptransport.Wrapper
+ webhooks []*Webhook
}
// newChallengeValidationController creates a new challengeValidationController
// that performs challenge validation through webhooks.
-func newChallengeValidationController(client *http.Client, webhooks []*Webhook) *challengeValidationController {
+func newChallengeValidationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController {
scepHooks := []*Webhook{}
for _, wh := range webhooks {
if wh.Kind != linkedca.Webhook_SCEPCHALLENGE.String() {
@@ -129,8 +131,9 @@ func newChallengeValidationController(client *http.Client, webhooks []*Webhook)
scepHooks = append(scepHooks, wh)
}
return &challengeValidationController{
- client: client,
- webhooks: scepHooks,
+ client: client,
+ wrapTransport: tw,
+ webhooks: scepHooks,
}
}
@@ -145,35 +148,44 @@ var (
// that case, the other webhooks will be skipped. If none of
// the webhooks indicates the value of the challenge was accepted,
// an error is returned.
-func (c *challengeValidationController) Validate(ctx context.Context, csr *x509.CertificateRequest, provisionerName, challenge, transactionID string) error {
+func (c *challengeValidationController) Validate(ctx context.Context, csr *x509.CertificateRequest, provisionerName, challenge, transactionID string) ([]SignCSROption, error) {
+ var opts []SignCSROption
+
for _, wh := range c.webhooks {
req, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr))
if err != nil {
- return fmt.Errorf("failed creating new webhook request: %w", err)
+ return nil, fmt.Errorf("failed creating new webhook request: %w", err)
}
req.ProvisionerName = provisionerName
req.SCEPChallenge = challenge
req.SCEPTransactionID = transactionID
- resp, err := wh.DoWithContext(ctx, c.client, req, nil) // TODO(hs): support templated URL? Requires some refactoring
+ resp, err := wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil) // TODO(hs): support templated URL? Requires some refactoring
if err != nil {
- return fmt.Errorf("failed executing webhook request: %w", err)
+ return nil, fmt.Errorf("failed executing webhook request: %w", err)
}
if resp.Allow {
- return nil // return early when response is positive
+ opts = append(opts, TemplateDataModifierFunc(func(data x509util.TemplateData) {
+ data.SetWebhook(wh.Name, resp.Data)
+ }))
}
}
- return ErrSCEPChallengeInvalid
+ if len(opts) == 0 {
+ return nil, ErrSCEPChallengeInvalid
+ }
+
+ return opts, nil
}
type notificationController struct {
- client *http.Client
- webhooks []*Webhook
+ client HTTPClient
+ wrapTransport httptransport.Wrapper
+ webhooks []*Webhook
}
// newNotificationController creates a new notificationController
// that performs SCEP notifications through webhooks.
-func newNotificationController(client *http.Client, webhooks []*Webhook) *notificationController {
+func newNotificationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController {
scepHooks := []*Webhook{}
for _, wh := range webhooks {
if wh.Kind != linkedca.Webhook_NOTIFYING.String() {
@@ -185,8 +197,9 @@ func newNotificationController(client *http.Client, webhooks []*Webhook) *notifi
scepHooks = append(scepHooks, wh)
}
return &notificationController{
- client: client,
- webhooks: scepHooks,
+ client: client,
+ wrapTransport: tw,
+ webhooks: scepHooks,
}
}
@@ -198,7 +211,7 @@ func (c *notificationController) Success(ctx context.Context, csr *x509.Certific
}
req.X509Certificate.Raw = cert.Raw // adding the full certificate DER bytes
req.SCEPTransactionID = transactionID
- if _, err = wh.DoWithContext(ctx, c.client, req, nil); err != nil {
+ if _, err = wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil); err != nil {
return fmt.Errorf("failed executing webhook request: %w: %w", ErrSCEPNotificationFailed, err)
}
}
@@ -215,7 +228,7 @@ func (c *notificationController) Failure(ctx context.Context, csr *x509.Certific
req.SCEPTransactionID = transactionID
req.SCEPErrorCode = errorCode
req.SCEPErrorDescription = errorDescription
- if _, err = wh.DoWithContext(ctx, c.client, req, nil); err != nil {
+ if _, err = wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil); err != nil {
return fmt.Errorf("failed executing webhook request: %w: %w", ErrSCEPNotificationFailed, err)
}
}
@@ -258,12 +271,14 @@ func (s *SCEP) Init(config Config) (err error) {
// Prepare the SCEP challenge validator
s.challengeValidationController = newChallengeValidationController(
config.WebhookClient,
+ config.WrapTransport,
s.GetOptions().GetWebhooks(),
)
// Prepare the SCEP notification controller
s.notificationController = newNotificationController(
config.WebhookClient,
+ config.WrapTransport,
s.GetOptions().GetWebhooks(),
)
@@ -440,18 +455,18 @@ func (s *SCEP) GetContentEncryptionAlgorithm() int {
// ValidateChallenge validates the provided challenge. It starts by
// selecting the validation method to use, then performs validation
// according to that method.
-func (s *SCEP) ValidateChallenge(ctx context.Context, csr *x509.CertificateRequest, challenge, transactionID string) error {
+func (s *SCEP) ValidateChallenge(ctx context.Context, csr *x509.CertificateRequest, challenge, transactionID string) ([]SignCSROption, error) {
if s.challengeValidationController == nil {
- return fmt.Errorf("provisioner %q wasn't initialized", s.Name)
+ return nil, fmt.Errorf("provisioner %q wasn't initialized", s.Name)
}
switch s.selectValidationMethod() {
case validationMethodWebhook:
return s.challengeValidationController.Validate(ctx, csr, s.Name, challenge, transactionID)
default:
if subtle.ConstantTimeCompare([]byte(s.ChallengePassword), []byte(challenge)) == 0 {
- return errors.New("invalid challenge password provided")
+ return nil, errors.New("invalid challenge password provided")
}
- return nil
+ return []SignCSROption{}, nil
}
}