diff --git a/rvpn/packer/packer.go b/rvpn/packer/packer.go index 2c6f6db..a981273 100644 --- a/rvpn/packer/packer.go +++ b/rvpn/packer/packer.go @@ -5,11 +5,13 @@ import ( "fmt" "net" "strconv" + "strings" ) const ( - packerV1 byte = 255 - 1 - packerV2 byte = 255 - 2 + _ = iota // skip the iota value of 0 + packerV1 byte = 255 - iota + packerV2 ) //Packer -- contains both header and data @@ -26,121 +28,95 @@ func NewPacker() (p *Packer) { return } +func splitHeader(header []byte, names []string) (map[string]string, error) { + parts := strings.Split(string(header), ",") + if p, n := len(parts), len(names); p > n { + return nil, fmt.Errorf("Header contains %d extra fields", p-n) + } else if p < n { + return nil, fmt.Errorf("Header missing fields %q", names[p:]) + } + + result := make(map[string]string, len(names)) + for ind, key := range names { + result[key] = parts[ind] + } + return result, nil +} + //ReadMessage - -func ReadMessage(b []byte) (p *Packer, err error) { +func ReadMessage(b []byte) (*Packer, error) { fmt.Println("ReadMessage") - var pos int - err = nil - // detect protocol in use + // Detect protocol in use if b[0] == packerV1 { - p = NewPacker() + // Separate the header and body using the header length in the second byte. + p := NewPacker() + header := b[2 : b[1]+2] + data := b[b[1]+2:] - // Handle Header Length - pos = pos + 1 - p.Header.HeaderLen = b[pos] - - //handle address family - pos = pos + 1 - end := bytes.IndexAny(b[pos:], ",") - if end == -1 { - err = fmt.Errorf("missing , while parsing address family") + // Handle the different parts of the header. + parts, err := splitHeader(header, []string{"address family", "address", "port", "data length", "service"}) + if err != nil { return nil, err } - bAddrFamily := b[pos : pos+end] - if bytes.ContainsAny(bAddrFamily, addressFamilyText[FamilyIPv4]) { + if familyText := parts["address family"]; familyText == addressFamilyText[FamilyIPv4] { p.Header.family = FamilyIPv4 - } else if bytes.ContainsAny(bAddrFamily, addressFamilyText[FamilyIPv6]) { + } else if familyText == addressFamilyText[FamilyIPv6] { p.Header.family = FamilyIPv6 } else { - err = fmt.Errorf("Address family not supported %d", bAddrFamily) + return nil, fmt.Errorf("Address family %q not supported", familyText) } - //handle address - pos = pos + end + 1 - end = bytes.IndexAny(b[pos:], ",") - if end == -1 { - err = fmt.Errorf("missing , while parsing address") - return nil, err - } - p.Header.address = net.ParseIP(string(b[pos : pos+end])) - - //handle import - pos = pos + end + 1 - end = bytes.IndexAny(b[pos:], ",") - if end == -1 { - err = fmt.Errorf("missing , while parsing address") - return nil, err + p.Header.address = net.ParseIP(parts["address"]) + if p.Header.address == nil { + return nil, fmt.Errorf("Invalid network address %q", parts["address"]) + } else if p.Header.Family() == FamilyIPv4 && p.Header.address.To4() == nil { + return nil, fmt.Errorf("Address %q is not in address family %s", parts["address"], p.Header.FamilyText()) } - p.Header.Port, err = strconv.Atoi(string(b[pos : pos+end])) - if err != nil { - err = fmt.Errorf("error converting port %s", err) + //handle port + if port, err := strconv.Atoi(parts["port"]); err != nil { + return nil, fmt.Errorf("Error converting port %q: %v", parts["port"], err) + } else if port <= 0 || port > 65535 { + return nil, fmt.Errorf("Port %d out of range", port) + } else { + p.Header.Port = port } //handle data length - pos = pos + end + 1 - end = bytes.IndexAny(b[pos:], ",") - if end == -1 { - err = fmt.Errorf("missing , while parsing address") - return nil, err - } - - p.Data.DataLen, err = strconv.Atoi(string(b[pos : pos+end])) - if err != nil { - err = fmt.Errorf("error converting data length %s", err) + if dataLen, err := strconv.Atoi(parts["data length"]); err != nil { + return nil, fmt.Errorf("Error converting data length %q: %v", parts["data length"], err) + } else if dataLen != len(data) { + return nil, fmt.Errorf("Data length %d doesn't match received length %d", dataLen, len(data)) } //handle Service - pos = pos + end + 1 - end = pos + int(p.Header.HeaderLen) - p.Header.Service = string(b[pos : p.Header.HeaderLen+2]) + p.Header.Service = parts["service"] //handle payload - pos = int(p.Header.HeaderLen + 2) - p.Data.AppendBytes(b[pos:]) - - } else { - err = fmt.Errorf("Version %d not supported", b[0:0]) + p.Data.AppendBytes(data) + return p, nil } - return - + return nil, fmt.Errorf("Version %d not supported", 255-b[0]) } //PackV1 -- Outputs version 1 of packer -func (p *Packer) PackV1() (b bytes.Buffer) { - version := packerV1 - - var headerBuf bytes.Buffer - headerBuf.WriteString(p.Header.FamilyText()) - headerBuf.WriteString(",") - headerBuf.Write([]byte(p.Header.Address().String())) - headerBuf.WriteString(",") - headerBuf.WriteString(fmt.Sprintf("%d", p.Header.Port)) - headerBuf.WriteString(",") - headerBuf.WriteString(fmt.Sprintf("%d", p.Data.buffer.Len())) - headerBuf.WriteString(",") - headerBuf.WriteString(p.Header.Service) - - var metaBuf bytes.Buffer - metaBuf.WriteByte(version) - metaBuf.WriteByte(byte(headerBuf.Len())) +func (p *Packer) PackV1() bytes.Buffer { + header := strings.Join([]string{ + p.Header.FamilyText(), + p.Header.AddressString(), + strconv.Itoa(p.Header.Port), + strconv.Itoa(p.Data.DataLen()), + p.Header.Service, + }, ",") var buf bytes.Buffer - buf.Write(metaBuf.Bytes()) - buf.Write(headerBuf.Bytes()) - buf.Write(p.Data.buffer.Bytes()) + buf.WriteByte(packerV1) + buf.WriteByte(byte(len(header))) + buf.WriteString(header) + buf.Write(p.Data.Data()) - //fmt.Println("header: ", headerBuf.String()) - //fmt.Println("meta: ", metaBuf) - //fmt.Println("Data: ", p.Data.buffer) - //fmt.Println("Buffer: ", buf.Bytes()) - //fmt.Println("Buffer: ", hex.Dump(buf.Bytes())) - //fmt.Printf("Buffer %s", buf.Bytes()) - - b = buf - - return + return buf } diff --git a/rvpn/packer/packer_data.go b/rvpn/packer/packer_data.go index 10ad5c0..08775ad 100644 --- a/rvpn/packer/packer_data.go +++ b/rvpn/packer/packer_data.go @@ -6,14 +6,11 @@ import ( //packerData -- Contains packer data type packerData struct { - buffer *bytes.Buffer - DataLen int + buffer bytes.Buffer } -func newPackerData() (p *packerData) { - p = new(packerData) - p.buffer = new(bytes.Buffer) - return +func newPackerData() *packerData { + return new(packerData) } func (p *packerData) AppendString(dataString string) (int, error) { @@ -28,3 +25,7 @@ func (p *packerData) AppendBytes(dataBytes []byte) (int, error) { func (p *packerData) Data() []byte { return p.buffer.Bytes() } + +func (p *packerData) DataLen() int { + return p.buffer.Len() +} diff --git a/rvpn/packer/packer_header.go b/rvpn/packer/packer_header.go index 100fba2..be7cbb2 100644 --- a/rvpn/packer/packer_header.go +++ b/rvpn/packer/packer_header.go @@ -9,11 +9,10 @@ type addressFamily int // packerHeader structure to hold our header information. type packerHeader struct { - family addressFamily - address net.IP - Port int - Service string - HeaderLen byte + family addressFamily + address net.IP + Port int + Service string } //Family -- ENUM for Address Family @@ -32,7 +31,6 @@ func newPackerHeader() (p *packerHeader) { p.SetAddress("127.0.0.1") p.Port = 65535 p.Service = "na" - p.HeaderLen = 0 return }