fix regression that prevented setting config slices with env variables (#612)

This commit is contained in:
aler9 2021-10-04 08:58:08 +02:00
parent 0d15e2772a
commit b70a4bfe5b
6 changed files with 50 additions and 14 deletions

View File

@ -3,10 +3,39 @@ package conf
import (
"encoding/json"
"fmt"
"strings"
"github.com/aler9/gortsplib/pkg/headers"
)
func unmarshalStringSlice(b []byte) ([]string, error) {
var in interface{}
if err := json.Unmarshal(b, &in); err != nil {
return nil, err
}
var slice []string
switch it := in.(type) {
case string: // from environment variables
slice = strings.Split(it, ",")
case []interface{}: // from yaml
for _, e := range it {
et, ok := e.(string)
if !ok {
return nil, fmt.Errorf("cannot unmarshal from %T", e)
}
slice = append(slice, et)
}
default:
return nil, fmt.Errorf("cannot unmarshal from %T", in)
}
return slice, nil
}
// AuthMethods is the authMethods parameter.
type AuthMethods []headers.AuthMethod
@ -29,12 +58,12 @@ func (d AuthMethods) MarshalJSON() ([]byte, error) {
// UnmarshalJSON unmarshals a AuthMethods from JSON.
func (d *AuthMethods) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
slice, err := unmarshalStringSlice(b)
if err != nil {
return err
}
for _, v := range in {
for _, v := range slice {
switch v {
case "basic":
*d = append(*d, headers.AuthBasic)
@ -43,7 +72,7 @@ func (d *AuthMethods) UnmarshalJSON(b []byte) error {
*d = append(*d, headers.AuthDigest)
default:
return fmt.Errorf("invalid authentication method: %s", in)
return fmt.Errorf("invalid authentication method: %s", v)
}
}

View File

@ -86,6 +86,9 @@ func TestConfFromFileAndEnv(t *testing.T) {
os.Setenv("RTSP_PATHS_CAM1_SOURCE", "rtsp://testing")
defer os.Unsetenv("RTSP_PATHS_CAM1_SOURCE")
os.Setenv("RTSP_PROTOCOLS", "tcp")
defer os.Unsetenv("RTSP_PROTOCOLS")
tmpf, err := writeTempFile([]byte("{}"))
require.NoError(t, err)
defer os.Remove(tmpf)
@ -94,6 +97,8 @@ func TestConfFromFileAndEnv(t *testing.T) {
require.NoError(t, err)
require.Equal(t, true, hasFile)
require.Equal(t, Protocols{ProtocolTCP: {}}, conf.Protocols)
pa, ok := conf.Paths["cam1"]
require.Equal(t, true, ok)
require.Equal(t, &PathConf{

View File

@ -22,16 +22,16 @@ func (d IPsOrNets) MarshalJSON() ([]byte, error) {
// UnmarshalJSON unmarshals a IPsOrNets from JSON.
func (d *IPsOrNets) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
slice, err := unmarshalStringSlice(b)
if err != nil {
return err
}
if len(in) == 0 {
if len(slice) == 0 {
return nil
}
for _, t := range in {
for _, t := range slice {
if _, ipnet, err := net.ParseCIDR(t); err == nil {
*d = append(*d, ipnet)
} else if ip := net.ParseIP(t); ip != nil {

View File

@ -38,14 +38,14 @@ func (d LogDestinations) MarshalJSON() ([]byte, error) {
// UnmarshalJSON unmarshals a LogDestinations from JSON.
func (d *LogDestinations) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
slice, err := unmarshalStringSlice(b)
if err != nil {
return err
}
*d = make(LogDestinations)
for _, proto := range in {
for _, proto := range slice {
switch proto {
case "stdout":
(*d)[logger.DestinationStdout] = struct{}{}

View File

@ -46,14 +46,14 @@ func (d Protocols) MarshalJSON() ([]byte, error) {
// UnmarshalJSON unmarshals a Protocols from JSON.
func (d *Protocols) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
slice, err := unmarshalStringSlice(b)
if err != nil {
return err
}
*d = make(Protocols)
for _, proto := range in {
for _, proto := range slice {
switch proto {
case "udp":
(*d)[ProtocolUDP] = struct{}{}

View File

@ -72,6 +72,7 @@ func TestAPIConfigSet(t *testing.T) {
err := httpRequest(http.MethodPost, "http://localhost:9997/v1/config/set", map[string]interface{}{
"rtmpDisable": true,
"readTimeout": "7s",
"protocols": []string{"tcp"},
}, nil)
require.NoError(t, err)
@ -82,6 +83,7 @@ func TestAPIConfigSet(t *testing.T) {
require.NoError(t, err)
require.Equal(t, true, out["rtmpDisable"])
require.Equal(t, "7s", out["readTimeout"])
require.Equal(t, []interface{}{"tcp"}, out["protocols"])
}
func TestAPIConfigPathsAdd(t *testing.T) {