diff --git a/internal/conf/authmethod.go b/internal/conf/authmethod.go index ff599915..49e9a98e 100644 --- a/internal/conf/authmethod.go +++ b/internal/conf/authmethod.go @@ -3,10 +3,39 @@ package conf import ( "encoding/json" "fmt" + "strings" "github.com/aler9/gortsplib/pkg/headers" ) +func unmarshalStringSlice(b []byte) ([]string, error) { + var in interface{} + if err := json.Unmarshal(b, &in); err != nil { + return nil, err + } + + var slice []string + + switch it := in.(type) { + case string: // from environment variables + slice = strings.Split(it, ",") + + case []interface{}: // from yaml + for _, e := range it { + et, ok := e.(string) + if !ok { + return nil, fmt.Errorf("cannot unmarshal from %T", e) + } + slice = append(slice, et) + } + + default: + return nil, fmt.Errorf("cannot unmarshal from %T", in) + } + + return slice, nil +} + // AuthMethods is the authMethods parameter. type AuthMethods []headers.AuthMethod @@ -29,12 +58,12 @@ func (d AuthMethods) MarshalJSON() ([]byte, error) { // UnmarshalJSON unmarshals a AuthMethods from JSON. func (d *AuthMethods) UnmarshalJSON(b []byte) error { - var in []string - if err := json.Unmarshal(b, &in); err != nil { + slice, err := unmarshalStringSlice(b) + if err != nil { return err } - for _, v := range in { + for _, v := range slice { switch v { case "basic": *d = append(*d, headers.AuthBasic) @@ -43,7 +72,7 @@ func (d *AuthMethods) UnmarshalJSON(b []byte) error { *d = append(*d, headers.AuthDigest) default: - return fmt.Errorf("invalid authentication method: %s", in) + return fmt.Errorf("invalid authentication method: %s", v) } } diff --git a/internal/conf/conf_test.go b/internal/conf/conf_test.go index fce4c100..1924a48f 100644 --- a/internal/conf/conf_test.go +++ b/internal/conf/conf_test.go @@ -86,6 +86,9 @@ func TestConfFromFileAndEnv(t *testing.T) { os.Setenv("RTSP_PATHS_CAM1_SOURCE", "rtsp://testing") defer os.Unsetenv("RTSP_PATHS_CAM1_SOURCE") + os.Setenv("RTSP_PROTOCOLS", "tcp") + defer os.Unsetenv("RTSP_PROTOCOLS") + tmpf, err := writeTempFile([]byte("{}")) require.NoError(t, err) defer os.Remove(tmpf) @@ -94,6 +97,8 @@ func TestConfFromFileAndEnv(t *testing.T) { require.NoError(t, err) require.Equal(t, true, hasFile) + require.Equal(t, Protocols{ProtocolTCP: {}}, conf.Protocols) + pa, ok := conf.Paths["cam1"] require.Equal(t, true, ok) require.Equal(t, &PathConf{ diff --git a/internal/conf/ipsornets.go b/internal/conf/ipsornets.go index 42a90910..d1f83e74 100644 --- a/internal/conf/ipsornets.go +++ b/internal/conf/ipsornets.go @@ -22,16 +22,16 @@ func (d IPsOrNets) MarshalJSON() ([]byte, error) { // UnmarshalJSON unmarshals a IPsOrNets from JSON. func (d *IPsOrNets) UnmarshalJSON(b []byte) error { - var in []string - if err := json.Unmarshal(b, &in); err != nil { + slice, err := unmarshalStringSlice(b) + if err != nil { return err } - if len(in) == 0 { + if len(slice) == 0 { return nil } - for _, t := range in { + for _, t := range slice { if _, ipnet, err := net.ParseCIDR(t); err == nil { *d = append(*d, ipnet) } else if ip := net.ParseIP(t); ip != nil { diff --git a/internal/conf/logdestination.go b/internal/conf/logdestination.go index ac21f475..9eb8efa1 100644 --- a/internal/conf/logdestination.go +++ b/internal/conf/logdestination.go @@ -38,14 +38,14 @@ func (d LogDestinations) MarshalJSON() ([]byte, error) { // UnmarshalJSON unmarshals a LogDestinations from JSON. func (d *LogDestinations) UnmarshalJSON(b []byte) error { - var in []string - if err := json.Unmarshal(b, &in); err != nil { + slice, err := unmarshalStringSlice(b) + if err != nil { return err } *d = make(LogDestinations) - for _, proto := range in { + for _, proto := range slice { switch proto { case "stdout": (*d)[logger.DestinationStdout] = struct{}{} diff --git a/internal/conf/protocol.go b/internal/conf/protocol.go index 83dc015c..009e1180 100644 --- a/internal/conf/protocol.go +++ b/internal/conf/protocol.go @@ -46,14 +46,14 @@ func (d Protocols) MarshalJSON() ([]byte, error) { // UnmarshalJSON unmarshals a Protocols from JSON. func (d *Protocols) UnmarshalJSON(b []byte) error { - var in []string - if err := json.Unmarshal(b, &in); err != nil { + slice, err := unmarshalStringSlice(b) + if err != nil { return err } *d = make(Protocols) - for _, proto := range in { + for _, proto := range slice { switch proto { case "udp": (*d)[ProtocolUDP] = struct{}{} diff --git a/internal/core/api_test.go b/internal/core/api_test.go index b19c4bc7..dde4ee57 100644 --- a/internal/core/api_test.go +++ b/internal/core/api_test.go @@ -72,6 +72,7 @@ func TestAPIConfigSet(t *testing.T) { err := httpRequest(http.MethodPost, "http://localhost:9997/v1/config/set", map[string]interface{}{ "rtmpDisable": true, "readTimeout": "7s", + "protocols": []string{"tcp"}, }, nil) require.NoError(t, err) @@ -82,6 +83,7 @@ func TestAPIConfigSet(t *testing.T) { require.NoError(t, err) require.Equal(t, true, out["rtmpDisable"]) require.Equal(t, "7s", out["readTimeout"]) + require.Equal(t, []interface{}{"tcp"}, out["protocols"]) } func TestAPIConfigPathsAdd(t *testing.T) {