You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

203 lines
5.1 KiB

package proxyproto
import (
var (
lengthV4 = uint16(12)
lengthV6 = uint16(36)
lengthUnix = uint16(218)
lengthV4Bytes = func() []byte {
a := make([]byte, 2)
binary.BigEndian.PutUint16(a, lengthV4)
return a
lengthV6Bytes = func() []byte {
a := make([]byte, 2)
binary.BigEndian.PutUint16(a, lengthV6)
return a
lengthUnixBytes = func() []byte {
a := make([]byte, 2)
binary.BigEndian.PutUint16(a, lengthUnix)
return a
type _ports struct {
SrcPort uint16
DstPort uint16
type _addr4 struct {
Src [4]byte
Dst [4]byte
SrcPort uint16
DstPort uint16
type _addr6 struct {
Src [16]byte
Dst [16]byte
type _addrUnix struct {
Src [108]byte
Dst [108]byte
func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
// Skip first 12 bytes (signature)
for i := 0; i < 12; i++ {
if _, err = reader.ReadByte(); err != nil {
return nil, ErrCantReadProtocolVersionAndCommand
header = new(Header)
header.Version = 2
// Read the 13th byte, protocol version and command
b13, err := reader.ReadByte()
if err != nil {
return nil, ErrCantReadProtocolVersionAndCommand
header.Command = ProtocolVersionAndCommand(b13)
if _, ok := supportedCommand[header.Command]; !ok {
return nil, ErrUnsupportedProtocolVersionAndCommand
// If command is LOCAL, header ends here
if header.Command.IsLocal() {
return header, nil
// Read the 14th byte, address family and protocol
b14, err := reader.ReadByte()
if err != nil {
return nil, ErrCantReadAddressFamilyAndProtocol
header.TransportProtocol = AddressFamilyAndProtocol(b14)
if _, ok := supportedTransportProtocol[header.TransportProtocol]; !ok {
return nil, ErrUnsupportedAddressFamilyAndProtocol
// Make sure there are bytes available as specified in length
var length uint16
if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil {
return nil, ErrCantReadLength
if !header.validateLength(length) {
return nil, ErrInvalidLength
if _, err := reader.Peek(int(length)); err != nil {
return nil, ErrInvalidLength
// Length-limited reader for payload section
payloadReader := io.LimitReader(reader, int64(length))
// Read addresses and ports
if header.TransportProtocol.IsIPv4() {
var addr _addr4
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
header.SourceAddress = addr.Src[:]
header.DestinationAddress = addr.Dst[:]
header.SourcePort = addr.SrcPort
header.DestinationPort = addr.DstPort
} else if header.TransportProtocol.IsIPv6() {
var addr _addr6
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
header.SourceAddress = addr.Src[:]
header.DestinationAddress = addr.Dst[:]
header.SourcePort = addr.SrcPort
header.DestinationPort = addr.DstPort
// TODO fully support Unix addresses
// else if header.TransportProtocol.IsUnix() {
// var addr _addrUnix
// if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
// return nil, ErrInvalidAddress
// }
//if header.SourceAddress, err = net.ResolveUnixAddr("unix", string(addr.Src[:])); err != nil {
// return nil, ErrCantResolveSourceUnixAddress
//if header.DestinationAddress, err = net.ResolveUnixAddr("unix", string(addr.Dst[:])); err != nil {
// return nil, ErrCantResolveDestinationUnixAddress
// TODO add encapsulated TLV support
// Drain the remaining padding
payloadReader.Read(make([]byte, length))
return header, nil
func (header *Header) writeVersion2(w io.Writer) (int64, error) {
var buf bytes.Buffer
if !header.Command.IsLocal() {
// TODO add encapsulated TLV length
var addrSrc, addrDst []byte
if header.TransportProtocol.IsIPv4() {
addrSrc = header.SourceAddress.To4()
addrDst = header.DestinationAddress.To4()
} else if header.TransportProtocol.IsIPv6() {
addrSrc = header.SourceAddress.To16()
addrDst = header.DestinationAddress.To16()
} else if header.TransportProtocol.IsUnix() {
// TODO is below right?
addrSrc = []byte(header.SourceAddress.String())
addrDst = []byte(header.DestinationAddress.String())
portSrcBytes := func() []byte {
a := make([]byte, 2)
binary.BigEndian.PutUint16(a, header.SourcePort)
return a
portDstBytes := func() []byte {
a := make([]byte, 2)
binary.BigEndian.PutUint16(a, header.DestinationPort)
return a
return buf.WriteTo(w)
func (header *Header) validateLength(length uint16) bool {
if header.TransportProtocol.IsIPv4() {
return length >= lengthV4
} else if header.TransportProtocol.IsIPv6() {
return length >= lengthV6
} else if header.TransportProtocol.IsUnix() {
return length >= lengthUnix
return false