fix regression that prevented setting config slices with env variables (#612)
This commit is contained in:
parent
0d15e2772a
commit
b70a4bfe5b
|
@ -3,10 +3,39 @@ package conf
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/aler9/gortsplib/pkg/headers"
|
||||
)
|
||||
|
||||
func unmarshalStringSlice(b []byte) ([]string, error) {
|
||||
var in interface{}
|
||||
if err := json.Unmarshal(b, &in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var slice []string
|
||||
|
||||
switch it := in.(type) {
|
||||
case string: // from environment variables
|
||||
slice = strings.Split(it, ",")
|
||||
|
||||
case []interface{}: // from yaml
|
||||
for _, e := range it {
|
||||
et, ok := e.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot unmarshal from %T", e)
|
||||
}
|
||||
slice = append(slice, et)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("cannot unmarshal from %T", in)
|
||||
}
|
||||
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
// AuthMethods is the authMethods parameter.
|
||||
type AuthMethods []headers.AuthMethod
|
||||
|
||||
|
@ -29,12 +58,12 @@ func (d AuthMethods) MarshalJSON() ([]byte, error) {
|
|||
|
||||
// UnmarshalJSON unmarshals a AuthMethods from JSON.
|
||||
func (d *AuthMethods) UnmarshalJSON(b []byte) error {
|
||||
var in []string
|
||||
if err := json.Unmarshal(b, &in); err != nil {
|
||||
slice, err := unmarshalStringSlice(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, v := range in {
|
||||
for _, v := range slice {
|
||||
switch v {
|
||||
case "basic":
|
||||
*d = append(*d, headers.AuthBasic)
|
||||
|
@ -43,7 +72,7 @@ func (d *AuthMethods) UnmarshalJSON(b []byte) error {
|
|||
*d = append(*d, headers.AuthDigest)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("invalid authentication method: %s", in)
|
||||
return fmt.Errorf("invalid authentication method: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -86,6 +86,9 @@ func TestConfFromFileAndEnv(t *testing.T) {
|
|||
os.Setenv("RTSP_PATHS_CAM1_SOURCE", "rtsp://testing")
|
||||
defer os.Unsetenv("RTSP_PATHS_CAM1_SOURCE")
|
||||
|
||||
os.Setenv("RTSP_PROTOCOLS", "tcp")
|
||||
defer os.Unsetenv("RTSP_PROTOCOLS")
|
||||
|
||||
tmpf, err := writeTempFile([]byte("{}"))
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpf)
|
||||
|
@ -94,6 +97,8 @@ func TestConfFromFileAndEnv(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, true, hasFile)
|
||||
|
||||
require.Equal(t, Protocols{ProtocolTCP: {}}, conf.Protocols)
|
||||
|
||||
pa, ok := conf.Paths["cam1"]
|
||||
require.Equal(t, true, ok)
|
||||
require.Equal(t, &PathConf{
|
||||
|
|
|
@ -22,16 +22,16 @@ func (d IPsOrNets) MarshalJSON() ([]byte, error) {
|
|||
|
||||
// UnmarshalJSON unmarshals a IPsOrNets from JSON.
|
||||
func (d *IPsOrNets) UnmarshalJSON(b []byte) error {
|
||||
var in []string
|
||||
if err := json.Unmarshal(b, &in); err != nil {
|
||||
slice, err := unmarshalStringSlice(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(in) == 0 {
|
||||
if len(slice) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, t := range in {
|
||||
for _, t := range slice {
|
||||
if _, ipnet, err := net.ParseCIDR(t); err == nil {
|
||||
*d = append(*d, ipnet)
|
||||
} else if ip := net.ParseIP(t); ip != nil {
|
||||
|
|
|
@ -38,14 +38,14 @@ func (d LogDestinations) MarshalJSON() ([]byte, error) {
|
|||
|
||||
// UnmarshalJSON unmarshals a LogDestinations from JSON.
|
||||
func (d *LogDestinations) UnmarshalJSON(b []byte) error {
|
||||
var in []string
|
||||
if err := json.Unmarshal(b, &in); err != nil {
|
||||
slice, err := unmarshalStringSlice(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*d = make(LogDestinations)
|
||||
|
||||
for _, proto := range in {
|
||||
for _, proto := range slice {
|
||||
switch proto {
|
||||
case "stdout":
|
||||
(*d)[logger.DestinationStdout] = struct{}{}
|
||||
|
|
|
@ -46,14 +46,14 @@ func (d Protocols) MarshalJSON() ([]byte, error) {
|
|||
|
||||
// UnmarshalJSON unmarshals a Protocols from JSON.
|
||||
func (d *Protocols) UnmarshalJSON(b []byte) error {
|
||||
var in []string
|
||||
if err := json.Unmarshal(b, &in); err != nil {
|
||||
slice, err := unmarshalStringSlice(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*d = make(Protocols)
|
||||
|
||||
for _, proto := range in {
|
||||
for _, proto := range slice {
|
||||
switch proto {
|
||||
case "udp":
|
||||
(*d)[ProtocolUDP] = struct{}{}
|
||||
|
|
|
@ -72,6 +72,7 @@ func TestAPIConfigSet(t *testing.T) {
|
|||
err := httpRequest(http.MethodPost, "http://localhost:9997/v1/config/set", map[string]interface{}{
|
||||
"rtmpDisable": true,
|
||||
"readTimeout": "7s",
|
||||
"protocols": []string{"tcp"},
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -82,6 +83,7 @@ func TestAPIConfigSet(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, true, out["rtmpDisable"])
|
||||
require.Equal(t, "7s", out["readTimeout"])
|
||||
require.Equal(t, []interface{}{"tcp"}, out["protocols"])
|
||||
}
|
||||
|
||||
func TestAPIConfigPathsAdd(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue