From a1260af1c325e3158400a9179584f592c39b5a98 Mon Sep 17 00:00:00 2001 From: Tyler Reid Date: Fri, 9 Jul 2021 09:24:05 -0500 Subject: [PATCH] Break notify into submethods to create the session then create the publish input to send. Check we populate a region for all requests. This reverts commit 4c2a5f156c0337f658b3bd3de1b9af02400d56f3. Signed-off-by: Tyler Reid --- notify/sns/sns.go | 108 ++++++++++++++++++++++++++++------------------ 1 file changed, 66 insertions(+), 42 deletions(-) diff --git a/notify/sns/sns.go b/notify/sns/sns.go index 77501f66..dbd1ccc1 100644 --- a/notify/sns/sns.go +++ b/notify/sns/sns.go @@ -62,20 +62,45 @@ func New(c *config.SNSConfig, t *template.Template, l log.Logger, httpOpts ...co func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, error) { var ( - err error - data = notify.GetTemplateData(ctx, n.tmpl, alert, n.logger) - tmpl = notify.TmplText(n.tmpl, data, &err) - creds *credentials.Credentials = nil + err error + data = notify.GetTemplateData(ctx, n.tmpl, alert, n.logger) + tmpl = notify.TmplText(n.tmpl, data, &err) ) + + client, err := createSNSClient(n, tmpl) + if err != nil { + if e, ok := err.(awserr.RequestFailure); ok { + return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message())) + } else { + return true, err + } + } + + publishInput, err := createPublishInput(ctx, n, tmpl) + if err != nil { + return true, err + } + + publishOutput, err := client.Publish(publishInput) + if err != nil { + if e, ok := err.(awserr.RequestFailure); ok { + return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message())) + } else { + return true, err + } + } + + level.Debug(n.logger).Log("msg", "SNS message successfully published", "message_id", publishOutput.MessageId, "sequence number", publishOutput.SequenceNumber) + + return false, nil +} + +func createSNSClient(n *Notifier, tmpl func(string) string) (*sns.SNS, error) { + var creds *credentials.Credentials = nil + // If there are provided sigV4 credentials we want to use those to create a session. if n.conf.Sigv4.AccessKey != "" && n.conf.Sigv4.SecretKey != "" { creds = credentials.NewStaticCredentials(n.conf.Sigv4.AccessKey, string(n.conf.Sigv4.SecretKey), "") } - - attributes := make(map[string]*sns.MessageAttributeValue, len(n.conf.Attributes)) - for k, v := range n.conf.Attributes { - attributes[tmpl(k)] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(tmpl(v))} - } - sess, err := session.NewSessionWithOptions(session.Options{ Config: aws.Config{ Region: aws.String(n.conf.Sigv4.Region), @@ -84,11 +109,7 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err Profile: n.conf.Sigv4.Profile, }) if err != nil { - if e, ok := err.(awserr.RequestFailure); ok { - return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message())) - } else { - return true, err - } + return nil, err } if n.conf.Sigv4.RoleARN != "" { @@ -105,32 +126,37 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err Profile: n.conf.Sigv4.Profile, }) if err != nil { - if e, ok := err.(awserr.RequestFailure); ok { - return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message())) - } else { - return true, err - } + return nil, err } } creds = stscreds.NewCredentials(stsSess, n.conf.Sigv4.RoleARN) } + // Use our generated session with credentials to create the SNS Client. + client := sns.New(sess, &aws.Config{Credentials: creds}) + // We will always need a region to be set by either the local config or the environment. + if aws.StringValue(sess.Config.Region) == "" { + return nil, fmt.Errorf("region not configured in sns.sigv4.region or in default credentials chain") + } + return client, nil +} + +func createPublishInput(ctx context.Context, n *Notifier, tmpl func(string) string) (*sns.PublishInput, error) { + publishInput := &sns.PublishInput{} + messageAttributes := createMessageAttributes(n, tmpl) // Max message size for a message in a SNS publish request is 256KB, except for SMS messages where the limit is 1600 characters/runes. messageSizeLimit := 256 * 1024 - client := sns.New(sess, &aws.Config{Credentials: creds}) - publishInput := &sns.PublishInput{} - if n.conf.TopicARN != "" { topicTmpl := tmpl(n.conf.TopicARN) publishInput.SetTopicArn(topicTmpl) - if n.isFifo == nil { + // If we are using a topic ARN it could be a FIFO topic specified by the topic postfix .fifo. n.isFifo = aws.Bool(n.conf.TopicARN[len(n.conf.TopicARN)-5:] == ".fifo") } if *n.isFifo { // Deduplication key and Message Group ID are only added if it's a FIFO SNS Topic. key, err := notify.ExtractGroupKey(ctx) if err != nil { - return false, err + return nil, err } publishInput.SetMessageDeduplicationId(key.Hash()) publishInput.SetMessageGroupId(key.Hash()) @@ -143,36 +169,25 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err } if n.conf.TargetARN != "" { publishInput.SetTargetArn(tmpl(n.conf.TargetARN)) - } messageToSend, isTrunc, err := validateAndTruncateMessage(tmpl(n.conf.Message), messageSizeLimit) if err != nil { - return false, err + return nil, err } if isTrunc { - attributes["truncated"] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")} + // If we truncated the message we need to add a message attribute showing that it was truncated. + messageAttributes["truncated"] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")} } + publishInput.SetMessage(messageToSend) + publishInput.SetMessageAttributes(messageAttributes) if n.conf.Subject != "" { publishInput.SetSubject(tmpl(n.conf.Subject)) } - publishInput.SetMessageAttributes(attributes) - - publishOutput, err := client.Publish(publishInput) - if err != nil { - if e, ok := err.(awserr.RequestFailure); ok { - return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message())) - } else { - return true, err - } - } - - level.Debug(n.logger).Log("msg", "SNS message successfully published", "message_id", publishOutput.MessageId, "sequence number", publishOutput.SequenceNumber) - - return false, nil + return publishInput, nil } func validateAndTruncateMessage(message string, maxMessageSizeInBytes int) (string, bool, error) { @@ -187,3 +202,12 @@ func validateAndTruncateMessage(message string, maxMessageSizeInBytes int) (stri copy(truncated, message) return string(truncated), true, nil } + +func createMessageAttributes(n *Notifier, tmpl func(string) string) map[string]*sns.MessageAttributeValue { + // Convert the given attributes map into the AWS Message Attributes Format + attributes := make(map[string]*sns.MessageAttributeValue, len(n.conf.Attributes)) + for k, v := range n.conf.Attributes { + attributes[tmpl(k)] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(tmpl(v))} + } + return attributes +}