Break notify into submethods to create the session then create the publish input to send. Check we populate a region for all requests.

This reverts commit 4c2a5f156c.

Signed-off-by: Tyler Reid <tyler.reid@grafana.com>
This commit is contained in:
Tyler Reid 2021-07-09 09:24:05 -05:00
parent 51b93681b2
commit a1260af1c3
1 changed files with 66 additions and 42 deletions

View File

@ -62,20 +62,45 @@ func New(c *config.SNSConfig, t *template.Template, l log.Logger, httpOpts ...co
func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, error) {
var (
err error
data = notify.GetTemplateData(ctx, n.tmpl, alert, n.logger)
tmpl = notify.TmplText(n.tmpl, data, &err)
creds *credentials.Credentials = nil
err error
data = notify.GetTemplateData(ctx, n.tmpl, alert, n.logger)
tmpl = notify.TmplText(n.tmpl, data, &err)
)
client, err := createSNSClient(n, tmpl)
if err != nil {
if e, ok := err.(awserr.RequestFailure); ok {
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
} else {
return true, err
}
}
publishInput, err := createPublishInput(ctx, n, tmpl)
if err != nil {
return true, err
}
publishOutput, err := client.Publish(publishInput)
if err != nil {
if e, ok := err.(awserr.RequestFailure); ok {
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
} else {
return true, err
}
}
level.Debug(n.logger).Log("msg", "SNS message successfully published", "message_id", publishOutput.MessageId, "sequence number", publishOutput.SequenceNumber)
return false, nil
}
func createSNSClient(n *Notifier, tmpl func(string) string) (*sns.SNS, error) {
var creds *credentials.Credentials = nil
// If there are provided sigV4 credentials we want to use those to create a session.
if n.conf.Sigv4.AccessKey != "" && n.conf.Sigv4.SecretKey != "" {
creds = credentials.NewStaticCredentials(n.conf.Sigv4.AccessKey, string(n.conf.Sigv4.SecretKey), "")
}
attributes := make(map[string]*sns.MessageAttributeValue, len(n.conf.Attributes))
for k, v := range n.conf.Attributes {
attributes[tmpl(k)] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(tmpl(v))}
}
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{
Region: aws.String(n.conf.Sigv4.Region),
@ -84,11 +109,7 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err
Profile: n.conf.Sigv4.Profile,
})
if err != nil {
if e, ok := err.(awserr.RequestFailure); ok {
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
} else {
return true, err
}
return nil, err
}
if n.conf.Sigv4.RoleARN != "" {
@ -105,32 +126,37 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err
Profile: n.conf.Sigv4.Profile,
})
if err != nil {
if e, ok := err.(awserr.RequestFailure); ok {
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
} else {
return true, err
}
return nil, err
}
}
creds = stscreds.NewCredentials(stsSess, n.conf.Sigv4.RoleARN)
}
// Use our generated session with credentials to create the SNS Client.
client := sns.New(sess, &aws.Config{Credentials: creds})
// We will always need a region to be set by either the local config or the environment.
if aws.StringValue(sess.Config.Region) == "" {
return nil, fmt.Errorf("region not configured in sns.sigv4.region or in default credentials chain")
}
return client, nil
}
func createPublishInput(ctx context.Context, n *Notifier, tmpl func(string) string) (*sns.PublishInput, error) {
publishInput := &sns.PublishInput{}
messageAttributes := createMessageAttributes(n, tmpl)
// Max message size for a message in a SNS publish request is 256KB, except for SMS messages where the limit is 1600 characters/runes.
messageSizeLimit := 256 * 1024
client := sns.New(sess, &aws.Config{Credentials: creds})
publishInput := &sns.PublishInput{}
if n.conf.TopicARN != "" {
topicTmpl := tmpl(n.conf.TopicARN)
publishInput.SetTopicArn(topicTmpl)
if n.isFifo == nil {
// If we are using a topic ARN it could be a FIFO topic specified by the topic postfix .fifo.
n.isFifo = aws.Bool(n.conf.TopicARN[len(n.conf.TopicARN)-5:] == ".fifo")
}
if *n.isFifo {
// Deduplication key and Message Group ID are only added if it's a FIFO SNS Topic.
key, err := notify.ExtractGroupKey(ctx)
if err != nil {
return false, err
return nil, err
}
publishInput.SetMessageDeduplicationId(key.Hash())
publishInput.SetMessageGroupId(key.Hash())
@ -143,36 +169,25 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err
}
if n.conf.TargetARN != "" {
publishInput.SetTargetArn(tmpl(n.conf.TargetARN))
}
messageToSend, isTrunc, err := validateAndTruncateMessage(tmpl(n.conf.Message), messageSizeLimit)
if err != nil {
return false, err
return nil, err
}
if isTrunc {
attributes["truncated"] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}
// If we truncated the message we need to add a message attribute showing that it was truncated.
messageAttributes["truncated"] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}
}
publishInput.SetMessage(messageToSend)
publishInput.SetMessageAttributes(messageAttributes)
if n.conf.Subject != "" {
publishInput.SetSubject(tmpl(n.conf.Subject))
}
publishInput.SetMessageAttributes(attributes)
publishOutput, err := client.Publish(publishInput)
if err != nil {
if e, ok := err.(awserr.RequestFailure); ok {
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
} else {
return true, err
}
}
level.Debug(n.logger).Log("msg", "SNS message successfully published", "message_id", publishOutput.MessageId, "sequence number", publishOutput.SequenceNumber)
return false, nil
return publishInput, nil
}
func validateAndTruncateMessage(message string, maxMessageSizeInBytes int) (string, bool, error) {
@ -187,3 +202,12 @@ func validateAndTruncateMessage(message string, maxMessageSizeInBytes int) (stri
copy(truncated, message)
return string(truncated), true, nil
}
func createMessageAttributes(n *Notifier, tmpl func(string) string) map[string]*sns.MessageAttributeValue {
// Convert the given attributes map into the AWS Message Attributes Format
attributes := make(map[string]*sns.MessageAttributeValue, len(n.conf.Attributes))
for k, v := range n.conf.Attributes {
attributes[tmpl(k)] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(tmpl(v))}
}
return attributes
}