Update vendoring for github.com/golang/protobuf/proto

This commit is contained in:
Ben Kochie 2016-12-16 12:56:38 +01:00
parent 84c6c66a35
commit 3939e9d0b1
14 changed files with 819 additions and 185 deletions

View File

@ -39,5 +39,5 @@ test: install generate-test-pbs
generate-test-pbs: generate-test-pbs:
make install make install
make -C testdata make -C testdata
protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata:. proto3_proto/proto3.proto protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata,Mgoogle/protobuf/any.proto=github.com/golang/protobuf/ptypes/any:. proto3_proto/proto3.proto
make make

View File

@ -84,9 +84,15 @@ func mergeStruct(out, in reflect.Value) {
mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i]) mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
} }
if emIn, ok := in.Addr().Interface().(extendableProto); ok { if emIn, ok := extendable(in.Addr().Interface()); ok {
emOut := out.Addr().Interface().(extendableProto) emOut, _ := extendable(out.Addr().Interface())
mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap()) mIn, muIn := emIn.extensionsRead()
if mIn != nil {
mOut := emOut.extensionsWrite()
muIn.Lock()
mergeExtension(mOut, mIn)
muIn.Unlock()
}
} }
uf := in.FieldByName("XXX_unrecognized") uf := in.FieldByName("XXX_unrecognized")

View File

