diff --git a/receive.go b/receive.go index dcc042d..015714a 100644 --- a/receive.go +++ b/receive.go @@ -2,10 +2,7 @@ package btrfs import ( "bytes" - "encoding/binary" "errors" - "fmt" - "hash/crc32" "io" "os" "os/exec" @@ -52,85 +49,10 @@ func Receive(r io.Reader, dstDir string) error { if err != nil { return err } - sr, err := newStreamReader(r) - if err != nil { - return err - } - _, _, _ = dir, subvolID, sr + //sr, err := send.NewStreamReader(r) + //if err != nil { + // return err + //} + _, _ = dir, subvolID panic("not implemented") } - -type streamReader struct { - r io.Reader - hbuf []byte - buf *bytes.Buffer -} -type sendCommandArgs struct { - Type sendCmdAttr - Data []byte -} -type sendCommand struct { - Type sendCmd - Args []sendCommandArgs -} - -func (sr *streamReader) ReadCommand() (*sendCommand, error) { - sr.buf.Reset() - var h cmdHeader - if sr.hbuf == nil { - sr.hbuf = make([]byte, h.Size()) - } - if _, err := io.ReadFull(sr.r, sr.hbuf); err != nil { - return nil, err - } else if err = h.Unmarshal(sr.hbuf); err != nil { - return nil, err - } - if sr.buf == nil { - sr.buf = bytes.NewBuffer(nil) - } - if _, err := io.CopyN(sr.buf, sr.r, int64(h.Len)); err != nil { - return nil, err - } - tbl := crc32.MakeTable(0) - crc := crc32.Checksum(sr.buf.Bytes(), tbl) - if crc != h.Crc { - return nil, fmt.Errorf("crc missmatch in command: %x vs %x", crc, h.Crc) - } - cmd := sendCommand{Type: sendCmd(h.Cmd)} - var th tlvHeader - data := sr.buf.Bytes() - for { - if n := len(data); n < th.Size() { - if n != 0 { - return nil, io.ErrUnexpectedEOF - } - break - } - if err := th.Unmarshal(data); err != nil { - return nil, err - } - data = data[th.Size():] - if sendCmdAttr(th.Type) > sendAttrMax { // || th.Len > _BTRFS_SEND_BUF_SIZE { - return nil, fmt.Errorf("invalid tlv in cmd: %+v", th) - } - b := make([]byte, th.Len) - copy(b, data) - cmd.Args = append(cmd.Args, sendCommandArgs{Type: sendCmdAttr(th.Type), Data: b}) - } - return &cmd, nil -} - -func newStreamReader(r io.Reader) (*streamReader, error) { - buf := make([]byte, sendStreamMagicSize+4) - _, err := io.ReadFull(r, buf) - if err != nil { - return nil, err - } else if bytes.Compare(buf[:sendStreamMagicSize], []byte(sendStreamMagic)) != 0 { - return nil, errors.New("unexpected stream header") - } - version := binary.LittleEndian.Uint32(buf[sendStreamMagicSize:]) - if version > sendStreamVersion { - return nil, fmt.Errorf("stream version %d not supported", version) - } - return &streamReader{r: r}, nil -} diff --git a/send/send.go b/send/send.go new file mode 100644 index 0000000..f5f6315 --- /dev/null +++ b/send/send.go @@ -0,0 +1,484 @@ +package send + +import ( + "errors" + "fmt" + "github.com/dennwc/btrfs" + "io" + "io/ioutil" + "time" +) + +func NewStreamReader(r io.Reader) (*StreamReader, error) { + // read magic and version + buf := make([]byte, len(sendStreamMagic)+4) + _, err := io.ReadFull(r, buf) + if err != nil { + return nil, fmt.Errorf("cannot read magic: %v", err) + } else if string(buf[:sendStreamMagicSize]) != sendStreamMagic { + return nil, errors.New("unexpected stream header") + } + version := sendEndianess.Uint32(buf[sendStreamMagicSize:]) + if version != sendStreamVersion { + return nil, fmt.Errorf("stream version %d not supported", version) + } + return &StreamReader{r: r}, nil +} + +type StreamReader struct { + r io.Reader + buf [cmdHeaderSize]byte +} + +func (r *StreamReader) readCmdHeader() (h cmdHeader, err error) { + _, err = io.ReadFull(r.r, r.buf[:cmdHeaderSize]) + if err == io.EOF { + return + } else if err != nil { + err = fmt.Errorf("cannot read command header: %v", err) + return + } + err = h.Unmarshal(r.buf[:cmdHeaderSize]) + // TODO: check CRC + return +} + +type SendTLV struct { + Attr sendCmdAttr + Val interface{} +} + +func (r *StreamReader) readTLV(rd io.Reader) (*SendTLV, error) { + _, err := io.ReadFull(rd, r.buf[:tlvHeaderSize]) + if err == io.EOF { + return nil, err + } else if err != nil { + return nil, fmt.Errorf("cannot read tlv header: %v", err) + } + var h tlvHeader + if err = h.Unmarshal(r.buf[:tlvHeaderSize]); err != nil { + return nil, err + } + typ := sendCmdAttr(h.Type) + if sendCmdAttr(typ) > sendAttrMax { // || th.Len > _BTRFS_SEND_BUF_SIZE { + return nil, fmt.Errorf("invalid tlv in cmd: %q", typ) + } + buf := make([]byte, h.Len) + _, err = io.ReadFull(rd, buf) + if err != nil { + return nil, fmt.Errorf("cannot read tlv: %v", err) + } + var v interface{} + switch typ { + case sendAttrCtransid, sendAttrCloneCtransid, + sendAttrUid, sendAttrGid, sendAttrMode, + sendAttrIno, sendAttrFileOffset, sendAttrSize, + sendAttrCloneOffset, sendAttrCloneLen: + if len(buf) != 8 { + return nil, fmt.Errorf("unexpected int64 size: %v", h.Len) + } + v = sendEndianess.Uint64(buf[:8]) + case sendAttrPath, sendAttrPathTo, sendAttrClonePath, sendAttrXattrName: + v = string(buf) + case sendAttrData, sendAttrXattrData: + v = buf + case sendAttrUuid, sendAttrCloneUuid: + if h.Len != btrfs.UUIDSize { + return nil, fmt.Errorf("unexpected UUID size: %v", h.Len) + } + var u btrfs.UUID + copy(u[:], buf) + v = u + case sendAttrAtime, sendAttrMtime, sendAttrCtime, sendAttrOtime: + if h.Len != 12 { + return nil, fmt.Errorf("unexpected timestamp size: %v", h.Len) + } + v = time.Unix( // btrfs_timespec + int64(sendEndianess.Uint64(buf[:8])), + int64(sendEndianess.Uint32(buf[8:])), + ) + default: + return nil, fmt.Errorf("unsupported tlv type: %v (len: %v)", typ, h.Len) + } + return &SendTLV{Attr: typ, Val: v}, nil +} +func (r *StreamReader) ReadCommand() (_ Cmd, gerr error) { + h, err := r.readCmdHeader() + if err != nil { + return nil, err + } + var tlvs []SendTLV + rd := io.LimitReader(r.r, int64(h.Len)) + defer io.Copy(ioutil.Discard, rd) + for { + tlv, err := r.readTLV(rd) + if err == io.EOF { + break + } else if err != nil { + return nil, fmt.Errorf("command %v: %v", h.Cmd, err) + } + tlvs = append(tlvs, *tlv) + } + var c Cmd + switch h.Cmd { + case sendCmdEnd: + c = &StreamEnd{} + case sendCmdSubvol: + c = &SubvolCmd{} + case sendCmdSnapshot: + c = &SnapshotCmd{} + case sendCmdChown: + c = &ChownCmd{} + case sendCmdChmod: + c = &ChmodCmd{} + case sendCmdUtimes: + c = &UTimesCmd{} + case sendCmdMkdir: + c = &MkdirCmd{} + case sendCmdRename: + c = &RenameCmd{} + case sendCmdMkfile: + c = &MkfileCmd{} + case sendCmdWrite: + c = &WriteCmd{} + case sendCmdTruncate: + c = &TruncateCmd{} + } + if c == nil { + return &UnknownSendCmd{Kind: h.Cmd, Params: tlvs}, nil + } + if err := c.decode(tlvs); err != nil { + return nil, err + } + return c, nil +} + +type errUnexpectedAttrType struct { + Cmd CmdType + Val SendTLV +} + +func (e errUnexpectedAttrType) Error() string { + return fmt.Sprintf("unexpected type for %q (in %q): %T", + e.Val.Attr, e.Cmd, e.Val.Val) +} + +type errUnexpectedAttr struct { + Cmd CmdType + Val SendTLV +} + +func (e errUnexpectedAttr) Error() string { + return fmt.Sprintf("unexpected attr %q for %q (%T)", + e.Val.Attr, e.Cmd, e.Val.Val) +} + +type Cmd interface { + Type() CmdType + decode(tlvs []SendTLV) error +} + +type UnknownSendCmd struct { + Kind CmdType + Params []SendTLV +} + +func (c UnknownSendCmd) Type() CmdType { + return c.Kind +} +func (c *UnknownSendCmd) decode(tlvs []SendTLV) error { + c.Params = tlvs + return nil +} + +type StreamEnd struct{} + +func (c StreamEnd) Type() CmdType { + return sendCmdEnd +} +func (c *StreamEnd) decode(tlvs []SendTLV) error { + if len(tlvs) != 0 { + return fmt.Errorf("unexpected TLVs for stream end command: %#v", tlvs) + } + return nil +} + +type SubvolCmd struct { + Path string + UUID btrfs.UUID + CTransID uint64 +} + +func (c SubvolCmd) Type() CmdType { + return sendCmdSubvol +} +func (c *SubvolCmd) decode(tlvs []SendTLV) error { + for _, tlv := range tlvs { + var ok bool + switch tlv.Attr { + case sendAttrPath: + c.Path, ok = tlv.Val.(string) + case sendAttrUuid: + c.UUID, ok = tlv.Val.(btrfs.UUID) + case sendAttrCtransid: + c.CTransID, ok = tlv.Val.(uint64) + default: + return errUnexpectedAttr{Val: tlv, Cmd: c.Type()} + } + if !ok { + return errUnexpectedAttrType{Val: tlv, Cmd: c.Type()} + } + } + return nil +} + +type SnapshotCmd struct { + Path string + UUID btrfs.UUID + CTransID uint64 + CloneUUID btrfs.UUID + CloneTransID uint64 +} + +func (c SnapshotCmd) Type() CmdType { + return sendCmdSnapshot +} +func (c *SnapshotCmd) decode(tlvs []SendTLV) error { + for _, tlv := range tlvs { + var ok bool + switch tlv.Attr { + case sendAttrPath: + c.Path, ok = tlv.Val.(string) + case sendAttrUuid: + c.UUID, ok = tlv.Val.(btrfs.UUID) + case sendAttrCtransid: + c.CTransID, ok = tlv.Val.(uint64) + case sendAttrCloneUuid: + c.CloneUUID, ok = tlv.Val.(btrfs.UUID) + case sendAttrCloneCtransid: + c.CloneTransID, ok = tlv.Val.(uint64) + default: + return errUnexpectedAttr{Val: tlv, Cmd: c.Type()} + } + if !ok { + return errUnexpectedAttrType{Val: tlv, Cmd: c.Type()} + } + } + return nil +} + +type ChownCmd struct { + Path string + UID, GID uint64 +} + +func (c ChownCmd) Type() CmdType { + return sendCmdChown +} +func (c *ChownCmd) decode(tlvs []SendTLV) error { + for _, tlv := range tlvs { + var ok bool + switch tlv.Attr { + case sendAttrPath: + c.Path, ok = tlv.Val.(string) + case sendAttrUid: + c.UID, ok = tlv.Val.(uint64) + case sendAttrGid: + c.GID, ok = tlv.Val.(uint64) + default: + return errUnexpectedAttr{Val: tlv, Cmd: c.Type()} + } + if !ok { + return errUnexpectedAttrType{Val: tlv, Cmd: c.Type()} + } + } + return nil +} + +type ChmodCmd struct { + Path string + Mode uint64 +} + +func (c ChmodCmd) Type() CmdType { + return sendCmdChmod +} +func (c *ChmodCmd) decode(tlvs []SendTLV) error { + for _, tlv := range tlvs { + var ok bool + switch tlv.Attr { + case sendAttrPath: + c.Path, ok = tlv.Val.(string) + case sendAttrMode: + c.Mode, ok = tlv.Val.(uint64) + default: + return errUnexpectedAttr{Val: tlv, Cmd: c.Type()} + } + if !ok { + return errUnexpectedAttrType{Val: tlv, Cmd: c.Type()} + } + } + return nil +} + +type UTimesCmd struct { + Path string + ATime, MTime, CTime time.Time +} + +func (c UTimesCmd) Type() CmdType { + return sendCmdUtimes +} +func (c *UTimesCmd) decode(tlvs []SendTLV) error { + for _, tlv := range tlvs { + var ok bool + switch tlv.Attr { + case sendAttrPath: + c.Path, ok = tlv.Val.(string) + case sendAttrAtime: + c.ATime, ok = tlv.Val.(time.Time) + case sendAttrMtime: + c.MTime, ok = tlv.Val.(time.Time) + case sendAttrCtime: + c.CTime, ok = tlv.Val.(time.Time) + default: + return errUnexpectedAttr{Val: tlv, Cmd: c.Type()} + } + if !ok { + return errUnexpectedAttrType{Val: tlv, Cmd: c.Type()} + } + } + return nil +} + +type MkdirCmd struct { + Path string + Ino uint64 +} + +func (c MkdirCmd) Type() CmdType { + return sendCmdMkdir +} +func (c *MkdirCmd) decode(tlvs []SendTLV) error { + for _, tlv := range tlvs { + var ok bool + switch tlv.Attr { + case sendAttrPath: + c.Path, ok = tlv.Val.(string) + case sendAttrIno: + c.Ino, ok = tlv.Val.(uint64) + default: + return errUnexpectedAttr{Val: tlv, Cmd: c.Type()} + } + if !ok { + return errUnexpectedAttrType{Val: tlv, Cmd: c.Type()} + } + } + return nil +} + +type RenameCmd struct { + From, To string +} + +func (c RenameCmd) Type() CmdType { + return sendCmdRename +} +func (c *RenameCmd) decode(tlvs []SendTLV) error { + for _, tlv := range tlvs { + var ok bool + switch tlv.Attr { + case sendAttrPath: + c.From, ok = tlv.Val.(string) + case sendAttrPathTo: + c.To, ok = tlv.Val.(string) + default: + return errUnexpectedAttr{Val: tlv, Cmd: c.Type()} + } + if !ok { + return errUnexpectedAttrType{Val: tlv, Cmd: c.Type()} + } + } + return nil +} + +type MkfileCmd struct { + Path string + Ino uint64 +} + +func (c MkfileCmd) Type() CmdType { + return sendCmdMkfile +} +func (c *MkfileCmd) decode(tlvs []SendTLV) error { + for _, tlv := range tlvs { + var ok bool + switch tlv.Attr { + case sendAttrPath: + c.Path, ok = tlv.Val.(string) + case sendAttrIno: + c.Ino, ok = tlv.Val.(uint64) + default: + return errUnexpectedAttr{Val: tlv, Cmd: c.Type()} + } + if !ok { + return errUnexpectedAttrType{Val: tlv, Cmd: c.Type()} + } + } + return nil +} + +type WriteCmd struct { + Path string + Off uint64 + Data []byte +} + +func (c WriteCmd) Type() CmdType { + return sendCmdWrite +} +func (c *WriteCmd) decode(tlvs []SendTLV) error { + for _, tlv := range tlvs { + var ok bool + switch tlv.Attr { + case sendAttrPath: + c.Path, ok = tlv.Val.(string) + case sendAttrFileOffset: + c.Off, ok = tlv.Val.(uint64) + case sendAttrData: + c.Data, ok = tlv.Val.([]byte) + default: + return errUnexpectedAttr{Val: tlv, Cmd: c.Type()} + } + if !ok { + return errUnexpectedAttrType{Val: tlv, Cmd: c.Type()} + } + } + return nil +} + +type TruncateCmd struct { + Path string + Size uint64 +} + +func (c TruncateCmd) Type() CmdType { + return sendCmdTruncate +} +func (c *TruncateCmd) decode(tlvs []SendTLV) error { + for _, tlv := range tlvs { + var ok bool + switch tlv.Attr { + case sendAttrPath: + c.Path, ok = tlv.Val.(string) + case sendAttrSize: + c.Size, ok = tlv.Val.(uint64) + default: + return errUnexpectedAttr{Val: tlv, Cmd: c.Type()} + } + if !ok { + return errUnexpectedAttrType{Val: tlv, Cmd: c.Type()} + } + } + return nil +} diff --git a/send_h.go b/send/send_h.go similarity index 86% rename from send_h.go rename to send/send_h.go index 64b5058..63e3d32 100644 --- a/send_h.go +++ b/send/send_h.go @@ -1,4 +1,4 @@ -package btrfs +package send import ( "encoding/binary" @@ -19,29 +19,11 @@ const ( sendReadSize = 48 * 1024 ) -type tlvType uint16 - -const ( - tlvU8 = tlvType(iota) - tlvU16 - tlvU32 - tlvU64 - tlvBinary - tlvString - tlvUUID - tlvTimespec -) - -type streamHeader struct { - Magic [sendStreamMagicSize]byte - Version uint32 -} - const cmdHeaderSize = 10 type cmdHeader struct { Len uint32 // len excluding the header - Cmd sendCmd + Cmd CmdType Crc uint32 // crc including the header with zero crc field } @@ -51,7 +33,7 @@ func (h *cmdHeader) Unmarshal(p []byte) error { return io.ErrUnexpectedEOF } h.Len = sendEndianess.Uint32(p[0:]) - h.Cmd = sendCmd(sendEndianess.Uint16(p[4:])) + h.Cmd = CmdType(sendEndianess.Uint16(p[4:])) h.Crc = sendEndianess.Uint32(p[6:]) return nil } @@ -73,12 +55,12 @@ func (h *tlvHeader) Unmarshal(p []byte) error { return nil } -type sendCmd uint16 +type CmdType uint16 -func (c sendCmd) String() string { +func (c CmdType) String() string { var name string - if int(c) < len(sendCmdTypeNames) { - name = sendCmdTypeNames[int(c)] + if int(c) < len(cmdTypeNames) { + name = cmdTypeNames[int(c)] } if name != "" { return name @@ -86,7 +68,7 @@ func (c sendCmd) String() string { return strconv.FormatInt(int64(c), 16) } -var sendCmdTypeNames = []string{ +var cmdTypeNames = []string{ "", "subvol", @@ -121,7 +103,7 @@ var sendCmdTypeNames = []string{ } const ( - sendCmdUnspec = sendCmd(iota) + sendCmdUnspec = CmdType(iota) sendCmdSubvol sendCmdSnapshot