@ -61,7 +61,6 @@ var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for
// int32, int64, uint32, uint64, bool, and enum // int32, int64, uint32, uint64, bool, and enum
// protocol buffer types. // protocol buffer types.
func DecodeVarint(buf []byte) (x uint64, n int) { func DecodeVarint(buf []byte) (x uint64, n int) {
// x, n already 0
for shift := uint(0); shift < 64; shift += 7 { for shift := uint(0); shift < 64; shift += 7 {
if n >= len(buf) { if n >= len(buf) {
return 0, 0 return 0, 0
@ -78,13 +77,7 @@ func DecodeVarint(buf []byte) (x uint64, n int) {
return 0, 0 return 0, 0
} }
// DecodeVarint reads a varint-encoded integer from the Buffer. func (p *Buffer) decodeVarintSlow() (x uint64, err error) {
// This is the format for the
// int32, int64, uint32, uint64, bool, and enum
// protocol buffer types.
func (p *Buffer) DecodeVarint() (x uint64, err error) {
// x, err already 0
i := p.index i := p.index
l := len(p.buf) l := len(p.buf)
@ -107,6 +100,107 @@ func (p *Buffer) DecodeVarint() (x uint64, err error) {
return return
} }
// DecodeVarint reads a varint-encoded integer from the Buffer.
// This is the format for the
// int32, int64, uint32, uint64, bool, and enum
// protocol buffer types.
func (p *Buffer) DecodeVarint() (x uint64, err error) {
i := p.index
buf := p.buf
if i >= len(buf) {
return 0, io.ErrUnexpectedEOF
} else if buf[i] < 0x80 {
p.index++
return uint64(buf[i]), nil
} else if len(buf)-i < 10 {
return p.decodeVarintSlow()
}
var b uint64
// we already checked the first byte
x = uint64(buf[i]) - 0x80
i++
b = uint64(buf[i])
i++
x += b << 7
if b&0x80 == 0 {
goto done
}
x -= 0x80 << 7
b = uint64(buf[i])
i++
x += b << 14
if b&0x80 == 0 {
goto done
}
x -= 0x80 << 14
b = uint64(buf[i])
i++
x += b << 21
if b&0x80 == 0 {
goto done
}
x -= 0x80 << 21
b = uint64(buf[i])
i++
x += b << 28
if b&0x80 == 0 {
goto done
}
x -= 0x80 << 28
b = uint64(buf[i])
i++
x += b << 35
if b&0x80 == 0 {
goto done
}
x -= 0x80 << 35
b = uint64(buf[i])
i++
x += b << 42
if b&0x80 == 0 {
goto done
}
x -= 0x80 << 42
b = uint64(buf[i])
i++
x += b << 49
if b&0x80 == 0 {
goto done
}
x -= 0x80 << 49
b = uint64(buf[i])
i++
x += b << 56
if b&0x80 == 0 {
goto done
}
x -= 0x80 << 56
b = uint64(buf[i])
i++
x += b << 63
if b&0x80 == 0 {
goto done
}
// x -= 0x80 << 63 // Always zero.
return 0, errOverflow
done:
p.index = i
return x, nil
}
// DecodeFixed64 reads a 64-bit integer from the Buffer. // DecodeFixed64 reads a 64-bit integer from the Buffer.
// This is the format for the // This is the format for the
// fixed64, sfixed64, and double protocol buffer types. // fixed64, sfixed64, and double protocol buffer types.
@ -340,6 +434,8 @@ func (p *Buffer) DecodeGroup(pb Message) error {
// Buffer and places the decoded result in pb. If the struct // Buffer and places the decoded result in pb. If the struct
// underlying pb does not match the data in the buffer, the results can be // underlying pb does not match the data in the buffer, the results can be
// unpredictable. // unpredictable.
//
// Unlike proto.Unmarshal, this does not reset pb before starting to unmarshal.
func (p *Buffer) Unmarshal(pb Message) error { func (p *Buffer) Unmarshal(pb Message) error {
// If the object can unmarshal itself, let it. // If the object can unmarshal itself, let it.
if u, ok := pb.(Unmarshaler); ok { if u, ok := pb.(Unmarshaler); ok {
@ -378,6 +474,11 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
wire := int(u & 0x7) wire := int(u & 0x7)
if wire == WireEndGroup { if wire == WireEndGroup {
if is_group { if is_group {
if required > 0 {
// Not enough information to determine the exact field.
// (See below.)
return &RequiredNotSetError{"{Unknown}"}
}
return nil // input is satisfied return nil // input is satisfied
} }
return fmt.Errorf("proto: %s: wiretype end group for non-group", st) return fmt.Errorf("proto: %s: wiretype end group for non-group", st)
@ -390,11 +491,12 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
if !ok { if !ok {
// Maybe it's an extension? // Maybe it's an extension?
if prop.extendable { if prop.extendable {
if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) { if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) {
if err = o.skip(st, tag, wire); err == nil { if err = o.skip(st, tag, wire); err == nil {
ext := e.ExtensionMap()[int32(tag)] // may be missing extmap := e.extensionsWrite()
ext := extmap[int32(tag)] // may be missing
ext.enc = append(ext.enc, o.buf[oi:o.index]...) ext.enc = append(ext.enc, o.buf[oi:o.index]...)
e.ExtensionMap()[int32(tag)] = ext extmap[int32(tag)] = ext
} }
continue continue
} }
@ -768,10 +870,11 @@ func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
} }
} }
keyelem, valelem := keyptr.Elem(), valptr.Elem() keyelem, valelem := keyptr.Elem(), valptr.Elem()
if !keyelem.IsValid() || !valelem.IsValid() { if !keyelem.IsValid() {
// We did not decode the key or the value in the map entry. keyelem = reflect.Zero(p.mtype.Key())
// Either way, it's an invalid map entry. }
return fmt.Errorf("proto: bad map data: missing key/val") if !valelem.IsValid() {
valelem = reflect.Zero(p.mtype.Elem())
} }
v.SetMapIndex(keyelem, valelem) v.SetMapIndex(keyelem, valelem)

View File

@ -64,8 +64,16 @@ var (
// a struct with a repeated field containing a nil element. // a struct with a repeated field containing a nil element.
errRepeatedHasNil = errors.New("proto: repeated field has nil element") errRepeatedHasNil = errors.New("proto: repeated field has nil element")
// errOneofHasNil is the error returned if Marshal is called with
// a struct with a oneof field containing a nil element.
errOneofHasNil = errors.New("proto: oneof field has nil value")
// ErrNil is the error returned if Marshal is called with nil. // ErrNil is the error returned if Marshal is called with nil.
ErrNil = errors.New("proto: Marshal called with nil") ErrNil = errors.New("proto: Marshal called with nil")
// ErrTooLarge is the error returned if Marshal is called with a
// message that encodes to >2GB.
ErrTooLarge = errors.New("proto: message encodes to over 2 GB")
) )
// The fundamental encoders that put bytes on the wire. // The fundamental encoders that put bytes on the wire.
@ -74,6 +82,10 @@ var (
const maxVarintBytes = 10 // maximum length of a varint const maxVarintBytes = 10 // maximum length of a varint
// maxMarshalSize is the largest allowed size of an encoded protobuf,
// since C++ and Java use signed int32s for the size.
const maxMarshalSize = 1<<31 - 1
// EncodeVarint returns the varint encoding of x. // EncodeVarint returns the varint encoding of x.
// This is the format for the // This is the format for the
// int32, int64, uint32, uint64, bool, and enum // int32, int64, uint32, uint64, bool, and enum
@ -222,10 +234,6 @@ func Marshal(pb Message) ([]byte, error) {
} }
p := NewBuffer(nil) p := NewBuffer(nil)
err := p.Marshal(pb) err := p.Marshal(pb)
var state errorState
if err != nil && !state.shouldContinue(err, nil) {
return nil, err
}
if p.buf == nil && err == nil { if p.buf == nil && err == nil {
// Return a non-nil slice on success. // Return a non-nil slice on success.
return []byte{}, nil return []byte{}, nil
@ -254,11 +262,8 @@ func (p *Buffer) Marshal(pb Message) error {
// Can the object marshal itself? // Can the object marshal itself?
if m, ok := pb.(Marshaler); ok { if m, ok := pb.(Marshaler); ok {
data, err := m.Marshal() data, err := m.Marshal()
if err != nil {
return err
}
p.buf = append(p.buf, data...) p.buf = append(p.buf, data...)
return nil return err
} }
t, base, err := getbase(pb) t, base, err := getbase(pb)
@ -270,9 +275,12 @@ func (p *Buffer) Marshal(pb Message) error {
} }
if collectStats { if collectStats {
stats.Encode++ (stats).Encode++ // Parens are to work around a goimports bug.
} }
if len(p.buf) > maxMarshalSize {
return ErrTooLarge
}
return err return err
} }
@ -294,7 +302,7 @@ func Size(pb Message) (n int) {
} }
if collectStats { if collectStats {
stats.Size++ (stats).Size++ // Parens are to work around a goimports bug.
} }
return return
@ -999,7 +1007,6 @@ func size_slice_struct_message(p *Properties, base structPointer) (n int) {
if p.isMarshaler { if p.isMarshaler {
m := structPointer_Interface(structp, p.stype).(Marshaler) m := structPointer_Interface(structp, p.stype).(Marshaler)
data, _ := m.Marshal() data, _ := m.Marshal()
n += len(p.tagcode)
n += sizeRawBytes(data) n += sizeRawBytes(data)
continue continue
} }
@ -1058,10 +1065,32 @@ func size_slice_struct_group(p *Properties, base structPointer) (n int) {
// Encode an extension map. // Encode an extension map.
func (o *Buffer) enc_map(p *Properties, base structPointer) error { func (o *Buffer) enc_map(p *Properties, base structPointer) error {
v := *structPointer_ExtMap(base, p.field) exts := structPointer_ExtMap(base, p.field)
if err := encodeExtensionMap(v); err != nil { if err := encodeExtensionsMap(*exts); err != nil {
return err return err
} }
return o.enc_map_body(*exts)
}
func (o *Buffer) enc_exts(p *Properties, base structPointer) error {
exts := structPointer_Extensions(base, p.field)
v, mu := exts.extensionsRead()
if v == nil {
return nil
}
mu.Lock()
defer mu.Unlock()
if err := encodeExtensionsMap(v); err != nil {
return err
}
return o.enc_map_body(v)
}
func (o *Buffer) enc_map_body(v map[int32]Extension) error {
// Fast-path for common cases: zero or one extensions. // Fast-path for common cases: zero or one extensions.
if len(v) <= 1 { if len(v) <= 1 {
for _, e := range v { for _, e := range v {
@ -1084,8 +1113,13 @@ func (o *Buffer) enc_map(p *Properties, base structPointer) error {
} }
func size_map(p *Properties, base structPointer) int { func size_map(p *Properties, base structPointer) int {
v := *structPointer_ExtMap(base, p.field) v := structPointer_ExtMap(base, p.field)
return sizeExtensionMap(v) return extensionsMapSize(*v)
}
func size_exts(p *Properties, base structPointer) int {
v := structPointer_Extensions(base, p.field)
return extensionsSize(v)
} }
// Encode a map field. // Encode a map field.
@ -1114,7 +1148,7 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil { if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil {
return err return err
} }
if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil { if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil && err != ErrNil {
return err return err
} }
return nil return nil
@ -1124,11 +1158,6 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
for _, key := range v.MapKeys() { for _, key := range v.MapKeys() {
val := v.MapIndex(key) val := v.MapIndex(key)
// The only illegal map entry values are nil message pointers.
if val.Kind() == reflect.Ptr && val.IsNil() {
return errors.New("proto: map has nil element")
}
keycopy.Set(key) keycopy.Set(key)
valcopy.Set(val) valcopy.Set(val)
@ -1216,13 +1245,18 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
return err return err
} }
} }
if len(o.buf) > maxMarshalSize {
return ErrTooLarge
}
} }
} }
// Do oneof fields. // Do oneof fields.
if prop.oneofMarshaler != nil { if prop.oneofMarshaler != nil {
m := structPointer_Interface(base, prop.stype).(Message) m := structPointer_Interface(base, prop.stype).(Message)
if err := prop.oneofMarshaler(m, o); err != nil { if err := prop.oneofMarshaler(m, o); err == ErrNil {
return errOneofHasNil
} else if err != nil {
return err return err
} }
} }
@ -1230,6 +1264,9 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
// Add unrecognized fields at the end. // Add unrecognized fields at the end.
if prop.unrecField.IsValid() { if prop.unrecField.IsValid() {
v := *structPointer_Bytes(base, prop.unrecField) v := *structPointer_Bytes(base, prop.unrecField)
if len(o.buf)+len(v) > maxMarshalSize {
return ErrTooLarge
}
if len(v) > 0 { if len(v) > 0 {
o.buf = append(o.buf, v...) o.buf = append(o.buf, v...)
} }

View File

@ -54,13 +54,17 @@ Equality is defined in this way:
in a proto3 .proto file, fields are not "set"; specifically, in a proto3 .proto file, fields are not "set"; specifically,
zero length proto3 "bytes" fields are equal (nil == {}). zero length proto3 "bytes" fields are equal (nil == {}).
- Two repeated fields are equal iff their lengths are the same, - Two repeated fields are equal iff their lengths are the same,
and their corresponding elements are equal (a "bytes" field, and their corresponding elements are equal. Note a "bytes" field,
although represented by []byte, is not a repeated field) although represented by []byte, is not a repeated field and the
rule for the scalar fields described above applies.
- Two unset fields are equal. - Two unset fields are equal.
- Two unknown field sets are equal if their current - Two unknown field sets are equal if their current
encoded state is equal. encoded state is equal.
- Two extension sets are equal iff they have corresponding - Two extension sets are equal iff they have corresponding
elements that are pairwise equal. elements that are pairwise equal.
- Two map fields are equal iff their lengths are the same,
and they contain the same set of elements. Zero-length map
fields are equal.
- Every other combination of things are not equal. - Every other combination of things are not equal.
The return value is undefined if a and b are not protocol buffers. The return value is undefined if a and b are not protocol buffers.
@ -121,9 +125,16 @@ func equalStruct(v1, v2 reflect.Value) bool {
} }
} }
if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() {
em2 := v2.FieldByName("XXX_InternalExtensions")
if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) {
return false
}
}
if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() { if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
em2 := v2.FieldByName("XXX_extensions") em2 := v2.FieldByName("XXX_extensions")
if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) { if !equalExtMap(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
return false return false
} }
} }
@ -184,6 +195,13 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
} }
return true return true
case reflect.Ptr: case reflect.Ptr:
// Maps may have nil values in them, so check for nil.
if v1.IsNil() && v2.IsNil() {
return true
}
if v1.IsNil() != v2.IsNil() {
return false
}
return equalAny(v1.Elem(), v2.Elem(), prop) return equalAny(v1.Elem(), v2.Elem(), prop)
case reflect.Slice: case reflect.Slice:
if v1.Type().Elem().Kind() == reflect.Uint8 { if v1.Type().Elem().Kind() == reflect.Uint8 {
@ -223,8 +241,14 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
} }
// base is the struct type that the extensions are based on. // base is the struct type that the extensions are based on.
// em1 and em2 are extension maps. // x1 and x2 are InternalExtensions.
func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool { func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool {
em1, _ := x1.extensionsRead()
em2, _ := x2.extensionsRead()
return equalExtMap(base, em1, em2)
}
func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
if len(em1) != len(em2) { if len(em1) != len(em2) {
return false return false
} }

View File

@ -52,14 +52,99 @@ type ExtensionRange struct {
Start, End int32 // both inclusive Start, End int32 // both inclusive
} }
// extendableProto is an interface implemented by any protocol buffer that may be extended. // extendableProto is an interface implemented by any protocol buffer generated by the current
// proto compiler that may be extended.
type extendableProto interface { type extendableProto interface {
Message
ExtensionRangeArray() []ExtensionRange
extensionsWrite() map[int32]Extension
extensionsRead() (map[int32]Extension, sync.Locker)
}
// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
// version of the proto compiler that may be extended.
type extendableProtoV1 interface {
Message Message
ExtensionRangeArray() []ExtensionRange ExtensionRangeArray() []ExtensionRange
ExtensionMap() map[int32]Extension ExtensionMap() map[int32]Extension
} }
// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
type extensionAdapter struct {
extendableProtoV1
}
func (e extensionAdapter) extensionsWrite() map[int32]Extension {
return e.ExtensionMap()
}
func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
return e.ExtensionMap(), notLocker{}
}
// notLocker is a sync.Locker whose Lock and Unlock methods are nops.
type notLocker struct{}
func (n notLocker) Lock() {}
func (n notLocker) Unlock() {}
// extendable returns the extendableProto interface for the given generated proto message.
// If the proto message has the old extension format, it returns a wrapper that implements
// the extendableProto interface.
func extendable(p interface{}) (extendableProto, bool) {
if ep, ok := p.(extendableProto); ok {
return ep, ok
}
if ep, ok := p.(extendableProtoV1); ok {
return extensionAdapter{ep}, ok
}
return nil, false
}
// XXX_InternalExtensions is an internal representation of proto extensions.
//
// Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
// thus gaining the unexported 'extensions' method, which can be called only from the proto package.
//
// The methods of XXX_InternalExtensions are not concurrency safe in general,
// but calls to logically read-only methods such as has and get may be executed concurrently.
type XXX_InternalExtensions struct {
// The struct must be indirect so that if a user inadvertently copies a
// generated message and its embedded XXX_InternalExtensions, they
// avoid the mayhem of a copied mutex.
//
// The mutex serializes all logically read-only operations to p.extensionMap.
// It is up to the client to ensure that write operations to p.extensionMap are
// mutually exclusive with other accesses.
p *struct {
mu sync.Mutex
extensionMap map[int32]Extension
}
}
// extensionsWrite returns the extension map, creating it on first use.
func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
if e.p == nil {
e.p = new(struct {
mu sync.Mutex
extensionMap map[int32]Extension
})
e.p.extensionMap = make(map[int32]Extension)
}
return e.p.extensionMap
}
// extensionsRead returns the extensions map for read-only use. It may be nil.
// The caller must hold the returned mutex's lock when accessing Elements within the map.
func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
if e.p == nil {
return nil, nil
}
return e.p.extensionMap, &e.p.mu
}
var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
// ExtensionDesc represents an extension specification. // ExtensionDesc represents an extension specification.
// Used in generated code from the protocol compiler. // Used in generated code from the protocol compiler.
@ -69,6 +154,7 @@ type ExtensionDesc struct {
Field int32 // field number Field int32 // field number
Name string // fully-qualified name of extension, for text formatting Name string // fully-qualified name of extension, for text formatting
Tag string // protobuf tag style Tag string // protobuf tag style
Filename string // name of the file in which the extension is defined
} }
func (ed *ExtensionDesc) repeated() bool { func (ed *ExtensionDesc) repeated() bool {
@ -92,8 +178,13 @@ type Extension struct {
} }
// SetRawExtension is for testing only. // SetRawExtension is for testing only.
func SetRawExtension(base extendableProto, id int32, b []byte) { func SetRawExtension(base Message, id int32, b []byte) {
base.ExtensionMap()[id] = Extension{enc: b} epb, ok := extendable(base)
if !ok {
return
}
extmap := epb.extensionsWrite()
extmap[id] = Extension{enc: b}
} }
// isExtensionField returns true iff the given field number is in an extension range. // isExtensionField returns true iff the given field number is in an extension range.
@ -108,8 +199,12 @@ func isExtensionField(pb extendableProto, field int32) bool {
// checkExtensionTypes checks that the given extension is valid for pb. // checkExtensionTypes checks that the given extension is valid for pb.
func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
var pbi interface{} = pb
// Check the extended type. // Check the extended type.
if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b { if ea, ok := pbi.(extensionAdapter); ok {
pbi = ea.extendableProtoV1
}
if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
} }
// Check the range. // Check the range.
@ -155,8 +250,19 @@ func extensionProperties(ed *ExtensionDesc) *Properties {
return prop return prop
} }
// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m. // encode encodes any unmarshaled (unencoded) extensions in e.
func encodeExtensionMap(m map[int32]Extension) error { func encodeExtensions(e *XXX_InternalExtensions) error {
m, mu := e.extensionsRead()
if m == nil {
return nil // fast path
}
mu.Lock()
defer mu.Unlock()
return encodeExtensionsMap(m)
}
// encode encodes any unmarshaled (unencoded) extensions in e.
func encodeExtensionsMap(m map[int32]Extension) error {
for k, e := range m { for k, e := range m {
if e.value == nil || e.desc == nil { if e.value == nil || e.desc == nil {
// Extension is only in its encoded form. // Extension is only in its encoded form.
@ -184,7 +290,17 @@ func encodeExtensionMap(m map[int32]Extension) error {
return nil return nil
} }
func sizeExtensionMap(m map[int32]Extension) (n int) { func extensionsSize(e *XXX_InternalExtensions) (n int) {
m, mu := e.extensionsRead()
if m == nil {
return 0
}
mu.Lock()
defer mu.Unlock()
return extensionsMapSize(m)
}
func extensionsMapSize(m map[int32]Extension) (n int) {
for _, e := range m { for _, e := range m {
if e.value == nil || e.desc == nil { if e.value == nil || e.desc == nil {
// Extension is only in its encoded form. // Extension is only in its encoded form.
@ -209,26 +325,51 @@ func sizeExtensionMap(m map[int32]Extension) (n int) {
} }
// HasExtension returns whether the given extension is present in pb. // HasExtension returns whether the given extension is present in pb.
func HasExtension(pb extendableProto, extension *ExtensionDesc) bool { func HasExtension(pb Message, extension *ExtensionDesc) bool {
// TODO: Check types, field numbers, etc.? // TODO: Check types, field numbers, etc.?
_, ok := pb.ExtensionMap()[extension.Field] epb, ok := extendable(pb)
if !ok {
return false
}
extmap, mu := epb.extensionsRead()
if extmap == nil {
return false
}
mu.Lock()
_, ok = extmap[extension.Field]
mu.Unlock()
return ok return ok
} }
// ClearExtension removes the given extension from pb. // ClearExtension removes the given extension from pb.
func ClearExtension(pb extendableProto, extension *ExtensionDesc) { func ClearExtension(pb Message, extension *ExtensionDesc) {
epb, ok := extendable(pb)
if !ok {
return
}
// TODO: Check types, field numbers, etc.? // TODO: Check types, field numbers, etc.?
delete(pb.ExtensionMap(), extension.Field) extmap := epb.extensionsWrite()
delete(extmap, extension.Field)
} }
// GetExtension parses and returns the given extension of pb. // GetExtension parses and returns the given extension of pb.
// If the extension is not present and has no default value it returns ErrMissingExtension. // If the extension is not present and has no default value it returns ErrMissingExtension.
func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) { func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
if err := checkExtensionTypes(pb, extension); err != nil { epb, ok := extendable(pb)
if !ok {
return nil, errors.New("proto: not an extendable proto")
}
if err := checkExtensionTypes(epb, extension); err != nil {
return nil, err return nil, err
} }
emap := pb.ExtensionMap() emap, mu := epb.extensionsRead()
if emap == nil {
return defaultExtensionValue(extension)
}
mu.Lock()
defer mu.Unlock()
e, ok := emap[extension.Field] e, ok := emap[extension.Field]
if !ok { if !ok {
// defaultExtensionValue returns the default value or // defaultExtensionValue returns the default value or
@ -332,10 +473,9 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
// GetExtensions returns a slice of the extensions present in pb that are also listed in es. // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
// The returned slice has the same length as es; missing extensions will appear as nil elements. // The returned slice has the same length as es; missing extensions will appear as nil elements.
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
epb, ok := pb.(extendableProto) epb, ok := extendable(pb)
if !ok { if !ok {
err = errors.New("proto: not an extendable proto") return nil, errors.New("proto: not an extendable proto")
return
} }
extensions = make([]interface{}, len(es)) extensions = make([]interface{}, len(es))
for i, e := range es { for i, e := range es {
@ -350,9 +490,44 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e
return return
} }
// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order.
// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
// just the Field field, which defines the extension's field number.
func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
epb, ok := extendable(pb)
if !ok {
return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb)
}
registeredExtensions := RegisteredExtensions(pb)
emap, mu := epb.extensionsRead()
if emap == nil {
return nil, nil
}
mu.Lock()
defer mu.Unlock()
extensions := make([]*ExtensionDesc, 0, len(emap))
for extid, e := range emap {
desc := e.desc
if desc == nil {
desc = registeredExtensions[extid]
if desc == nil {
desc = &ExtensionDesc{Field: extid}
}
}
extensions = append(extensions, desc)
}
return extensions, nil
}
// SetExtension sets the specified extension of pb to the specified value. // SetExtension sets the specified extension of pb to the specified value.
func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error { func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
if err := checkExtensionTypes(pb, extension); err != nil { epb, ok := extendable(pb)
if !ok {
return errors.New("proto: not an extendable proto")
}
if err := checkExtensionTypes(epb, extension); err != nil {
return err return err
} }
typ := reflect.TypeOf(extension.ExtensionType) typ := reflect.TypeOf(extension.ExtensionType)
@ -368,10 +543,23 @@ func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{
return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
} }
pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value} extmap := epb.extensionsWrite()
extmap[extension.Field] = Extension{desc: extension, value: value}
return nil return nil
} }
// ClearAllExtensions clears all extensions from pb.
func ClearAllExtensions(pb Message) {
epb, ok := extendable(pb)
if !ok {
return
}
m := epb.extensionsWrite()
for k := range m {
delete(m, k)
}
}
// A global registry of extensions. // A global registry of extensions.
// The generated code will register the generated descriptors by calling RegisterExtension. // The generated code will register the generated descriptors by calling RegisterExtension.

View File

@ -235,6 +235,7 @@ To create and play with a Test object:
test := &pb.Test{ test := &pb.Test{
Label: proto.String("hello"), Label: proto.String("hello"),
Type: proto.Int32(17), Type: proto.Int32(17),
Reps: []int64{1, 2, 3},
Optionalgroup: &pb.Test_OptionalGroup{ Optionalgroup: &pb.Test_OptionalGroup{
RequiredField: proto.String("good bye"), RequiredField: proto.String("good bye"),
}, },
@ -307,7 +308,7 @@ func GetStats() Stats { return stats }
// temporary Buffer and are fine for most applications. // temporary Buffer and are fine for most applications.
type Buffer struct { type Buffer struct {
buf []byte // encode/decode byte stream buf []byte // encode/decode byte stream
index int // write point index int // read point
// pools of basic types to amortize allocation. // pools of basic types to amortize allocation.
bools []bool bools []bool
@ -888,6 +889,10 @@ func isProto3Zero(v reflect.Value) bool {
return false return false
} }
// ProtoPackageIsVersion2 is referenced from generated protocol buffer files
// to assert that that code is compatible with this version of the proto package.
const ProtoPackageIsVersion2 = true
// ProtoPackageIsVersion1 is referenced from generated protocol buffer files // ProtoPackageIsVersion1 is referenced from generated protocol buffer files
// to assert that that code is compatible with this version of the proto package. // to assert that that code is compatible with this version of the proto package.
const ProtoPackageIsVersion1 = true const ProtoPackageIsVersion1 = true

View File

@ -149,9 +149,21 @@ func skipVarint(buf []byte) []byte {
// MarshalMessageSet encodes the extension map represented by m in the message set wire format. // MarshalMessageSet encodes the extension map represented by m in the message set wire format.
// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option. // It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
func MarshalMessageSet(m map[int32]Extension) ([]byte, error) { func MarshalMessageSet(exts interface{}) ([]byte, error) {
if err := encodeExtensionMap(m); err != nil { var m map[int32]Extension
return nil, err switch exts := exts.(type) {
case *XXX_InternalExtensions:
if err := encodeExtensions(exts); err != nil {
return nil, err
}
m, _ = exts.extensionsRead()
case map[int32]Extension:
if err := encodeExtensionsMap(exts); err != nil {
return nil, err
}
m = exts
default:
return nil, errors.New("proto: not an extension map")
} }
// Sort extension IDs to provide a deterministic encoding. // Sort extension IDs to provide a deterministic encoding.
@ -178,7 +190,17 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format. // UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option. // It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error { func UnmarshalMessageSet(buf []byte, exts interface{}) error {
var m map[int32]Extension
switch exts := exts.(type) {
case *XXX_InternalExtensions:
m = exts.extensionsWrite()
case map[int32]Extension:
m = exts
default:
return errors.New("proto: not an extension map")
}
ms := new(messageSet) ms := new(messageSet)
if err := Unmarshal(buf, ms); err != nil { if err := Unmarshal(buf, ms); err != nil {
return err return err
@ -209,7 +231,16 @@ func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
// MarshalMessageSetJSON encodes the extension map represented by m in JSON format. // MarshalMessageSetJSON encodes the extension map represented by m in JSON format.
// It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option. // It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) { func MarshalMessageSetJSON(exts interface{}) ([]byte, error) {
var m map[int32]Extension
switch exts := exts.(type) {
case *XXX_InternalExtensions:
m, _ = exts.extensionsRead()
case map[int32]Extension:
m = exts
default:
return nil, errors.New("proto: not an extension map")
}
var b bytes.Buffer var b bytes.Buffer
b.WriteByte('{') b.WriteByte('{')
@ -252,7 +283,7 @@ func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) {
// UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format. // UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format.
// It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option. // It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error { func UnmarshalMessageSetJSON(buf []byte, exts interface{}) error {
// Common-case fast path. // Common-case fast path.
if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) { if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) {
return nil return nil

View File

@ -29,7 +29,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// +build appengine // +build appengine js
// This file contains an implementation of proto field accesses using package reflect. // This file contains an implementation of proto field accesses using package reflect.
// It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can // It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can
@ -139,6 +139,11 @@ func structPointer_StringSlice(p structPointer, f field) *[]string {
return structPointer_ifield(p, f).(*[]string) return structPointer_ifield(p, f).(*[]string)
} }
// Extensions returns the address of an extension map field in the struct.
func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions {
return structPointer_ifield(p, f).(*XXX_InternalExtensions)
}
// ExtMap returns the address of an extension map field in the struct. // ExtMap returns the address of an extension map field in the struct.
func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return structPointer_ifield(p, f).(*map[int32]Extension) return structPointer_ifield(p, f).(*map[int32]Extension)

View File

@ -29,7 +29,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// +build !appengine // +build !appengine,!js
// This file contains the implementation of the proto field accesses using package unsafe. // This file contains the implementation of the proto field accesses using package unsafe.
@ -126,6 +126,10 @@ func structPointer_StringSlice(p structPointer, f field) *[]string {
} }
// ExtMap returns the address of an extension map field in the struct. // ExtMap returns the address of an extension map field in the struct.
func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions {
return (*XXX_InternalExtensions)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f))) return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
} }

View File

@ -173,6 +173,7 @@ func (sp *StructProperties) Swap(i, j int) { sp.order[i], sp.order[j] = sp.order
type Properties struct { type Properties struct {
Name string // name of the field, for error messages Name string // name of the field, for error messages
OrigName string // original name before protocol compiler (always set) OrigName string // original name before protocol compiler (always set)
JSONName string // name to use for JSON; determined by protoc
Wire string Wire string
WireType int WireType int
Tag int Tag int
@ -229,8 +230,9 @@ func (p *Properties) String() string {
if p.Packed { if p.Packed {
s += ",packed" s += ",packed"
} }
if p.OrigName != p.Name { s += ",name=" + p.OrigName
s += ",name=" + p.OrigName if p.JSONName != p.OrigName {
s += ",json=" + p.JSONName
} }
if p.proto3 { if p.proto3 {
s += ",proto3" s += ",proto3"
@ -310,6 +312,8 @@ func (p *Properties) Parse(s string) {
p.Packed = true p.Packed = true
case strings.HasPrefix(f, "name="): case strings.HasPrefix(f, "name="):
p.OrigName = f[5:] p.OrigName = f[5:]
case strings.HasPrefix(f, "json="):
p.JSONName = f[5:]
case strings.HasPrefix(f, "enum="): case strings.HasPrefix(f, "enum="):
p.Enum = f[5:] p.Enum = f[5:]
case f == "proto3": case f == "proto3":
@ -469,17 +473,13 @@ func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lock
p.dec = (*Buffer).dec_slice_int64 p.dec = (*Buffer).dec_slice_int64
p.packedDec = (*Buffer).dec_slice_packed_int64 p.packedDec = (*Buffer).dec_slice_packed_int64
case reflect.Uint8: case reflect.Uint8:
p.enc = (*Buffer).enc_slice_byte
p.dec = (*Buffer).dec_slice_byte p.dec = (*Buffer).dec_slice_byte
p.size = size_slice_byte if p.proto3 {
// This is a []byte, which is either a bytes field,
// or the value of a map field. In the latter case,
// we always encode an empty []byte, so we should not
// use the proto3 enc/size funcs.
// f == nil iff this is the key/value of a map field.
if p.proto3 && f != nil {
p.enc = (*Buffer).enc_proto3_slice_byte p.enc = (*Buffer).enc_proto3_slice_byte
p.size = size_proto3_slice_byte p.size = size_proto3_slice_byte
} else {
p.enc = (*Buffer).enc_slice_byte
p.size = size_slice_byte
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
switch t2.Bits() { switch t2.Bits() {
@ -678,7 +678,8 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
propertiesMap[t] = prop propertiesMap[t] = prop
// build properties // build properties
prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) ||
reflect.PtrTo(t).Implements(extendableProtoV1Type)
prop.unrecField = invalidField prop.unrecField = invalidField
prop.Prop = make([]*Properties, t.NumField()) prop.Prop = make([]*Properties, t.NumField())
prop.order = make([]int, t.NumField()) prop.order = make([]int, t.NumField())
@ -689,15 +690,22 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
name := f.Name name := f.Name
p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false) p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false)
if f.Name == "XXX_extensions" { // special case if f.Name == "XXX_InternalExtensions" { // special case
p.enc = (*Buffer).enc_exts
p.dec = nil // not needed
p.size = size_exts
} else if f.Name == "XXX_extensions" { // special case
p.enc = (*Buffer).enc_map p.enc = (*Buffer).enc_map
p.dec = nil // not needed p.dec = nil // not needed
p.size = size_map p.size = size_map
} } else if f.Name == "XXX_unrecognized" { // special case
if f.Name == "XXX_unrecognized" { // special case
prop.unrecField = toField(&f) prop.unrecField = toField(&f)
} }
oneof := f.Tag.Get("protobuf_oneof") != "" // special case oneof := f.Tag.Get("protobuf_oneof") // special case
if oneof != "" {
// Oneof fields don't use the traditional protobuf tag.
p.OrigName = oneof
}
prop.Prop[i] = p prop.Prop[i] = p
prop.order[i] = i prop.order[i] = i
if debug { if debug {
@ -707,7 +715,7 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
} }
print("\n") print("\n")
} }
if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && !oneof { if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && oneof == "" {
fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]") fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]")
} }
} }
@ -836,7 +844,29 @@ func RegisterType(x Message, name string) {
} }
// MessageName returns the fully-qualified proto name for the given message type. // MessageName returns the fully-qualified proto name for the given message type.
func MessageName(x Message) string { return revProtoTypes[reflect.TypeOf(x)] } func MessageName(x Message) string {
type xname interface {
XXX_MessageName() string
}
if m, ok := x.(xname); ok {
return m.XXX_MessageName()
}
return revProtoTypes[reflect.TypeOf(x)]
}
// MessageType returns the message type (pointer to struct) for a named message. // MessageType returns the message type (pointer to struct) for a named message.
func MessageType(name string) reflect.Type { return protoTypes[name] } func MessageType(name string) reflect.Type { return protoTypes[name] }
// A registry of all linked proto files.
var (
protoFiles = make(map[string][]byte) // file name => fileDescriptor
)
// RegisterFile is called from generated code and maps from the
// full file name of a .proto file to its compressed FileDescriptorProto.
func RegisterFile(filename string, fileDescriptor []byte) {
protoFiles[filename] = fileDescriptor
}
// FileDescriptor returns the compressed FileDescriptorProto for a .proto file.
func FileDescriptor(filename string) []byte { return protoFiles[filename] }

View File

@ -154,7 +154,7 @@ func (w *textWriter) indent() { w.ind++ }
func (w *textWriter) unindent() { func (w *textWriter) unindent() {
if w.ind == 0 { if w.ind == 0 {
log.Printf("proto: textWriter unindented too far") log.Print("proto: textWriter unindented too far")
return return
} }
w.ind-- w.ind--
@ -175,7 +175,93 @@ type raw interface {
Bytes() []byte Bytes() []byte
} }
func writeStruct(w *textWriter, sv reflect.Value) error { func requiresQuotes(u string) bool {
// When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted.
for _, ch := range u {
switch {
case ch == '.' || ch == '/' || ch == '_':
continue
case '0' <= ch && ch <= '9':
continue
case 'A' <= ch && ch <= 'Z':
continue
case 'a' <= ch && ch <= 'z':
continue
default:
return true
}
}
return false
}
// isAny reports whether sv is a google.protobuf.Any message
func isAny(sv reflect.Value) bool {
type wkt interface {
XXX_WellKnownType() string
}
t, ok := sv.Addr().Interface().(wkt)
return ok && t.XXX_WellKnownType() == "Any"
}
// writeProto3Any writes an expanded google.protobuf.Any message.
//
// It returns (false, nil) if sv value can't be unmarshaled (e.g. because
// required messages are not linked in).
//
// It returns (true, error) when sv was written in expanded format or an error
// was encountered.
func (tm *TextMarshaler) writeProto3Any(w *textWriter, sv reflect.Value) (bool, error) {
turl := sv.FieldByName("TypeUrl")
val := sv.FieldByName("Value")
if !turl.IsValid() || !val.IsValid() {
return true, errors.New("proto: invalid google.protobuf.Any message")
}
b, ok := val.Interface().([]byte)
if !ok {
return true, errors.New("proto: invalid google.protobuf.Any message")
}
parts := strings.Split(turl.String(), "/")
mt := MessageType(parts[len(parts)-1])
if mt == nil {
return false, nil
}
m := reflect.New(mt.Elem())
if err := Unmarshal(b, m.Interface().(Message)); err != nil {
return false, nil
}
w.Write([]byte("["))
u := turl.String()
if requiresQuotes(u) {
writeString(w, u)
} else {
w.Write([]byte(u))
}
if w.compact {
w.Write([]byte("]:<"))
} else {
w.Write([]byte("]: <\n"))
w.ind++
}
if err := tm.writeStruct(w, m.Elem()); err != nil {
return true, err
}
if w.compact {
w.Write([]byte("> "))
} else {
w.ind--
w.Write([]byte(">\n"))
}
return true, nil
}
func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
if tm.ExpandAny && isAny(sv) {
if canExpand, err := tm.writeProto3Any(w, sv); canExpand {
return err
}
}
st := sv.Type() st := sv.Type()
sprops := GetProperties(st) sprops := GetProperties(st)
for i := 0; i < sv.NumField(); i++ { for i := 0; i < sv.NumField(); i++ {
@ -227,7 +313,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
} }
continue continue
} }
if err := writeAny(w, v, props); err != nil { if err := tm.writeAny(w, v, props); err != nil {
return err return err
} }
if err := w.WriteByte('\n'); err != nil { if err := w.WriteByte('\n'); err != nil {
@ -269,7 +355,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
return err return err
} }
} }
if err := writeAny(w, key, props.mkeyprop); err != nil { if err := tm.writeAny(w, key, props.mkeyprop); err != nil {
return err return err
} }
if err := w.WriteByte('\n'); err != nil { if err := w.WriteByte('\n'); err != nil {
@ -286,7 +372,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
return err return err
} }
} }
if err := writeAny(w, val, props.mvalprop); err != nil { if err := tm.writeAny(w, val, props.mvalprop); err != nil {
return err return err
} }
if err := w.WriteByte('\n'); err != nil { if err := w.WriteByte('\n'); err != nil {
@ -358,7 +444,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
} }
// Enums have a String method, so writeAny will work fine. // Enums have a String method, so writeAny will work fine.
if err := writeAny(w, fv, props); err != nil { if err := tm.writeAny(w, fv, props); err != nil {
return err return err
} }
@ -369,8 +455,8 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
// Extensions (the XXX_extensions field). // Extensions (the XXX_extensions field).
pv := sv.Addr() pv := sv.Addr()
if pv.Type().Implements(extendableProtoType) { if _, ok := extendable(pv.Interface()); ok {
if err := writeExtensions(w, pv); err != nil { if err := tm.writeExtensions(w, pv); err != nil {
return err return err
} }
} }
@ -400,7 +486,7 @@ func writeRaw(w *textWriter, b []byte) error {
} }
// writeAny writes an arbitrary field. // writeAny writes an arbitrary field.
func writeAny(w *textWriter, v reflect.Value, props *Properties) error { func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Properties) error {
v = reflect.Indirect(v) v = reflect.Indirect(v)
// Floats have special cases. // Floats have special cases.
@ -427,7 +513,7 @@ func writeAny(w *textWriter, v reflect.Value, props *Properties) error {
switch v.Kind() { switch v.Kind() {
case reflect.Slice: case reflect.Slice:
// Should only be a []byte; repeated fields are handled in writeStruct. // Should only be a []byte; repeated fields are handled in writeStruct.
if err := writeString(w, string(v.Interface().([]byte))); err != nil { if err := writeString(w, string(v.Bytes())); err != nil {
return err return err
} }
case reflect.String: case reflect.String:
@ -449,15 +535,15 @@ func writeAny(w *textWriter, v reflect.Value, props *Properties) error {
} }
} }
w.indent() w.indent()
if tm, ok := v.Interface().(encoding.TextMarshaler); ok { if etm, ok := v.Interface().(encoding.TextMarshaler); ok {
text, err := tm.MarshalText() text, err := etm.MarshalText()
if err != nil { if err != nil {
return err return err
} }
if _, err = w.Write(text); err != nil { if _, err = w.Write(text); err != nil {
return err return err
} }
} else if err := writeStruct(w, v); err != nil { } else if err := tm.writeStruct(w, v); err != nil {
return err return err
} }
w.unindent() w.unindent()
@ -601,19 +687,24 @@ func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// writeExtensions writes all the extensions in pv. // writeExtensions writes all the extensions in pv.
// pv is assumed to be a pointer to a protocol message struct that is extendable. // pv is assumed to be a pointer to a protocol message struct that is extendable.
func writeExtensions(w *textWriter, pv reflect.Value) error { func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error {
emap := extensionMaps[pv.Type().Elem()] emap := extensionMaps[pv.Type().Elem()]
ep := pv.Interface().(extendableProto) ep, _ := extendable(pv.Interface())
// Order the extensions by ID. // Order the extensions by ID.
// This isn't strictly necessary, but it will give us // This isn't strictly necessary, but it will give us
// canonical output, which will also make testing easier. // canonical output, which will also make testing easier.
m := ep.ExtensionMap() m, mu := ep.extensionsRead()
if m == nil {
return nil
}
mu.Lock()
ids := make([]int32, 0, len(m)) ids := make([]int32, 0, len(m))
for id := range m { for id := range m {
ids = append(ids, id) ids = append(ids, id)
} }
sort.Sort(int32Slice(ids)) sort.Sort(int32Slice(ids))
mu.Unlock()
for _, extNum := range ids { for _, extNum := range ids {
ext := m[extNum] ext := m[extNum]
@ -636,13 +727,13 @@ func writeExtensions(w *textWriter, pv reflect.Value) error {
// Repeated extensions will appear as a slice. // Repeated extensions will appear as a slice.
if !desc.repeated() { if !desc.repeated() {
if err := writeExtension(w, desc.Name, pb); err != nil { if err := tm.writeExtension(w, desc.Name, pb); err != nil {
return err return err
} }
} else { } else {
v := reflect.ValueOf(pb) v := reflect.ValueOf(pb)
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
if err := writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil { if err := tm.writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil {
return err return err
} }
} }
@ -651,7 +742,7 @@ func writeExtensions(w *textWriter, pv reflect.Value) error {
return nil return nil
} }
func writeExtension(w *textWriter, name string, pb interface{}) error { func (tm *TextMarshaler) writeExtension(w *textWriter, name string, pb interface{}) error {
if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil { if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil {
return err return err
} }
@ -660,7 +751,7 @@ func writeExtension(w *textWriter, name string, pb interface{}) error {
return err return err
} }
} }
if err := writeAny(w, reflect.ValueOf(pb), nil); err != nil { if err := tm.writeAny(w, reflect.ValueOf(pb), nil); err != nil {
return err return err
} }
if err := w.WriteByte('\n'); err != nil { if err := w.WriteByte('\n'); err != nil {
@ -685,7 +776,15 @@ func (w *textWriter) writeIndent() {
w.complete = false w.complete = false
} }
func marshalText(w io.Writer, pb Message, compact bool) error { // TextMarshaler is a configurable text format marshaler.
type TextMarshaler struct {
Compact bool // use compact text format (one line).
ExpandAny bool // expand google.protobuf.Any messages of known types
}
// Marshal writes a given protocol buffer in text format.
// The only errors returned are from w.
func (tm *TextMarshaler) Marshal(w io.Writer, pb Message) error {
val := reflect.ValueOf(pb) val := reflect.ValueOf(pb)
if pb == nil || val.IsNil() { if pb == nil || val.IsNil() {
w.Write([]byte("<nil>")) w.Write([]byte("<nil>"))
@ -700,11 +799,11 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
aw := &textWriter{ aw := &textWriter{
w: ww, w: ww,
complete: true, complete: true,
compact: compact, compact: tm.Compact,
} }
if tm, ok := pb.(encoding.TextMarshaler); ok { if etm, ok := pb.(encoding.TextMarshaler); ok {
text, err := tm.MarshalText() text, err := etm.MarshalText()
if err != nil { if err != nil {
return err return err
} }
@ -718,7 +817,7 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
} }
// Dereference the received pointer so we don't have outer < and >. // Dereference the received pointer so we don't have outer < and >.
v := reflect.Indirect(val) v := reflect.Indirect(val)
if err := writeStruct(aw, v); err != nil { if err := tm.writeStruct(aw, v); err != nil {
return err return err
} }
if bw != nil { if bw != nil {
@ -727,25 +826,29 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
return nil return nil
} }
// Text is the same as Marshal, but returns the string directly.
func (tm *TextMarshaler) Text(pb Message) string {
var buf bytes.Buffer
tm.Marshal(&buf, pb)
return buf.String()
}
var (
defaultTextMarshaler = TextMarshaler{}
compactTextMarshaler = TextMarshaler{Compact: true}
)
// TODO: consider removing some of the Marshal functions below.
// MarshalText writes a given protocol buffer in text format. // MarshalText writes a given protocol buffer in text format.
// The only errors returned are from w. // The only errors returned are from w.
func MarshalText(w io.Writer, pb Message) error { func MarshalText(w io.Writer, pb Message) error { return defaultTextMarshaler.Marshal(w, pb) }
return marshalText(w, pb, false)
}
// MarshalTextString is the same as MarshalText, but returns the string directly. // MarshalTextString is the same as MarshalText, but returns the string directly.
func MarshalTextString(pb Message) string { func MarshalTextString(pb Message) string { return defaultTextMarshaler.Text(pb) }
var buf bytes.Buffer
marshalText(&buf, pb, false)
return buf.String()
}
// CompactText writes a given protocol buffer in compact text format (one line). // CompactText writes a given protocol buffer in compact text format (one line).
func CompactText(w io.Writer, pb Message) error { return marshalText(w, pb, true) } func CompactText(w io.Writer, pb Message) error { return compactTextMarshaler.Marshal(w, pb) }
// CompactTextString is the same as CompactText, but returns the string directly. // CompactTextString is the same as CompactText, but returns the string directly.
func CompactTextString(pb Message) string { func CompactTextString(pb Message) string { return compactTextMarshaler.Text(pb) }
var buf bytes.Buffer
marshalText(&buf, pb, true)
return buf.String()
}

View File

@ -44,6 +44,9 @@ import (
"unicode/utf8" "unicode/utf8"
) )
// Error string emitted when deserializing Any and fields are already set
const anyRepeatedlyUnpacked = "Any message unpacked multiple times, or %q already set"
type ParseError struct { type ParseError struct {
Message string Message string
Line int // 1-based line number Line int // 1-based line number
@ -119,6 +122,14 @@ func isWhitespace(c byte) bool {
return false return false
} }
func isQuote(c byte) bool {
switch c {
case '"', '\'':
return true
}
return false
}
func (p *textParser) skipWhitespace() { func (p *textParser) skipWhitespace() {
i := 0 i := 0
for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') { for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') {
@ -155,7 +166,7 @@ func (p *textParser) advance() {
p.cur.offset, p.cur.line = p.offset, p.line p.cur.offset, p.cur.line = p.offset, p.line
p.cur.unquoted = "" p.cur.unquoted = ""
switch p.s[0] { switch p.s[0] {
case '<', '>', '{', '}', ':', '[', ']', ';', ',': case '<', '>', '{', '}', ':', '[', ']', ';', ',', '/':
// Single symbol // Single symbol
p.cur.value, p.s = p.s[0:1], p.s[1:len(p.s)] p.cur.value, p.s = p.s[0:1], p.s[1:len(p.s)]
case '"', '\'': case '"', '\'':
@ -333,13 +344,13 @@ func (p *textParser) next() *token {
p.advance() p.advance()
if p.done { if p.done {
p.cur.value = "" p.cur.value = ""
} else if len(p.cur.value) > 0 && p.cur.value[0] == '"' { } else if len(p.cur.value) > 0 && isQuote(p.cur.value[0]) {
// Look for multiple quoted strings separated by whitespace, // Look for multiple quoted strings separated by whitespace,
// and concatenate them. // and concatenate them.
cat := p.cur cat := p.cur
for { for {
p.skipWhitespace() p.skipWhitespace()
if p.done || p.s[0] != '"' { if p.done || !isQuote(p.s[0]) {
break break
} }
p.advance() p.advance()
@ -443,7 +454,10 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
fieldSet := make(map[string]bool) fieldSet := make(map[string]bool)
// A struct is a sequence of "name: value", terminated by one of // A struct is a sequence of "name: value", terminated by one of
// '>' or '}', or the end of the input. A name may also be // '>' or '}', or the end of the input. A name may also be
// "[extension]". // "[extension]" or "[type/url]".
//
// The whole struct can also be an expanded Any message, like:
// [type/url] < ... struct contents ... >
for { for {
tok := p.next() tok := p.next()
if tok.err != nil { if tok.err != nil {
@ -453,33 +467,74 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
break break
} }
if tok.value == "[" { if tok.value == "[" {
// Looks like an extension. // Looks like an extension or an Any.
// //
// TODO: Check whether we need to handle // TODO: Check whether we need to handle
// namespace rooted names (e.g. ".something.Foo"). // namespace rooted names (e.g. ".something.Foo").
tok = p.next() extName, err := p.consumeExtName()
if tok.err != nil { if err != nil {
return tok.err return err
} }
if s := strings.LastIndex(extName, "/"); s >= 0 {
// If it contains a slash, it's an Any type URL.
messageName := extName[s+1:]
mt := MessageType(messageName)
if mt == nil {
return p.errorf("unrecognized message %q in google.protobuf.Any", messageName)
}
tok = p.next()
if tok.err != nil {
return tok.err
}
// consume an optional colon
if tok.value == ":" {
tok = p.next()
if tok.err != nil {
return tok.err
}
}
var terminator string
switch tok.value {
case "<":
terminator = ">"
case "{":
terminator = "}"
default:
return p.errorf("expected '{' or '<', found %q", tok.value)
}
v := reflect.New(mt.Elem())
if pe := p.readStruct(v.Elem(), terminator); pe != nil {
return pe
}
b, err := Marshal(v.Interface().(Message))
if err != nil {
return p.errorf("failed to marshal message of type %q: %v", messageName, err)
}
if fieldSet["type_url"] {
return p.errorf(anyRepeatedlyUnpacked, "type_url")
}
if fieldSet["value"] {
return p.errorf(anyRepeatedlyUnpacked, "value")
}
sv.FieldByName("TypeUrl").SetString(extName)
sv.FieldByName("Value").SetBytes(b)
fieldSet["type_url"] = true
fieldSet["value"] = true
continue
}
var desc *ExtensionDesc var desc *ExtensionDesc
// This could be faster, but it's functional. // This could be faster, but it's functional.
// TODO: Do something smarter than a linear scan. // TODO: Do something smarter than a linear scan.
for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) { for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) {
if d.Name == tok.value { if d.Name == extName {
desc = d desc = d
break break
} }
} }
if desc == nil { if desc == nil {
return p.errorf("unrecognized extension %q", tok.value) return p.errorf("unrecognized extension %q", extName)
}
// Check the extension terminator.
tok = p.next()
if tok.err != nil {
return tok.err
}
if tok.value != "]" {
return p.errorf("unrecognized extension terminator %q", tok.value)
} }
props := &Properties{} props := &Properties{}
@ -506,7 +561,7 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
} }
reqFieldErr = err reqFieldErr = err
} }
ep := sv.Addr().Interface().(extendableProto) ep := sv.Addr().Interface().(Message)
if !rep { if !rep {
SetExtension(ep, desc, ext.Interface()) SetExtension(ep, desc, ext.Interface())
} else { } else {
@ -537,7 +592,11 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
props = oop.Prop props = oop.Prop
nv := reflect.New(oop.Type.Elem()) nv := reflect.New(oop.Type.Elem())
dst = nv.Elem().Field(0) dst = nv.Elem().Field(0)
sv.Field(oop.Field).Set(nv) field := sv.Field(oop.Field)
if !field.IsNil() {
return p.errorf("field '%s' would overwrite already parsed oneof '%s'", name, sv.Type().Field(oop.Field).Name)
}
field.Set(nv)
} }
if !dst.IsValid() { if !dst.IsValid() {
return p.errorf("unknown field name %q in %v", name, st) return p.errorf("unknown field name %q in %v", name, st)
@ -558,8 +617,9 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
// The map entry should be this sequence of tokens: // The map entry should be this sequence of tokens:
// < key : KEY value : VALUE > // < key : KEY value : VALUE >
// Technically the "key" and "value" could come in any order, // However, implementations may omit key or value, and technically
// but in practice they won't. // we should support them in any order. See b/28924776 for a time
// this went wrong.
tok := p.next() tok := p.next()
var terminator string var terminator string
@ -571,32 +631,39 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
default: default:
return p.errorf("expected '{' or '<', found %q", tok.value) return p.errorf("expected '{' or '<', found %q", tok.value)
} }
if err := p.consumeToken("key"); err != nil { for {
return err tok := p.next()
} if tok.err != nil {
if err := p.consumeToken(":"); err != nil { return tok.err
return err }
} if tok.value == terminator {
if err := p.readAny(key, props.mkeyprop); err != nil { break
return err }
} switch tok.value {
if err := p.consumeOptionalSeparator(); err != nil { case "key":
return err if err := p.consumeToken(":"); err != nil {
} return err
if err := p.consumeToken("value"); err != nil { }
return err if err := p.readAny(key, props.mkeyprop); err != nil {
} return err
if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil { }
return err if err := p.consumeOptionalSeparator(); err != nil {
} return err
if err := p.readAny(val, props.mvalprop); err != nil { }
return err case "value":
} if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil {
if err := p.consumeOptionalSeparator(); err != nil { return err
return err }
} if err := p.readAny(val, props.mvalprop); err != nil {
if err := p.consumeToken(terminator); err != nil { return err
return err }
if err := p.consumeOptionalSeparator(); err != nil {
return err
}
default:
p.back()
return p.errorf(`expected "key", "value", or %q, found %q`, terminator, tok.value)
}
} }
dst.SetMapIndex(key, val) dst.SetMapIndex(key, val)
@ -619,7 +686,8 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
return err return err
} }
reqFieldErr = err reqFieldErr = err
} else if props.Required { }
if props.Required {
reqCount-- reqCount--
} }
@ -635,6 +703,35 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
return reqFieldErr return reqFieldErr
} }
// consumeExtName consumes extension name or expanded Any type URL and the
// following ']'. It returns the name or URL consumed.
func (p *textParser) consumeExtName() (string, error) {
tok := p.next()
if tok.err != nil {
return "", tok.err
}
// If extension name or type url is quoted, it's a single token.
if len(tok.value) > 2 && isQuote(tok.value[0]) && tok.value[len(tok.value)-1] == tok.value[0] {
name, err := unquoteC(tok.value[1:len(tok.value)-1], rune(tok.value[0]))
if err != nil {
return "", err
}
return name, p.consumeToken("]")
}
// Consume everything up to "]"
var parts []string
for tok.value != "]" {
parts = append(parts, tok.value)
tok = p.next()
if tok.err != nil {
return "", p.errorf("unrecognized type_url or extension name: %s", tok.err)
}
}
return strings.Join(parts, ""), nil
}
// consumeOptionalSeparator consumes an optional semicolon or comma. // consumeOptionalSeparator consumes an optional semicolon or comma.
// It is used in readStruct to provide backward compatibility. // It is used in readStruct to provide backward compatibility.
func (p *textParser) consumeOptionalSeparator() error { func (p *textParser) consumeOptionalSeparator() error {
@ -699,12 +796,12 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) error {
fv.Set(reflect.Append(fv, reflect.New(at.Elem()).Elem())) fv.Set(reflect.Append(fv, reflect.New(at.Elem()).Elem()))
return p.readAny(fv.Index(fv.Len()-1), props) return p.readAny(fv.Index(fv.Len()-1), props)
case reflect.Bool: case reflect.Bool:
// Either "true", "false", 1 or 0. // true/1/t/True or false/f/0/False.
switch tok.value { switch tok.value {
case "true", "1": case "true", "1", "t", "True":
fv.SetBool(true) fv.SetBool(true)
return nil return nil
case "false", "0": case "false", "0", "f", "False":
fv.SetBool(false) fv.SetBool(false)
return nil return nil
} }

5
vendor/vendor.json vendored
View File

@ -33,9 +33,10 @@
"revisionTime": "2016-12-02T21:35:00Z" "revisionTime": "2016-12-02T21:35:00Z"
}, },
{ {
"checksumSHA1": "kBeNcaKk56FguvPSUCEaH6AxpRc=",
"path": "github.com/golang/protobuf/proto", "path": "github.com/golang/protobuf/proto",
"revision": "2402d76f3d41f928c7902a765dfc872356dd3aad", "revision": "8ee79997227bf9b34611aee7946ae64735e6fd93",
"revisionTime": "2016-01-06T12:58:00+11:00" "revisionTime": "2016-11-17T03:31:26Z"
}, },
{ {
"path": "github.com/kolo/xmlrpc", "path": "github.com/kolo/xmlrpc",