Browse Source

p2p: API cleanup and PoC 7 compatibility

Whoa, one more big commit. I didn't manage to untangle the
changes while working towards compatibility.
Felix Lange 11 năm trước cách đây
mục cha
commit
59b63caf5e
17 tập tin đã thay đổi với 1665 bổ sung1902 xóa
  1. 3 3
      p2p/client_identity.go
  2. 53 9
      p2p/message.go
  3. 0 221
      p2p/messenger.go
  4. 0 203
      p2p/messenger_test.go
  5. 17 17
      p2p/natpmp.go
  6. 102 96
      p2p/natupnp.go
  7. 0 196
      p2p/network.go
  8. 432 44
      p2p/peer.go
  9. 100 50
      p2p/peer_error.go
  10. 0 98
      p2p/peer_error_handler.go
  11. 0 34
      p2p/peer_error_handler_test.go
  12. 220 88
      p2p/peer_test.go
  13. 194 218
      p2p/protocol.go
  14. 346 367
      p2p/server.go
  15. 130 258
      p2p/server_test.go
  16. 28 0
      p2p/testlog_test.go
  17. 40 0
      p2p/testpoc7.go

+ 3 - 3
p2p/client_identity.go

@@ -5,10 +5,10 @@ import (
 	"runtime"
 )
 
-// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc.
+// ClientIdentity represents the identity of a peer.
 type ClientIdentity interface {
-	String() string
-	Pubkey() []byte
+	String() string // human readable identity
+	Pubkey() []byte // 512-bit public key
 }
 
 type SimpleClientIdentity struct {

+ 53 - 9
p2p/message.go

@@ -11,8 +11,6 @@ import (
 	"github.com/ethereum/go-ethereum/ethutil"
 )
 
-type MsgCode uint64
-
 // Msg defines the structure of a p2p message.
 //
 // Note that a Msg can only be sent once since the Payload reader is
@@ -21,13 +19,13 @@ type MsgCode uint64
 // structure, encode the payload into a byte array and create a
 // separate Msg with a bytes.Reader as Payload for each send.
 type Msg struct {
-	Code    MsgCode
+	Code    uint64
 	Size    uint32 // size of the paylod
 	Payload io.Reader
 }
 
 // NewMsg creates an RLP-encoded message with the given code.
-func NewMsg(code MsgCode, params ...interface{}) Msg {
+func NewMsg(code uint64, params ...interface{}) Msg {
 	buf := new(bytes.Buffer)
 	for _, p := range params {
 		buf.Write(ethutil.Encode(p))
@@ -63,6 +61,52 @@ func (msg Msg) Discard() error {
 	return err
 }
 
+type MsgReader interface {
+	ReadMsg() (Msg, error)
+}
+
+type MsgWriter interface {
+	// WriteMsg sends an existing message.
+	// The Payload reader of the message is consumed.
+	// Note that messages can be sent only once.
+	WriteMsg(Msg) error
+
+	// EncodeMsg writes an RLP-encoded message with the given
+	// code and data elements.
+	EncodeMsg(code uint64, data ...interface{}) error
+}
+
+// MsgReadWriter provides reading and writing of encoded messages.
+type MsgReadWriter interface {
+	MsgReader
+	MsgWriter
+}
+
+// MsgLoop reads messages off the given reader and
+// calls the handler function for each decoded message until
+// it returns an error or the peer connection is closed.
+//
+// If a message is larger than the given maximum size,
+// MsgLoop returns an appropriate error.
+func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Value) error) error {
+	for {
+		msg, err := r.ReadMsg()
+		if err != nil {
+			return err
+		}
+		if msg.Size > maxsize {
+			return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
+		}
+		value, err := msg.Data()
+		if err != nil {
+			return err
+		}
+		if err := f(msg.Code, value); err != nil {
+			return err
+		}
+	}
+}
+
 var magicToken = []byte{34, 64, 8, 145}
 
 func writeMsg(w io.Writer, msg Msg) error {
@@ -103,10 +147,10 @@ func readMsg(r byteReader) (msg Msg, err error) {
 	// read magic and payload size
 	start := make([]byte, 8)
 	if _, err = io.ReadFull(r, start); err != nil {
-		return msg, NewPeerError(ReadError, "%v", err)
+		return msg, newPeerError(errRead, "%v", err)
 	}
 	if !bytes.HasPrefix(start, magicToken) {
-		return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
+		return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
 	}
 	size := binary.BigEndian.Uint32(start[4:])
 
@@ -152,13 +196,13 @@ func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) {
 }
 
 // readUint reads an RLP-encoded unsigned integer from r.
-func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) {
+func readMsgCode(r byteReader) (code uint64, codelen uint32, err error) {
 	b, err := r.ReadByte()
 	if err != nil {
 		return 0, 0, err
 	}
 	if b < 0x80 {
-		return MsgCode(b), 1, nil
+		return uint64(b), 1, nil
 	} else if b < 0x89 { // max length for uint64 is 8 bytes
 		codelen = uint32(b - 0x80)
 		if codelen == 0 {
@@ -168,7 +212,7 @@ func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) {
 		if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil {
 			return 0, 0, err
 		}
-		return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil
+		return binary.BigEndian.Uint64(buf), codelen, nil
 	}
 	return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b)
 }

+ 0 - 221
p2p/messenger.go

@@ -1,221 +0,0 @@
-package p2p
-
-import (
-	"bufio"
-	"bytes"
-	"fmt"
-	"io"
-	"io/ioutil"
-	"net"
-	"sync"
-	"time"
-)
-
-type Handlers map[string]Protocol
-
-type proto struct {
-	in              chan Msg
-	maxcode, offset MsgCode
-	messenger       *messenger
-}
-
-func (rw *proto) WriteMsg(msg Msg) error {
-	if msg.Code >= rw.maxcode {
-		return NewPeerError(InvalidMsgCode, "not handled")
-	}
-	msg.Code += rw.offset
-	return rw.messenger.writeMsg(msg)
-}
-
-func (rw *proto) ReadMsg() (Msg, error) {
-	msg, ok := <-rw.in
-	if !ok {
-		return msg, io.EOF
-	}
-	msg.Code -= rw.offset
-	return msg, nil
-}
-
-// eofSignal wraps a reader with eof signaling.
-// the eof channel is closed when the wrapped reader
-// reaches EOF.
-type eofSignal struct {
-	wrapped io.Reader
-	eof     chan struct{}
-}
-
-func (r *eofSignal) Read(buf []byte) (int, error) {
-	n, err := r.wrapped.Read(buf)
-	if err != nil {
-		close(r.eof) // tell messenger that msg has been consumed
-	}
-	return n, err
-}
-
-// messenger represents a message-oriented peer connection.
-// It keeps track of the set of protocols understood
-// by the remote peer.
-type messenger struct {
-	peer     *Peer
-	handlers Handlers
-
-	// the mutex protects the connection
-	// so only one protocol can write at a time.
-	writeMu sync.Mutex
-	conn    net.Conn
-	bufconn *bufio.ReadWriter
-
-	protocolLock sync.RWMutex
-	protocols    map[string]*proto
-	offsets      map[MsgCode]*proto
-	protoWG      sync.WaitGroup
-
-	err   chan error
-	pulse chan bool
-}
-
-func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger {
-	return &messenger{
-		conn:      conn,
-		bufconn:   bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
-		peer:      peer,
-		handlers:  handlers,
-		protocols: make(map[string]*proto),
-		err:       errchan,
-		pulse:     make(chan bool, 1),
-	}
-}
-
-func (m *messenger) Start() {
-	m.protocols[""] = m.startProto(0, "", &baseProtocol{})
-	go m.readLoop()
-}
-
-func (m *messenger) Stop() {
-	m.conn.Close()
-	m.protoWG.Wait()
-}
-
-const (
-	// maximum amount of time allowed for reading a message
-	msgReadTimeout = 5 * time.Second
-
-	// messages smaller than this many bytes will be read at
-	// once before passing them to a protocol.
-	wholePayloadSize = 64 * 1024
-)
-
-func (m *messenger) readLoop() {
-	defer m.closeProtocols()
-	for {
-		m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
-		msg, err := readMsg(m.bufconn)
-		if err != nil {
-			m.err <- err
-			return
-		}
-		// send ping to heartbeat channel signalling time of last message
-		m.pulse <- true
-		proto, err := m.getProto(msg.Code)
-		if err != nil {
-			m.err <- err
-			return
-		}
-		if msg.Size <= wholePayloadSize {
-			// optimization: msg is small enough, read all
-			// of it and move on to the next message
-			buf, err := ioutil.ReadAll(msg.Payload)
-			if err != nil {
-				m.err <- err
-				return
-			}
-			msg.Payload = bytes.NewReader(buf)
-			proto.in <- msg
-		} else {
-			pr := &eofSignal{msg.Payload, make(chan struct{})}
-			msg.Payload = pr
-			proto.in <- msg
-			<-pr.eof
-		}
-	}
-}
-
-func (m *messenger) closeProtocols() {
-	m.protocolLock.RLock()
-	for _, p := range m.protocols {
-		close(p.in)
-	}
-	m.protocolLock.RUnlock()
-}
-
-func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto {
-	proto := &proto{
-		in:        make(chan Msg),
-		offset:    offset,
-		maxcode:   impl.Offset(),
-		messenger: m,
-	}
-	m.protoWG.Add(1)
-	go func() {
-		if err := impl.Start(m.peer, proto); err != nil && err != io.EOF {
-			logger.Errorf("protocol %q error: %v\n", name, err)
-			m.err <- err
-		}
-		m.protoWG.Done()
-	}()
-	return proto
-}
-
-// getProto finds the protocol responsible for handling
-// the given message code.
-func (m *messenger) getProto(code MsgCode) (*proto, error) {
-	m.protocolLock.RLock()
-	defer m.protocolLock.RUnlock()
-	for _, proto := range m.protocols {
-		if code >= proto.offset && code < proto.offset+proto.maxcode {
-			return proto, nil
-		}
-	}
-	return nil, NewPeerError(InvalidMsgCode, "%d", code)
-}
-
-// setProtocols starts all subprotocols shared with the
-// remote peer. the protocols must be sorted alphabetically.
-func (m *messenger) setRemoteProtocols(protocols []string) {
-	m.protocolLock.Lock()
-	defer m.protocolLock.Unlock()
-	offset := baseProtocolOffset
-	for _, name := range protocols {
-		inst, ok := m.handlers[name]
-		if !ok {
-			continue // not handled
-		}
-		m.protocols[name] = m.startProto(offset, name, inst)
-		offset += inst.Offset()
-	}
-}
-
-// writeProtoMsg sends the given message on behalf of the given named protocol.
-func (m *messenger) writeProtoMsg(protoName string, msg Msg) error {
-	m.protocolLock.RLock()
-	proto, ok := m.protocols[protoName]
-	m.protocolLock.RUnlock()
-	if !ok {
-		return fmt.Errorf("protocol %s not handled by peer", protoName)
-	}
-	if msg.Code >= proto.maxcode {
-		return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
-	}
-	msg.Code += proto.offset
-	return m.writeMsg(msg)
-}
-
-// writeMsg writes a message to the connection.
-func (m *messenger) writeMsg(msg Msg) error {
-	m.writeMu.Lock()
-	defer m.writeMu.Unlock()
-	if err := writeMsg(m.bufconn, msg); err != nil {
-		return err
-	}
-	return m.bufconn.Flush()
-}

+ 0 - 203
p2p/messenger_test.go

@@ -1,203 +0,0 @@
-package p2p
-
-import (
-	"bufio"
-	"fmt"
-	"io"
-	"log"
-	"net"
-	"os"
-	"reflect"
-	"testing"
-	"time"
-
-	logpkg "github.com/ethereum/go-ethereum/logger"
-)
-
-func init() {
-	logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel))
-}
-
-func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
-	conn1, conn2 := net.Pipe()
-	id := NewSimpleClientIdentity("test", "0", "0", "public key")
-	server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
-	peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0)
-	return conn2, peer, peer.messenger
-}
-
-func performTestHandshake(r *bufio.Reader, w io.Writer) error {
-	// read remote handshake
-	msg, err := readMsg(r)
-	if err != nil {
-		return fmt.Errorf("read error: %v", err)
-	}
-	if msg.Code != handshakeMsg {
-		return fmt.Errorf("first message should be handshake, got %d", msg.Code)
-	}
-	if err := msg.Discard(); err != nil {
-		return err
-	}
-	// send empty handshake
-	pubkey := make([]byte, 64)
-	msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey)
-	return writeMsg(w, msg)
-}
-
-type testProtocol struct {
-	offset MsgCode
-	f      func(MsgReadWriter)
-}
-
-func (p *testProtocol) Offset() MsgCode {
-	return p.offset
-}
-
-func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error {
-	p.f(rw)
-	return nil
-}
-
-func TestRead(t *testing.T) {
-	done := make(chan struct{})
-	handlers := Handlers{
-		"a": &testProtocol{5, func(rw MsgReadWriter) {
-			msg, err := rw.ReadMsg()
-			if err != nil {
-				t.Errorf("read error: %v", err)
-			}
-			if msg.Code != 2 {
-				t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
-			}
-			data, err := msg.Data()
-			if err != nil {
-				t.Errorf("data decoding error: %v", err)
-			}
-			expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
-			if !reflect.DeepEqual(data.Slice(), expdata) {
-				t.Errorf("incorrect msg data %#v", data.Slice())
-			}
-			close(done)
-		}},
-	}
-
-	net, peer, m := testMessenger(handlers)
-	defer peer.Stop()
-	bufr := bufio.NewReader(net)
-	if err := performTestHandshake(bufr, net); err != nil {
-		t.Fatalf("handshake failed: %v", err)
-	}
-	m.setRemoteProtocols([]string{"a"})
-
-	writeMsg(net, NewMsg(18, 1, "000"))
-	select {
-	case <-done:
-	case <-time.After(2 * time.Second):
-		t.Errorf("receive timeout")
-	}
-}
-
-func TestWriteFromProto(t *testing.T) {
-	handlers := Handlers{
-		"a": &testProtocol{2, func(rw MsgReadWriter) {
-			if err := rw.WriteMsg(NewMsg(2)); err == nil {
-				t.Error("expected error for out-of-range msg code, got nil")
-			}
-			if err := rw.WriteMsg(NewMsg(1)); err != nil {
-				t.Errorf("write error: %v", err)
-			}
-		}},
-	}
-	net, peer, mess := testMessenger(handlers)
-	defer peer.Stop()
-	bufr := bufio.NewReader(net)
-	if err := performTestHandshake(bufr, net); err != nil {
-		t.Fatalf("handshake failed: %v", err)
-	}
-	mess.setRemoteProtocols([]string{"a"})
-
-	msg, err := readMsg(bufr)
-	if err != nil {
-		t.Errorf("read error: %v")
-	}
-	if msg.Code != 17 {
-		t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
-	}
-}
-
-var discardProto = &testProtocol{1, func(rw MsgReadWriter) {
-	for {
-		msg, err := rw.ReadMsg()
-		if err != nil {
-			return
-		}
-		if err = msg.Discard(); err != nil {
-			return
-		}
-	}
-}}
-
-func TestMessengerWriteProtoMsg(t *testing.T) {
-	handlers := Handlers{"a": discardProto}
-	net, peer, mess := testMessenger(handlers)
-	defer peer.Stop()
-	bufr := bufio.NewReader(net)
-	if err := performTestHandshake(bufr, net); err != nil {
-		t.Fatalf("handshake failed: %v", err)
-	}
-	mess.setRemoteProtocols([]string{"a"})
-
-	// test write errors
-	if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil {
-		t.Errorf("expected error for unknown protocol, got nil")
-	}
-	if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil {
-		t.Errorf("expected error for out-of-range msg code, got nil")
-	} else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode {
-		t.Errorf("wrong error for out-of-range msg code, got %#v")
-	}
-
-	// test succcessful write
-	read, readerr := make(chan Msg), make(chan error)
-	go func() {
-		if msg, err := readMsg(bufr); err != nil {
-			readerr <- err
-		} else {
-			read <- msg
-		}
-	}()
-	if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil {
-		t.Errorf("expect no error for known protocol: %v", err)
-	}
-	select {
-	case msg := <-read:
-		if msg.Code != 16 {
-			t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
-		}
-		msg.Discard()
-	case err := <-readerr:
-		t.Errorf("read error: %v", err)
-	}
-}
-
-func TestPulse(t *testing.T) {
-	net, peer, _ := testMessenger(nil)
-	defer peer.Stop()
-	bufr := bufio.NewReader(net)
-	if err := performTestHandshake(bufr, net); err != nil {
-		t.Fatalf("handshake failed: %v", err)
-	}
-
-	before := time.Now()
-	msg, err := readMsg(bufr)
-	if err != nil {
-		t.Fatalf("read error: %v", err)
-	}
-	after := time.Now()
-	if msg.Code != pingMsg {
-		t.Errorf("expected ping message, got %d", msg.Code)
-	}
-	if d := after.Sub(before); d < pingTimeout {
-		t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
-	}
-}

+ 17 - 17
p2p/natpmp.go

@@ -3,6 +3,7 @@ package p2p
 import (
 	"fmt"
 	"net"
+	"time"
 
 	natpmp "github.com/jackpal/go-nat-pmp"
 )
@@ -13,38 +14,37 @@ import (
 //  + Register for changes to the external address.
 //  + Re-register port mapping when router reboots.
 //  + A mechanism for keeping a port mapping registered.
+//  + Discover gateway address automatically.
 
 type natPMPClient struct {
 	client *natpmp.Client
 }
 
-func NewNatPMP(gateway net.IP) (nat NAT) {
+// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
+// address should be the IP of your router.
+func PMP(gateway net.IP) (nat NAT) {
 	return &natPMPClient{natpmp.NewClient(gateway)}
 }
 
-func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) {
+func (*natPMPClient) String() string {
+	return "NAT-PMP"
+}
+
+func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
 	response, err := n.client.GetExternalAddress()
 	if err != nil {
-		return
+		return nil, err
 	}
-	ip := response.ExternalIPAddress
-	addr = net.IPv4(ip[0], ip[1], ip[2], ip[3])
-	return
+	return response.ExternalIPAddress[:], nil
 }
 
-func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int,
-	description string, timeout int) (mappedExternalPort int, err error) {
-	if timeout <= 0 {
-		err = fmt.Errorf("timeout must not be <= 0")
-		return
+func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
+	if lifetime <= 0 {
+		return fmt.Errorf("lifetime must not be <= 0")
 	}
 	// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
-	response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout)
-	if err != nil {
-		return
-	}
-	mappedExternalPort = int(response.MappedExternalPort)
-	return
+	_, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
+	return err
 }
 
 func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {

+ 102 - 96
p2p/natupnp.go

@@ -7,6 +7,7 @@ import (
 	"bytes"
 	"encoding/xml"
 	"errors"
+	"fmt"
 	"net"
 	"net/http"
 	"os"
@@ -15,28 +16,46 @@ import (
 	"time"
 )
 
+const (
+	upnpDiscoverAttempts = 3
+	upnpDiscoverTimeout  = 5 * time.Second
+)
+
+// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
+// discover the address of your router using UDP broadcasts.
+func UPNP() NAT {
+	return &upnpNAT{}
+}
+
 type upnpNAT struct {
 	serviceURL string
 	ourIP      string
 }
 
-func upnpDiscover(attempts int) (nat NAT, err error) {
+func (n *upnpNAT) String() string {
+	return "UPNP"
+}
+
+func (n *upnpNAT) discover() error {
+	if n.serviceURL != "" {
+		// already discovered
+		return nil
+	}
+
 	ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
 	if err != nil {
-		return
+		return err
 	}
+	// TODO: try on all network interfaces simultaneously.
+	// Broadcasting on 0.0.0.0 could select a random interface
+	// to send on (platform specific).
 	conn, err := net.ListenPacket("udp4", ":0")
 	if err != nil {
-		return
-	}
-	socket := conn.(*net.UDPConn)
-	defer socket.Close()
-
-	err = socket.SetDeadline(time.Now().Add(10 * time.Second))
-	if err != nil {
-		return
+		return err
 	}
+	defer conn.Close()
 
+	conn.SetDeadline(time.Now().Add(10 * time.Second))
 	st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
 	buf := bytes.NewBufferString(
 		"M-SEARCH * HTTP/1.1\r\n" +
@@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
 			"MX: 2\r\n\r\n")
 	message := buf.Bytes()
 	answerBytes := make([]byte, 1024)
-	for i := 0; i < attempts; i++ {
-		_, err = socket.WriteToUDP(message, ssdp)
+	for i := 0; i < upnpDiscoverAttempts; i++ {
+		_, err = conn.WriteTo(message, ssdp)
 		if err != nil {
-			return
+			return err
 		}
-		var n int
-		n, _, err = socket.ReadFromUDP(answerBytes)
+		nn, _, err := conn.ReadFrom(answerBytes)
 		if err != nil {
 			continue
-			// socket.Close()
-			// return
 		}
-		answer := string(answerBytes[0:n])
+		answer := string(answerBytes[0:nn])
 		if strings.Index(answer, "\r\n"+st) < 0 {
 			continue
 		}
@@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
 		var serviceURL string
 		serviceURL, err = getServiceURL(locURL)
 		if err != nil {
-			return
+			return err
 		}
 		var ourIP string
 		ourIP, err = getOurIP()
 		if err != nil {
-			return
+			return err
 		}
-		nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP}
+		n.serviceURL = serviceURL
+		n.ourIP = ourIP
+		return nil
+	}
+	return errors.New("UPnP port discovery failed.")
+}
+
+func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
+	if err := n.discover(); err != nil {
+		return nil, err
+	}
+	info, err := n.getStatusInfo()
+	return net.ParseIP(info.externalIpAddress), err
+}
+
+func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
+	if err := n.discover(); err != nil {
+		return err
+	}
+
+	// A single concatenation would break ARM compilation.
+	message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
+		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
+	message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
+	message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
+		"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
+		"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
+	message += description +
+		"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
+		"</NewLeaseDuration></u:AddPortMapping>"
+
+	// TODO: check response to see if the port was forwarded
+	_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
+	return err
+}
+
+func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
+	if err := n.discover(); err != nil {
+		return err
+	}
+
+	message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
+		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
+		"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
+		"</u:DeletePortMapping>"
+
+	// TODO: check response to see if the port was deleted
+	_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
+	return err
+}
+
+type statusInfo struct {
+	externalIpAddress string
+}
+
+func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
+	message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
+		"</u:GetStatusInfo>"
+
+	var response *http.Response
+	response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
+	if err != nil {
 		return
 	}
-	err = errors.New("UPnP port discovery failed.")
+
+	// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
+
+	response.Body.Close()
 	return
 }
 
@@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) {
 	}
 	return
 }
-
-type statusInfo struct {
-	externalIpAddress string
-}
-
-func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
-
-	message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
-		"</u:GetStatusInfo>"
-
-	var response *http.Response
-	response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
-	if err != nil {
-		return
-	}
-
-	// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
-
-	response.Body.Close()
-	return
-}
-
-func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
-	info, err := n.getStatusInfo()
-	if err != nil {
-		return
-	}
-	addr = net.ParseIP(info.externalIpAddress)
-	return
-}
-
-func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
-	// A single concatenation would break ARM compilation.
-	message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
-		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
-	message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
-	message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
-		"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
-		"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
-	message += description +
-		"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
-		"</NewLeaseDuration></u:AddPortMapping>"
-
-	var response *http.Response
-	response, err = soapRequest(n.serviceURL, "AddPortMapping", message)
-	if err != nil {
-		return
-	}
-
-	// TODO: check response to see if the port was forwarded
-	// log.Println(message, response)
-	mappedExternalPort = externalPort
-	_ = response
-	return
-}
-
-func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
-
-	message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
-		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
-		"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
-		"</u:DeletePortMapping>"
-
-	var response *http.Response
-	response, err = soapRequest(n.serviceURL, "DeletePortMapping", message)
-	if err != nil {
-		return
-	}
-
-	// TODO: check response to see if the port was deleted
-	// log.Println(message, response)
-	_ = response
-	return
-}

+ 0 - 196
p2p/network.go

@@ -1,196 +0,0 @@
-package p2p
-
-import (
-	"fmt"
-	"math/rand"
-	"net"
-	"strconv"
-	"time"
-)
-
-const (
-	DialerTimeout             = 180 //seconds
-	KeepAlivePeriod           = 60  //minutes
-	portMappingUpdateInterval = 900 // seconds = 15 mins
-	upnpDiscoverAttempts      = 3
-)
-
-// Dialer is not an interface in net, so we define one
-// *net.Dialer conforms to this
-type Dialer interface {
-	Dial(network, address string) (net.Conn, error)
-}
-
-type Network interface {
-	Start() error
-	Listener(net.Addr) (net.Listener, error)
-	Dialer(net.Addr) (Dialer, error)
-	NewAddr(string, int) (addr net.Addr, err error)
-	ParseAddr(string) (addr net.Addr, err error)
-}
-
-type NAT interface {
-	GetExternalAddress() (addr net.IP, err error)
-	AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
-	DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
-}
-
-type TCPNetwork struct {
-	nat     NAT
-	natType NATType
-	quit    chan chan bool
-	ports   chan string
-}
-
-type NATType int
-
-const (
-	NONE = iota
-	UPNP
-	PMP
-)
-
-const (
-	portMappingTimeout = 1200 // 20 mins
-)
-
-func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
-	return &TCPNetwork{
-		natType: natType,
-		ports:   make(chan string),
-	}
-}
-
-func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
-	return &net.Dialer{
-		Timeout: DialerTimeout * time.Second,
-		// KeepAlive: KeepAlivePeriod * time.Minute,
-		LocalAddr: addr,
-	}, nil
-}
-
-func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
-	if self.natType == UPNP {
-		_, port, _ := net.SplitHostPort(addr.String())
-		if self.quit == nil {
-			self.quit = make(chan chan bool)
-			go self.updatePortMappings()
-		}
-		self.ports <- port
-	}
-	return net.Listen(addr.Network(), addr.String())
-}
-
-func (self *TCPNetwork) Start() (err error) {
-	switch self.natType {
-	case NONE:
-	case UPNP:
-		nat, uerr := upnpDiscover(upnpDiscoverAttempts)
-		if uerr != nil {
-			err = fmt.Errorf("UPNP failed: ", uerr)
-		} else {
-			self.nat = nat
-		}
-	case PMP:
-		err = fmt.Errorf("PMP not implemented")
-	default:
-		err = fmt.Errorf("Invalid NAT type: %v", self.natType)
-	}
-	return
-}
-
-func (self *TCPNetwork) Stop() {
-	q := make(chan bool)
-	self.quit <- q
-	<-q
-}
-
-func (self *TCPNetwork) addPortMapping(lport int) (err error) {
-	_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
-	if err != nil {
-		logger.Errorf("unable to add port mapping on %v: %v", lport, err)
-	} else {
-		logger.Debugf("succesfully added port mapping on %v", lport)
-	}
-	return
-}
-
-func (self *TCPNetwork) updatePortMappings() {
-	timer := time.NewTimer(portMappingUpdateInterval * time.Second)
-	lports := []int{}
-out:
-	for {
-		select {
-		case port := <-self.ports:
-			int64lport, _ := strconv.ParseInt(port, 10, 16)
-			lport := int(int64lport)
-			if err := self.addPortMapping(lport); err != nil {
-				lports = append(lports, lport)
-			}
-		case <-timer.C:
-			for lport := range lports {
-				if err := self.addPortMapping(lport); err != nil {
-				}
-			}
-		case errc := <-self.quit:
-			errc <- true
-			break out
-		}
-	}
-
-	timer.Stop()
-	for lport := range lports {
-		if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
-			logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
-		} else {
-			logger.Debugf("succesfully removed port mapping on %v", lport)
-		}
-	}
-}
-
-func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
-	ip, err := self.lookupIP(host)
-	if err == nil {
-		return &net.TCPAddr{
-			IP:   ip,
-			Port: port,
-		}, nil
-	}
-	return nil, err
-}
-
-func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
-	host, port, err := net.SplitHostPort(address)
-	if err == nil {
-		iport, _ := strconv.Atoi(port)
-		addr, e := self.NewAddr(host, iport)
-		return addr, e
-	}
-	return nil, err
-}
-
-func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
-	if ip = net.ParseIP(host); ip != nil {
-		return
-	}
-
-	var ips []net.IP
-	ips, err = net.LookupIP(host)
-	if err != nil {
-		logger.Warnln(err)
-		return
-	}
-	if len(ips) == 0 {
-		err = fmt.Errorf("No IP addresses available for %v", host)
-		logger.Warnln(err)
-		return
-	}
-	if len(ips) > 1 {
-		// Pick a random IP address, simulating round-robin DNS.
-		rand.Seed(time.Now().UTC().UnixNano())
-		ip = ips[rand.Intn(len(ips))]
-	} else {
-		ip = ips[0]
-	}
-	return
-}

+ 432 - 44
p2p/peer.go

@@ -1,66 +1,454 @@
 package p2p
 
 import (
+	"bufio"
+	"bytes"
 	"fmt"
+	"io"
+	"io/ioutil"
 	"net"
-	"strconv"
+	"sort"
+	"sync"
+	"time"
+
+	"github.com/ethereum/go-ethereum/event"
+	"github.com/ethereum/go-ethereum/logger"
 )
 
+// peerAddr is the structure of a peer list element.
+// It is also a valid net.Addr.
+type peerAddr struct {
+	IP     net.IP
+	Port   uint64
+	Pubkey []byte // optional
+}
+
+func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
+	n := addr.Network()
+	if n != "tcp" && n != "tcp4" && n != "tcp6" {
+		// for testing with non-TCP
+		return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
+	}
+	ta := addr.(*net.TCPAddr)
+	return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
+}
+
+func (d peerAddr) Network() string {
+	if d.IP.To4() != nil {
+		return "tcp4"
+	} else {
+		return "tcp6"
+	}
+}
+
+func (d peerAddr) String() string {
+	return fmt.Sprintf("%v:%d", d.IP, d.Port)
+}
+
+func (d peerAddr) RlpData() interface{} {
+	return []interface{}{d.IP, d.Port, d.Pubkey}
+}
+
+// Peer represents a remote peer.
 type Peer struct {
-	Inbound          bool // inbound (via listener) or outbound (via dialout)
-	Address          net.Addr
-	Host             []byte
-	Port             uint16
-	Pubkey           []byte
-	Id               string
-	Caps             []string
-	peerErrorChan    chan error
-	messenger        *messenger
-	peerErrorHandler *PeerErrorHandler
-	server           *Server
-}
-
-func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer {
-	peerErrorChan := NewPeerErrorChannel()
-	host, port, _ := net.SplitHostPort(address.String())
-	intport, _ := strconv.Atoi(port)
-	peer := &Peer{
-		Inbound:       inbound,
-		Address:       address,
-		Port:          uint16(intport),
-		Host:          net.ParseIP(host),
-		peerErrorChan: peerErrorChan,
-		server:        server,
-	}
-	peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers())
-	peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan)
+	// Peers have all the log methods.
+	// Use them to display messages related to the peer.
+	*logger.Logger
+
+	infolock   sync.Mutex
+	identity   ClientIdentity
+	caps       []Cap
+	listenAddr *peerAddr // what remote peer is listening on
+	dialAddr   *peerAddr // non-nil if dialing
+
+	// The mutex protects the connection
+	// so only one protocol can write at a time.
+	writeMu sync.Mutex
+	conn    net.Conn
+	bufconn *bufio.ReadWriter
+
+	// These fields maintain the running protocols.
+	protocols       []Protocol
+	runBaseProtocol bool // for testing
+
+	runlock sync.RWMutex // protects running
+	running map[string]*proto
+
+	protoWG  sync.WaitGroup
+	protoErr chan error
+	closed   chan struct{}
+	disc     chan DiscReason
+
+	activity event.TypeMux // for activity events
+
+	slot int // index into Server peer list
+
+	// These fields are kept so base protocol can access them.
+	// TODO: this should be one or more interfaces
+	ourID         ClientIdentity        // client id of the Server
+	ourListenAddr *peerAddr             // listen addr of Server, nil if not listening
+	newPeerAddr   chan<- *peerAddr      // tell server about received peers
+	otherPeers    func() []*Peer        // should return the list of all peers
+	pubkeyHook    func(*peerAddr) error // called at end of handshake to validate pubkey
+}
+
+// NewPeer returns a peer for testing purposes.
+func NewPeer(id ClientIdentity, caps []Cap) *Peer {
+	conn, _ := net.Pipe()
+	peer := newPeer(conn, nil, nil)
+	peer.setHandshakeInfo(id, nil, caps)
 	return peer
 }
 
-func (self *Peer) String() string {
-	var kind string
-	if self.Inbound {
-		kind = "inbound"
-	} else {
+func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
+	p := newPeer(conn, server.Protocols, dialAddr)
+	p.ourID = server.Identity
+	p.newPeerAddr = server.peerConnect
+	p.otherPeers = server.Peers
+	p.pubkeyHook = server.verifyPeer
+	p.runBaseProtocol = true
+
+	// laddr can be updated concurrently by NAT traversal.
+	// newServerPeer must be called with the server lock held.
+	if server.laddr != nil {
+		p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
+	}
+	return p
+}
+
+func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
+	p := &Peer{
+		Logger:    logger.NewLogger("P2P " + conn.RemoteAddr().String()),
+		conn:      conn,
+		dialAddr:  dialAddr,
+		bufconn:   bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
+		protocols: protocols,
+		running:   make(map[string]*proto),
+		disc:      make(chan DiscReason),
+		protoErr:  make(chan error),
+		closed:    make(chan struct{}),
+	}
+	return p
+}
+
+// Identity returns the client identity of the remote peer. The
+// identity can be nil if the peer has not yet completed the
+// handshake.
+func (p *Peer) Identity() ClientIdentity {
+	p.infolock.Lock()
+	defer p.infolock.Unlock()
+	return p.identity
+}
+
+// Caps returns the capabilities (supported subprotocols) of the remote peer.
+func (p *Peer) Caps() []Cap {
+	p.infolock.Lock()
+	defer p.infolock.Unlock()
+	return p.caps
+}
+
+func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
+	p.infolock.Lock()
+	p.identity = id
+	p.listenAddr = laddr
+	p.caps = caps
+	p.infolock.Unlock()
+}
+
+// RemoteAddr returns the remote address of the network connection.
+func (p *Peer) RemoteAddr() net.Addr {
+	return p.conn.RemoteAddr()
+}
+
+// LocalAddr returns the local address of the network connection.
+func (p *Peer) LocalAddr() net.Addr {
+	return p.conn.LocalAddr()
+}
+
+// Disconnect terminates the peer connection with the given reason.
+// It returns immediately and does not wait until the connection is closed.
+func (p *Peer) Disconnect(reason DiscReason) {
+	select {
+	case p.disc <- reason:
+	case <-p.closed:
+	}
+}
+
+// String implements fmt.Stringer.
+func (p *Peer) String() string {
+	kind := "inbound"
+	p.infolock.Lock()
+	if p.dialAddr != nil {
 		kind = "outbound"
 	}
-	return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps)
+	p.infolock.Unlock()
+	return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
+}
+
+const (
+	// maximum amount of time allowed for reading a message
+	msgReadTimeout = 5 * time.Second
+	// maximum amount of time allowed for writing a message
+	msgWriteTimeout = 5 * time.Second
+	// messages smaller than this many bytes will be read at
+	// once before passing them to a protocol.
+	wholePayloadSize = 64 * 1024
+)
+
+var (
+	inactivityTimeout     = 2 * time.Second
+	disconnectGracePeriod = 2 * time.Second
+)
+
+func (p *Peer) loop() (reason DiscReason, err error) {
+	defer p.activity.Stop()
+	defer p.closeProtocols()
+	defer close(p.closed)
+	defer p.conn.Close()
+
+	// read loop
+	readMsg := make(chan Msg)
+	readErr := make(chan error)
+	readNext := make(chan bool, 1)
+	protoDone := make(chan struct{}, 1)
+	go p.readLoop(readMsg, readErr, readNext)
+	readNext <- true
+
+	if p.runBaseProtocol {
+		p.startBaseProtocol()
+	}
+
+loop:
+	for {
+		select {
+		case msg := <-readMsg:
+			// a new message has arrived.
+			var wait bool
+			if wait, err = p.dispatch(msg, protoDone); err != nil {
+				p.Errorf("msg dispatch error: %v\n", err)
+				reason = discReasonForError(err)
+				break loop
+			}
+			if !wait {
+				// Msg has already been read completely, continue with next message.
+				readNext <- true
+			}
+			p.activity.Post(time.Now())
+		case <-protoDone:
+			// protocol has consumed the message payload,
+			// we can continue reading from the socket.
+			readNext <- true
+
+		case err := <-readErr:
+			// read failed. there is no need to run the
+			// polite disconnect sequence because the connection
+			// is probably dead anyway.
+			// TODO: handle write errors as well
+			return DiscNetworkError, err
+		case err = <-p.protoErr:
+			reason = discReasonForError(err)
+			break loop
+		case reason = <-p.disc:
+			break loop
+		}
+	}
+
+	// wait for read loop to return.
+	close(readNext)
+	<-readErr
+	// tell the remote end to disconnect
+	done := make(chan struct{})
+	go func() {
+		p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
+		p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
+		io.Copy(ioutil.Discard, p.conn)
+		close(done)
+	}()
+	select {
+	case <-done:
+	case <-time.After(disconnectGracePeriod):
+	}
+	return reason, err
+}
+
+func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
+	for _ = range unblock {
+		p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
+		if msg, err := readMsg(p.bufconn); err != nil {
+			errc <- err
+		} else {
+			msgc <- msg
+		}
+	}
+	close(errc)
+}
+
+func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
+	proto, err := p.getProto(msg.Code)
+	if err != nil {
+		return false, err
+	}
+	if msg.Size <= wholePayloadSize {
+		// optimization: msg is small enough, read all
+		// of it and move on to the next message
+		buf, err := ioutil.ReadAll(msg.Payload)
+		if err != nil {
+			return false, err
+		}
+		msg.Payload = bytes.NewReader(buf)
+		proto.in <- msg
+	} else {
+		wait = true
+		pr := &eofSignal{msg.Payload, protoDone}
+		msg.Payload = pr
+		proto.in <- msg
+	}
+	return wait, nil
+}
+
+func (p *Peer) startBaseProtocol() {
+	p.runlock.Lock()
+	defer p.runlock.Unlock()
+	p.running[""] = p.startProto(0, Protocol{
+		Length: baseProtocolLength,
+		Run:    runBaseProtocol,
+	})
+}
+
+// startProtocols starts matching named subprotocols.
+func (p *Peer) startSubprotocols(caps []Cap) {
+	sort.Sort(capsByName(caps))
+
+	p.runlock.Lock()
+	defer p.runlock.Unlock()
+	offset := baseProtocolLength
+outer:
+	for _, cap := range caps {
+		for _, proto := range p.protocols {
+			if proto.Name == cap.Name &&
+				proto.Version == cap.Version &&
+				p.running[cap.Name] == nil {
+				p.running[cap.Name] = p.startProto(offset, proto)
+				offset += proto.Length
+				continue outer
+			}
+		}
+	}
+}
+
+func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
+	rw := &proto{
+		in:      make(chan Msg),
+		offset:  offset,
+		maxcode: impl.Length,
+		peer:    p,
+	}
+	p.protoWG.Add(1)
+	go func() {
+		err := impl.Run(p, rw)
+		if err == nil {
+			p.Infof("protocol %q returned", impl.Name)
+			err = newPeerError(errMisc, "protocol returned")
+		} else {
+			p.Errorf("protocol %q error: %v\n", impl.Name, err)
+		}
+		select {
+		case p.protoErr <- err:
+		case <-p.closed:
+		}
+		p.protoWG.Done()
+	}()
+	return rw
+}
+
+// getProto finds the protocol responsible for handling
+// the given message code.
+func (p *Peer) getProto(code uint64) (*proto, error) {
+	p.runlock.RLock()
+	defer p.runlock.RUnlock()
+	for _, proto := range p.running {
+		if code >= proto.offset && code < proto.offset+proto.maxcode {
+			return proto, nil
+		}
+	}
+	return nil, newPeerError(errInvalidMsgCode, "%d", code)
+}
+
+func (p *Peer) closeProtocols() {
+	p.runlock.RLock()
+	for _, p := range p.running {
+		close(p.in)
+	}
+	p.runlock.RUnlock()
+	p.protoWG.Wait()
+}
+
+// writeProtoMsg sends the given message on behalf of the given named protocol.
+func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
+	p.runlock.RLock()
+	proto, ok := p.running[protoName]
+	p.runlock.RUnlock()
+	if !ok {
+		return fmt.Errorf("protocol %s not handled by peer", protoName)
+	}
+	if msg.Code >= proto.maxcode {
+		return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
+	}
+	msg.Code += proto.offset
+	return p.writeMsg(msg, msgWriteTimeout)
+}
+
+// writeMsg writes a message to the connection.
+func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
+	p.writeMu.Lock()
+	defer p.writeMu.Unlock()
+	p.conn.SetWriteDeadline(time.Now().Add(timeout))
+	if err := writeMsg(p.bufconn, msg); err != nil {
+		return newPeerError(errWrite, "%v", err)
+	}
+	return p.bufconn.Flush()
+}
+
+type proto struct {
+	name            string
+	in              chan Msg
+	maxcode, offset uint64
+	peer            *Peer
+}
+
+func (rw *proto) WriteMsg(msg Msg) error {
+	if msg.Code >= rw.maxcode {
+		return newPeerError(errInvalidMsgCode, "not handled")
+	}
+	msg.Code += rw.offset
+	return rw.peer.writeMsg(msg, msgWriteTimeout)
 }
 
-func (self *Peer) Write(protocol string, msg Msg) error {
-	return self.messenger.writeProtoMsg(protocol, msg)
+func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
+	return rw.WriteMsg(NewMsg(code, data))
 }
 
-func (self *Peer) Start() {
-	self.peerErrorHandler.Start()
-	self.messenger.Start()
+func (rw *proto) ReadMsg() (Msg, error) {
+	msg, ok := <-rw.in
+	if !ok {
+		return msg, io.EOF
+	}
+	msg.Code -= rw.offset
+	return msg, nil
 }
 
-func (self *Peer) Stop() {
-	self.peerErrorHandler.Stop()
-	self.messenger.Stop()
+// eofSignal wraps a reader with eof signaling.
+// the eof channel is closed when the wrapped reader
+// reaches EOF.
+type eofSignal struct {
+	wrapped io.Reader
+	eof     chan<- struct{}
 }
 
-func (p *Peer) Encode() []interface{} {
-	return []interface{}{p.Host, p.Port, p.Pubkey}
+func (r *eofSignal) Read(buf []byte) (int, error) {
+	n, err := r.wrapped.Read(buf)
+	if err != nil {
+		r.eof <- struct{}{} // tell Peer that msg has been consumed
+	}
+	return n, err
 }

+ 100 - 50
p2p/peer_error.go

@@ -4,71 +4,121 @@ import (
 	"fmt"
 )
 
-type ErrorCode int
-
-const errorChanCapacity = 10
-
 const (
-	PacketTooLong = iota
-	PayloadTooShort
-	MagicTokenMismatch
-	ReadError
-	WriteError
-	MiscError
-	InvalidMsgCode
-	InvalidMsg
-	P2PVersionMismatch
-	PubkeyMissing
-	PubkeyInvalid
-	PubkeyForbidden
-	ProtocolBreach
-	PortMismatch
-	PingTimeout
-	InvalidGenesis
-	InvalidNetworkId
-	InvalidProtocolVersion
+	errMagicTokenMismatch = iota
+	errRead
+	errWrite
+	errMisc
+	errInvalidMsgCode
+	errInvalidMsg
+	errP2PVersionMismatch
+	errPubkeyMissing
+	errPubkeyInvalid
+	errPubkeyForbidden
+	errProtocolBreach
+	errPingTimeout
+	errInvalidNetworkId
+	errInvalidProtocolVersion
 )
 
-var errorToString = map[ErrorCode]string{
-	PacketTooLong:          "Packet too long",
-	PayloadTooShort:        "Payload too short",
-	MagicTokenMismatch:     "Magic token mismatch",
-	ReadError:              "Read error",
-	WriteError:             "Write error",
-	MiscError:              "Misc error",
-	InvalidMsgCode:         "Invalid message code",
-	InvalidMsg:             "Invalid message",
-	P2PVersionMismatch:     "P2P Version Mismatch",
-	PubkeyMissing:          "Public key missing",
-	PubkeyInvalid:          "Public key invalid",
-	PubkeyForbidden:        "Public key forbidden",
-	ProtocolBreach:         "Protocol Breach",
-	PortMismatch:           "Port mismatch",
-	PingTimeout:            "Ping timeout",
-	InvalidGenesis:         "Invalid genesis block",
-	InvalidNetworkId:       "Invalid network id",
-	InvalidProtocolVersion: "Invalid protocol version",
+var errorToString = map[int]string{
+	errMagicTokenMismatch:     "Magic token mismatch",
+	errRead:                   "Read error",
+	errWrite:                  "Write error",
+	errMisc:                   "Misc error",
+	errInvalidMsgCode:         "Invalid message code",
+	errInvalidMsg:             "Invalid message",
+	errP2PVersionMismatch:     "P2P Version Mismatch",
+	errPubkeyMissing:          "Public key missing",
+	errPubkeyInvalid:          "Public key invalid",
+	errPubkeyForbidden:        "Public key forbidden",
+	errProtocolBreach:         "Protocol Breach",
+	errPingTimeout:            "Ping timeout",
+	errInvalidNetworkId:       "Invalid network id",
+	errInvalidProtocolVersion: "Invalid protocol version",
 }
 
-type PeerError struct {
-	Code    ErrorCode
+type peerError struct {
+	Code    int
 	message string
 }
 
-func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError {
+func newPeerError(code int, format string, v ...interface{}) *peerError {
 	desc, ok := errorToString[code]
 	if !ok {
 		panic("invalid error code")
 	}
-	format = desc + ": " + format
-	message := fmt.Sprintf(format, v...)
-	return &PeerError{code, message}
+	err := &peerError{code, desc}
+	if format != "" {
+		err.message += ": " + fmt.Sprintf(format, v...)
+	}
+	return err
 }
 
-func (self *PeerError) Error() string {
+func (self *peerError) Error() string {
 	return self.message
 }
 
-func NewPeerErrorChannel() chan error {
-	return make(chan error, errorChanCapacity)
+type DiscReason byte
+
+const (
+	DiscRequested           DiscReason = 0x00
+	DiscNetworkError                   = 0x01
+	DiscProtocolError                  = 0x02
+	DiscUselessPeer                    = 0x03
+	DiscTooManyPeers                   = 0x04
+	DiscAlreadyConnected               = 0x05
+	DiscIncompatibleVersion            = 0x06
+	DiscInvalidIdentity                = 0x07
+	DiscQuitting                       = 0x08
+	DiscUnexpectedIdentity             = 0x09
+	DiscSelf                           = 0x0a
+	DiscReadTimeout                    = 0x0b
+	DiscSubprotocolError               = 0x10
+)
+
+var discReasonToString = [DiscSubprotocolError + 1]string{
+	DiscRequested:           "Disconnect requested",
+	DiscNetworkError:        "Network error",
+	DiscProtocolError:       "Breach of protocol",
+	DiscUselessPeer:         "Useless peer",
+	DiscTooManyPeers:        "Too many peers",
+	DiscAlreadyConnected:    "Already connected",
+	DiscIncompatibleVersion: "Incompatible P2P protocol version",
+	DiscInvalidIdentity:     "Invalid node identity",
+	DiscQuitting:            "Client quitting",
+	DiscUnexpectedIdentity:  "Unexpected identity",
+	DiscSelf:                "Connected to self",
+	DiscReadTimeout:         "Read timeout",
+	DiscSubprotocolError:    "Subprotocol error",
+}
+
+func (d DiscReason) String() string {
+	if len(discReasonToString) < int(d) {
+		return fmt.Sprintf("Unknown Reason(%d)", d)
+	}
+	return discReasonToString[d]
+}
+
+func discReasonForError(err error) DiscReason {
+	peerError, ok := err.(*peerError)
+	if !ok {
+		return DiscSubprotocolError
+	}
+	switch peerError.Code {
+	case errP2PVersionMismatch:
+		return DiscIncompatibleVersion
+	case errPubkeyMissing, errPubkeyInvalid:
+		return DiscInvalidIdentity
+	case errPubkeyForbidden:
+		return DiscUselessPeer
+	case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
+		return DiscProtocolError
+	case errPingTimeout:
+		return DiscReadTimeout
+	case errRead, errWrite, errMisc:
+		return DiscNetworkError
+	default:
+		return DiscSubprotocolError
+	}
 }

+ 0 - 98
p2p/peer_error_handler.go

@@ -1,98 +0,0 @@
-package p2p
-
-import (
-	"net"
-)
-
-const (
-	severityThreshold = 10
-)
-
-type DisconnectRequest struct {
-	addr   net.Addr
-	reason DiscReason
-}
-
-type PeerErrorHandler struct {
-	quit           chan chan bool
-	address        net.Addr
-	peerDisconnect chan DisconnectRequest
-	severity       int
-	errc           chan error
-}
-
-func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler {
-	return &PeerErrorHandler{
-		quit:           make(chan chan bool),
-		address:        address,
-		peerDisconnect: peerDisconnect,
-		errc:           errc,
-	}
-}
-
-func (self *PeerErrorHandler) Start() {
-	go self.listen()
-}
-
-func (self *PeerErrorHandler) Stop() {
-	q := make(chan bool)
-	self.quit <- q
-	<-q
-}
-
-func (self *PeerErrorHandler) listen() {
-	for {
-		select {
-		case err, ok := <-self.errc:
-			if ok {
-				logger.Debugf("error %v\n", err)
-				go self.handle(err)
-			} else {
-				return
-			}
-		case q := <-self.quit:
-			q <- true
-			return
-		}
-	}
-}
-
-func (self *PeerErrorHandler) handle(err error) {
-	reason := DiscReason(' ')
-	peerError, ok := err.(*PeerError)
-	if !ok {
-		peerError = NewPeerError(MiscError, " %v", err)
-	}
-	switch peerError.Code {
-	case P2PVersionMismatch:
-		reason = DiscIncompatibleVersion
-	case PubkeyMissing, PubkeyInvalid:
-		reason = DiscInvalidIdentity
-	case PubkeyForbidden:
-		reason = DiscUselessPeer
-	case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach:
-		reason = DiscProtocolError
-	case PingTimeout:
-		reason = DiscReadTimeout
-	case ReadError, WriteError, MiscError:
-		reason = DiscNetworkError
-	case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
-		reason = DiscSubprotocolError
-	default:
-		self.severity += self.getSeverity(peerError)
-	}
-
-	if self.severity >= severityThreshold {
-		reason = DiscSubprotocolError
-	}
-	if reason != DiscReason(' ') {
-		self.peerDisconnect <- DisconnectRequest{
-			addr:   self.address,
-			reason: reason,
-		}
-	}
-}
-
-func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
-	return 1
-}

+ 0 - 34
p2p/peer_error_handler_test.go

@@ -1,34 +0,0 @@
-package p2p
-
-import (
-	// "fmt"
-	"net"
-	"testing"
-	"time"
-)
-
-func TestPeerErrorHandler(t *testing.T) {
-	address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
-	peerDisconnect := make(chan DisconnectRequest)
-	peerErrorChan := NewPeerErrorChannel()
-	peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan)
-	peh.Start()
-	defer peh.Stop()
-	for i := 0; i < 11; i++ {
-		select {
-		case <-peerDisconnect:
-			t.Errorf("expected no disconnect request")
-		default:
-		}
-		peerErrorChan <- NewPeerError(MiscError, "")
-	}
-	time.Sleep(1 * time.Millisecond)
-	select {
-	case request := <-peerDisconnect:
-		if request.addr.String() != address.String() {
-			t.Errorf("incorrect address %v != %v", request.addr, address)
-		}
-	default:
-		t.Errorf("expected disconnect request")
-	}
-}

+ 220 - 88
p2p/peer_test.go

@@ -1,90 +1,222 @@
 package p2p
 
-// "net"
-
-// func TestPeer(t *testing.T) {
-// 	handlers := make(Handlers)
-// 	testProtocol := &TestProtocol{recv: make(chan testMsg)}
-// 	handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
-// 	handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
-// 	addr := &TestAddr{"test:30"}
-// 	conn := NewTestNetworkConnection(addr)
-// 	_, server := SetupTestServer(handlers)
-// 	server.Handshake()
-// 	peer := NewPeer(conn, addr, true, server)
-// 	// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
-// 	peer.Start()
-// 	defer peer.Stop()
-// 	time.Sleep(2 * time.Millisecond)
-// 	if len(conn.Out) != 1 {
-// 		t.Errorf("handshake not sent")
-// 	} else {
-// 		out := conn.Out[0]
-// 		packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
-// 		if bytes.Compare(out, packet) != 0 {
-// 			t.Errorf("incorrect handshake packet %v != %v", out, packet)
-// 		}
-// 	}
-
-// 	packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
-// 	conn.In(0, packet)
-// 	time.Sleep(10 * time.Millisecond)
-
-// 	pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
-// 	if pro.state != handshakeReceived {
-// 		t.Errorf("handshake not received")
-// 	}
-// 	if peer.Port != 30 {
-// 		t.Errorf("port incorrectly set")
-// 	}
-// 	if peer.Id != "peer" {
-// 		t.Errorf("id incorrectly set")
-// 	}
-// 	if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
-// 		t.Errorf("pubkey incorrectly set")
-// 	}
-// 	fmt.Println(peer.Caps)
-// 	if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
-// 		t.Errorf("protocols incorrectly set")
-// 	}
-
-// 	msg := NewMsg(3)
-// 	err := peer.Write("aaa", msg)
-// 	if err != nil {
-// 		t.Errorf("expect no error for known protocol: %v", err)
-// 	} else {
-// 		time.Sleep(1 * time.Millisecond)
-// 		if len(conn.Out) != 2 {
-// 			t.Errorf("msg not written")
-// 		} else {
-// 			out := conn.Out[1]
-// 			packet := Packet(16, 3)
-// 			if bytes.Compare(out, packet) != 0 {
-// 				t.Errorf("incorrect packet %v != %v", out, packet)
-// 			}
-// 		}
-// 	}
-
-// 	msg = NewMsg(2)
-// 	err = peer.Write("ccc", msg)
-// 	if err != nil {
-// 		t.Errorf("expect no error for known protocol: %v", err)
-// 	} else {
-// 		time.Sleep(1 * time.Millisecond)
-// 		if len(conn.Out) != 3 {
-// 			t.Errorf("msg not written")
-// 		} else {
-// 			out := conn.Out[2]
-// 			packet := Packet(21, 2)
-// 			if bytes.Compare(out, packet) != 0 {
-// 				t.Errorf("incorrect packet %v != %v", out, packet)
-// 			}
-// 		}
-// 	}
-
-// 	err = peer.Write("bbb", msg)
-// 	time.Sleep(1 * time.Millisecond)
-// 	if err == nil {
-// 		t.Errorf("expect error for unknown protocol")
-// 	}
-// }
+import (
+	"bufio"
+	"net"
+	"reflect"
+	"testing"
+	"time"
+)
+
+var discard = Protocol{
+	Name:   "discard",
+	Length: 1,
+	Run: func(p *Peer, rw MsgReadWriter) error {
+		for {
+			msg, err := rw.ReadMsg()
+			if err != nil {
+				return err
+			}
+			if err = msg.Discard(); err != nil {
+				return err
+			}
+		}
+	},
+}
+
+func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
+	conn1, conn2 := net.Pipe()
+	id := NewSimpleClientIdentity("test", "0", "0", "public key")
+	peer := newPeer(conn1, protos, nil)
+	peer.ourID = id
+	peer.pubkeyHook = func(*peerAddr) error { return nil }
+	errc := make(chan error, 1)
+	go func() {
+		_, err := peer.loop()
+		errc <- err
+	}()
+	return conn2, peer, errc
+}
+
+func TestPeerProtoReadMsg(t *testing.T) {
+	defer testlog(t).detach()
+
+	done := make(chan struct{})
+	proto := Protocol{
+		Name:   "a",
+		Length: 5,
+		Run: func(peer *Peer, rw MsgReadWriter) error {
+			msg, err := rw.ReadMsg()
+			if err != nil {
+				t.Errorf("read error: %v", err)
+			}
+			if msg.Code != 2 {
+				t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
+			}
+			data, err := msg.Data()
+			if err != nil {
+				t.Errorf("data decoding error: %v", err)
+			}
+			expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
+			if !reflect.DeepEqual(data.Slice(), expdata) {
+				t.Errorf("incorrect msg data %#v", data.Slice())
+			}
+			close(done)
+			return nil
+		},
+	}
+
+	net, peer, errc := testPeer([]Protocol{proto})
+	defer net.Close()
+	peer.startSubprotocols([]Cap{proto.cap()})
+
+	writeMsg(net, NewMsg(18, 1, "000"))
+	select {
+	case <-done:
+	case err := <-errc:
+		t.Errorf("peer returned: %v", err)
+	case <-time.After(2 * time.Second):
+		t.Errorf("receive timeout")
+	}
+}
+
+func TestPeerProtoReadLargeMsg(t *testing.T) {
+	defer testlog(t).detach()
+
+	msgsize := uint32(10 * 1024 * 1024)
+	done := make(chan struct{})
+	proto := Protocol{
+		Name:   "a",
+		Length: 5,
+		Run: func(peer *Peer, rw MsgReadWriter) error {
+			msg, err := rw.ReadMsg()
+			if err != nil {
+				t.Errorf("read error: %v", err)
+			}
+			if msg.Size != msgsize+4 {
+				t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
+			}
+			msg.Discard()
+			close(done)
+			return nil
+		},
+	}
+
+	net, peer, errc := testPeer([]Protocol{proto})
+	defer net.Close()
+	peer.startSubprotocols([]Cap{proto.cap()})
+
+	writeMsg(net, NewMsg(18, make([]byte, msgsize)))
+	select {
+	case <-done:
+	case err := <-errc:
+		t.Errorf("peer returned: %v", err)
+	case <-time.After(2 * time.Second):
+		t.Errorf("receive timeout")
+	}
+}
+
+func TestPeerProtoEncodeMsg(t *testing.T) {
+	defer testlog(t).detach()
+
+	proto := Protocol{
+		Name:   "a",
+		Length: 2,
+		Run: func(peer *Peer, rw MsgReadWriter) error {
+			if err := rw.EncodeMsg(2); err == nil {
+				t.Error("expected error for out-of-range msg code, got nil")
+			}
+			if err := rw.EncodeMsg(1); err != nil {
+				t.Errorf("write error: %v", err)
+			}
+			return nil
+		},
+	}
+	net, peer, _ := testPeer([]Protocol{proto})
+	defer net.Close()
+	peer.startSubprotocols([]Cap{proto.cap()})
+
+	bufr := bufio.NewReader(net)
+	msg, err := readMsg(bufr)
+	if err != nil {
+		t.Errorf("read error: %v", err)
+	}
+	if msg.Code != 17 {
+		t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
+	}
+}
+
+func TestPeerWrite(t *testing.T) {
+	defer testlog(t).detach()
+
+	net, peer, peerErr := testPeer([]Protocol{discard})
+	defer net.Close()
+	peer.startSubprotocols([]Cap{discard.cap()})
+
+	// test write errors
+	if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
+		t.Errorf("expected error for unknown protocol, got nil")
+	}
+	if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
+		t.Errorf("expected error for out-of-range msg code, got nil")
+	} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
+		t.Errorf("wrong error for out-of-range msg code, got %#v", err)
+	}
+
+	// setup for reading the message on the other end
+	read := make(chan struct{})
+	go func() {
+		bufr := bufio.NewReader(net)
+		msg, err := readMsg(bufr)
+		if err != nil {
+			t.Errorf("read error: %v", err)
+		} else if msg.Code != 16 {
+			t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
+		}
+		msg.Discard()
+		close(read)
+	}()
+
+	// test succcessful write
+	if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
+		t.Errorf("expect no error for known protocol: %v", err)
+	}
+	select {
+	case <-read:
+	case err := <-peerErr:
+		t.Fatalf("peer stopped: %v", err)
+	}
+}
+
+func TestPeerActivity(t *testing.T) {
+	// shorten inactivityTimeout while this test is running
+	oldT := inactivityTimeout
+	defer func() { inactivityTimeout = oldT }()
+	inactivityTimeout = 20 * time.Millisecond
+
+	net, peer, peerErr := testPeer([]Protocol{discard})
+	defer net.Close()
+	peer.startSubprotocols([]Cap{discard.cap()})
+
+	sub := peer.activity.Subscribe(time.Time{})
+	defer sub.Unsubscribe()
+
+	for i := 0; i < 6; i++ {
+		writeMsg(net, NewMsg(16))
+		select {
+		case <-sub.Chan():
+		case <-time.After(inactivityTimeout / 2):
+			t.Fatal("no event within ", inactivityTimeout/2)
+		case err := <-peerErr:
+			t.Fatal("peer error", err)
+		}
+	}
+
+	select {
+	case <-time.After(inactivityTimeout * 2):
+	case <-sub.Chan():
+		t.Fatal("got activity event while connection was inactive")
+	case err := <-peerErr:
+		t.Fatal("peer error", err)
+	}
+}

+ 194 - 218
p2p/protocol.go

@@ -3,249 +3,185 @@ package p2p
 import (
 	"bytes"
 	"net"
-	"sort"
 	"time"
 
 	"github.com/ethereum/go-ethereum/ethutil"
 )
 
-// Protocol is implemented by P2P subprotocols.
-type Protocol interface {
-	// Start is called when the protocol becomes active.
-	// It should read and write messages from rw.
-	// Messages must be fully consumed.
-	//
-	// The connection is closed when Start returns. It should return
-	// any protocol-level error (such as an I/O error) that is
-	// encountered.
-	Start(peer *Peer, rw MsgReadWriter) error
+// Protocol represents a P2P subprotocol implementation.
+type Protocol struct {
+	// Name should contain the official protocol name,
+	// often a three-letter word.
+	Name string
 
-	// Offset should return the number of message codes
-	// used by the protocol.
-	Offset() MsgCode
-}
+	// Version should contain the version number of the protocol.
+	Version uint
 
-type MsgReader interface {
-	ReadMsg() (Msg, error)
-}
-
-type MsgWriter interface {
-	WriteMsg(Msg) error
-}
-
-// MsgReadWriter is passed to protocols. Protocol implementations can
-// use it to write messages back to a connected peer.
-type MsgReadWriter interface {
-	MsgReader
-	MsgWriter
-}
+	// Length should contain the number of message codes used
+	// by the protocol.
+	Length uint64
 
-type MsgHandler func(code MsgCode, data *ethutil.Value) error
-
-// MsgLoop reads messages off the given reader and
-// calls the handler function for each decoded message until
-// it returns an error or the peer connection is closed.
-//
-// If a message is larger than the given maximum size, RunProtocol
-// returns an appropriate error.n
-func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error {
-	for {
-		msg, err := r.ReadMsg()
-		if err != nil {
-			return err
-		}
-		if msg.Size > maxsize {
-			return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
-		}
-		value, err := msg.Data()
-		if err != nil {
-			return err
-		}
-		if err := handler(msg.Code, value); err != nil {
-			return err
-		}
-	}
-}
-
-// the ÐΞVp2p base protocol
-type baseProtocol struct {
-	rw   MsgReadWriter
-	peer *Peer
+	// Run is called in a new groutine when the protocol has been
+	// negotiated with a peer. It should read and write messages from
+	// rw. The Payload for each message must be fully consumed.
+	//
+	// The peer connection is closed when Start returns. It should return
+	// any protocol-level error (such as an I/O error) that is
+	// encountered.
+	Run func(peer *Peer, rw MsgReadWriter) error
 }
 
-type bpMsg struct {
-	code MsgCode
-	data *ethutil.Value
+func (p Protocol) cap() Cap {
+	return Cap{p.Name, p.Version}
 }
 
 const (
-	p2pVersion      = 0
-	pingTimeout     = 2 * time.Second
-	pingGracePeriod = 2 * time.Second
+	baseProtocolVersion    = 2
+	baseProtocolLength     = uint64(16)
+	baseProtocolMaxMsgSize = 10 * 1024 * 1024
 )
 
 const (
-	// message codes
-	handshakeMsg = iota
-	discMsg
-	pingMsg
-	pongMsg
-	getPeersMsg
-	peersMsg
+	// devp2p message codes
+	handshakeMsg = 0x00
+	discMsg      = 0x01
+	pingMsg      = 0x02
+	pongMsg      = 0x03
+	getPeersMsg  = 0x04
+	peersMsg     = 0x05
 )
 
-const (
-	baseProtocolOffset     MsgCode = 16
-	baseProtocolMaxMsgSize         = 500 * 1024
-)
-
-type DiscReason byte
+// handshake is the structure of a handshake list.
+type handshake struct {
+	Version    uint64
+	ID         string
+	Caps       []Cap
+	ListenPort uint64
+	NodeID     []byte
+}
 
-const (
-	// Values are given explicitly instead of by iota because these values are
-	// defined by the wire protocol spec; it is easier for humans to ensure
-	// correctness when values are explicit.
-	DiscRequested           = 0x00
-	DiscNetworkError        = 0x01
-	DiscProtocolError       = 0x02
-	DiscUselessPeer         = 0x03
-	DiscTooManyPeers        = 0x04
-	DiscAlreadyConnected    = 0x05
-	DiscIncompatibleVersion = 0x06
-	DiscInvalidIdentity     = 0x07
-	DiscQuitting            = 0x08
-	DiscUnexpectedIdentity  = 0x09
-	DiscSelf                = 0x0a
-	DiscReadTimeout         = 0x0b
-	DiscSubprotocolError    = 0x10
-)
+func (h *handshake) String() string {
+	return h.ID
+}
+func (h *handshake) Pubkey() []byte {
+	return h.NodeID
+}
 
-var discReasonToString = [DiscSubprotocolError + 1]string{
-	DiscRequested:           "Disconnect requested",
-	DiscNetworkError:        "Network error",
-	DiscProtocolError:       "Breach of protocol",
-	DiscUselessPeer:         "Useless peer",
-	DiscTooManyPeers:        "Too many peers",
-	DiscAlreadyConnected:    "Already connected",
-	DiscIncompatibleVersion: "Incompatible P2P protocol version",
-	DiscInvalidIdentity:     "Invalid node identity",
-	DiscQuitting:            "Client quitting",
-	DiscUnexpectedIdentity:  "Unexpected identity",
-	DiscSelf:                "Connected to self",
-	DiscReadTimeout:         "Read timeout",
-	DiscSubprotocolError:    "Subprotocol error",
+// Cap is the structure of a peer capability.
+type Cap struct {
+	Name    string
+	Version uint
 }
 
-func (d DiscReason) String() string {
-	if len(discReasonToString) < int(d) {
-		return "Unknown"
-	}
-	return discReasonToString[d]
+func (cap Cap) RlpData() interface{} {
+	return []interface{}{cap.Name, cap.Version}
 }
 
-func (bp *baseProtocol) Offset() MsgCode {
-	return baseProtocolOffset
+type capsByName []Cap
+
+func (cs capsByName) Len() int           { return len(cs) }
+func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
+func (cs capsByName) Swap(i, j int)      { cs[i], cs[j] = cs[j], cs[i] }
+
+type baseProtocol struct {
+	rw   MsgReadWriter
+	peer *Peer
 }
 
-func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error {
-	bp.peer, bp.rw = peer, rw
+func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
+	bp := &baseProtocol{rw, peer}
 
-	// Do the handshake.
-	// TODO: disconnect is valid before handshake, too.
-	rw.WriteMsg(bp.peer.server.handshakeMsg())
+	// do handshake
+	if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
+		return err
+	}
 	msg, err := rw.ReadMsg()
 	if err != nil {
 		return err
 	}
 	if msg.Code != handshakeMsg {
-		return NewPeerError(ProtocolBreach, " first message must be handshake")
+		return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
 	}
 	data, err := msg.Data()
 	if err != nil {
-		return NewPeerError(InvalidMsg, "%v", err)
+		return newPeerError(errInvalidMsg, "%v", err)
 	}
 	if err := bp.handleHandshake(data); err != nil {
 		return err
 	}
 
-	msgin := make(chan bpMsg)
-	done := make(chan error, 1)
+	// run main loop
+	quit := make(chan error, 1)
 	go func() {
-		done <- MsgLoop(rw, baseProtocolMaxMsgSize,
-			func(code MsgCode, data *ethutil.Value) error {
-				msgin <- bpMsg{code, data}
-				return nil
-			})
+		quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle)
 	}()
-	return bp.loop(msgin, done)
+	return bp.loop(quit)
 }
 
-func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error {
-	logger.Debugf("pingpong keepalive started at %v\n", time.Now())
-	messenger := bp.rw.(*proto).messenger
-	pingTimer := time.NewTimer(pingTimeout)
-	pinged := true
+var pingTimeout = 2 * time.Second
+
+func (bp *baseProtocol) loop(quit <-chan error) error {
+	ping := time.NewTimer(pingTimeout)
+	activity := bp.peer.activity.Subscribe(time.Time{})
+	lastActive := time.Time{}
+	defer ping.Stop()
+	defer activity.Unsubscribe()
 
-	for {
+	getPeersTick := time.NewTicker(10 * time.Second)
+	defer getPeersTick.Stop()
+	err := bp.rw.EncodeMsg(getPeersMsg)
+
+	for err == nil {
 		select {
-		case msg := <-msgin:
-			if err := bp.handle(msg.code, msg.data); err != nil {
-				return err
-			}
-		case err := <-quit:
+		case err = <-quit:
 			return err
-		case <-messenger.pulse:
-			pingTimer.Reset(pingTimeout)
-			pinged = false
-		case <-pingTimer.C:
-			if pinged {
-				return NewPeerError(PingTimeout, "")
+		case <-getPeersTick.C:
+			err = bp.rw.EncodeMsg(getPeersMsg)
+		case event := <-activity.Chan():
+			ping.Reset(pingTimeout)
+			lastActive = event.(time.Time)
+		case t := <-ping.C:
+			if lastActive.Add(pingTimeout * 2).Before(t) {
+				err = newPeerError(errPingTimeout, "")
+			} else if lastActive.Add(pingTimeout).Before(t) {
+				err = bp.rw.EncodeMsg(pingMsg)
 			}
-			logger.Debugf("pinging at %v\n", time.Now())
-			if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil {
-				return NewPeerError(WriteError, "%v", err)
-			}
-			pinged = true
-			pingTimer.Reset(pingTimeout)
 		}
 	}
+	return err
 }
 
-func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error {
+func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
 	switch code {
 	case handshakeMsg:
-		return NewPeerError(ProtocolBreach, " extra handshake received")
+		return newPeerError(errProtocolBreach, "extra handshake received")
 
 	case discMsg:
-		logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint()))
-		bp.peer.server.PeerDisconnect() <- DisconnectRequest{
-			addr:   bp.peer.Address,
-			reason: DiscRequested,
-		}
+		bp.peer.Disconnect(DiscReason(data.Get(0).Uint()))
+		return nil
 
 	case pingMsg:
-		return bp.rw.WriteMsg(NewMsg(pongMsg))
+		return bp.rw.EncodeMsg(pongMsg)
 
 	case pongMsg:
-		// reply for ping
 
 	case getPeersMsg:
-		// Peer asked for list of connected peers.
-		peersRLP := bp.peer.server.encodedPeerList()
-		if peersRLP != nil {
-			msg := Msg{
-				Code:    peersMsg,
-				Size:    uint32(len(peersRLP)),
-				Payload: bytes.NewReader(peersRLP),
-			}
-			return bp.rw.WriteMsg(msg)
+		peers := bp.peerList()
+		// this is dangerous. the spec says that we should _delay_
+		// sending the response if no new information is available.
+		// this means that would need to send a response later when
+		// new peers become available.
+		//
+		// TODO: add event mechanism to notify baseProtocol for new peers
+		if len(peers) > 0 {
+			return bp.rw.EncodeMsg(peersMsg, peers)
 		}
 
 	case peersMsg:
 		bp.handlePeers(data)
 
 	default:
-		return NewPeerError(InvalidMsgCode, "unknown message code %v", code)
+		return newPeerError(errInvalidMsgCode, "unknown message code %v", code)
 	}
 	return nil
 }
@@ -253,62 +189,102 @@ func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error {
 func (bp *baseProtocol) handlePeers(data *ethutil.Value) {
 	it := data.NewIterator()
 	for it.Next() {
-		ip := net.IP(it.Value().Get(0).Bytes())
-		port := it.Value().Get(1).Uint()
-		address := &net.TCPAddr{IP: ip, Port: int(port)}
-		go bp.peer.server.PeerConnect(address)
+		addr := &peerAddr{
+			IP:     net.IP(it.Value().Get(0).Bytes()),
+			Port:   it.Value().Get(1).Uint(),
+			Pubkey: it.Value().Get(2).Bytes(),
+		}
+		bp.peer.Debugf("received peer suggestion: %v", addr)
+		bp.peer.newPeerAddr <- addr
 	}
 }
 
 func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
-	var (
-		remoteVersion = c.Get(0).Uint()
-		id            = c.Get(1).Str()
-		caps          = c.Get(2)
-		port          = c.Get(3).Uint()
-		pubkey        = c.Get(4).Bytes()
-	)
-	// Check correctness of p2p protocol version
-	if remoteVersion != p2pVersion {
-		return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion)
+	hs := handshake{
+		Version:    c.Get(0).Uint(),
+		ID:         c.Get(1).Str(),
+		Caps:       nil, // decoded below
+		ListenPort: c.Get(3).Uint(),
+		NodeID:     c.Get(4).Bytes(),
 	}
-
-	// Handle the pub key (validation, uniqueness)
-	if len(pubkey) == 0 {
-		return NewPeerError(PubkeyMissing, "not supplied in handshake.")
+	if hs.Version != baseProtocolVersion {
+		return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
+			baseProtocolVersion, hs.Version)
 	}
-
-	if len(pubkey) != 64 {
-		return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
+	if len(hs.NodeID) == 0 {
+		return newPeerError(errPubkeyMissing, "")
+	}
+	if len(hs.NodeID) != 64 {
+		return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
+	}
+	if da := bp.peer.dialAddr; da != nil {
+		// verify that the peer we wanted to connect to
+		// actually holds the target public key.
+		if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
+			return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
+		}
+	}
+	pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
+	if err := bp.peer.pubkeyHook(pa); err != nil {
+		return newPeerError(errPubkeyForbidden, "%v", err)
+	}
+	capsIt := c.Get(2).NewIterator()
+	for capsIt.Next() {
+		cap := capsIt.Value()
+		name := cap.Get(0).Str()
+		if name != "" {
+			hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())})
+		}
 	}
 
-	// self connect detection
-	if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 {
-		return NewPeerError(PubkeyForbidden, "not allowed to connect to self")
+	var addr *peerAddr
+	if hs.ListenPort != 0 {
+		addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
+		addr.Port = hs.ListenPort
 	}
+	bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
+	bp.peer.startSubprotocols(hs.Caps)
+	return nil
+}
 
-	// register pubkey on server. this also sets the pubkey on the peer (need lock)
-	if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil {
-		return NewPeerError(PubkeyForbidden, err.Error())
+func (bp *baseProtocol) handshakeMsg() Msg {
+	var (
+		port uint64
+		caps []interface{}
+	)
+	if bp.peer.ourListenAddr != nil {
+		port = bp.peer.ourListenAddr.Port
 	}
+	for _, proto := range bp.peer.protocols {
+		caps = append(caps, proto.cap())
+	}
+	return NewMsg(handshakeMsg,
+		baseProtocolVersion,
+		bp.peer.ourID.String(),
+		caps,
+		port,
+		bp.peer.ourID.Pubkey()[1:],
+	)
+}
 
-	// check port
-	if bp.peer.Inbound {
-		uint16port := uint16(port)
-		if bp.peer.Port > 0 && bp.peer.Port != uint16port {
-			return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port)
-		} else {
-			bp.peer.Port = uint16port
+func (bp *baseProtocol) peerList() []ethutil.RlpEncodable {
+	peers := bp.peer.otherPeers()
+	ds := make([]ethutil.RlpEncodable, 0, len(peers))
+	for _, p := range peers {
+		p.infolock.Lock()
+		addr := p.listenAddr
+		p.infolock.Unlock()
+		// filter out this peer and peers that are not listening or
+		// have not completed the handshake.
+		// TODO: track previously sent peers and exclude them as well.
+		if p == bp.peer || addr == nil {
+			continue
 		}
+		ds = append(ds, addr)
 	}
-
-	capsIt := caps.NewIterator()
-	for capsIt.Next() {
-		cap := capsIt.Value().Str()
-		bp.peer.Caps = append(bp.peer.Caps, cap)
+	ourAddr := bp.peer.ourListenAddr
+	if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
+		ds = append(ds, ourAddr)
 	}
-	sort.Strings(bp.peer.Caps)
-	bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps)
-	bp.peer.Id = id
-	return nil
+	return ds
 }

+ 346 - 367
p2p/server.go

@@ -2,155 +2,101 @@ package p2p
 
 import (
 	"bytes"
+	"errors"
 	"fmt"
 	"net"
-	"sort"
-	"strconv"
 	"sync"
 	"time"
 
-	logpkg "github.com/ethereum/go-ethereum/logger"
+	"github.com/ethereum/go-ethereum/logger"
 )
 
 const (
-	outboundAddressPoolSize = 10
-	disconnectGracePeriod   = 2
+	outboundAddressPoolSize   = 500
+	defaultDialTimeout        = 10 * time.Second
+	portMappingUpdateInterval = 15 * time.Minute
+	portMappingTimeout        = 20 * time.Minute
 )
 
-type Blacklist interface {
-	Get([]byte) (bool, error)
-	Put([]byte) error
-	Delete([]byte) error
-	Exists(pubkey []byte) (ok bool)
-}
-
-type BlacklistMap struct {
-	blacklist map[string]bool
-	lock      sync.RWMutex
-}
-
-func NewBlacklist() *BlacklistMap {
-	return &BlacklistMap{
-		blacklist: make(map[string]bool),
-	}
-}
-
-func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
-	self.lock.RLock()
-	defer self.lock.RUnlock()
-	v, ok := self.blacklist[string(pubkey)]
-	var err error
-	if !ok {
-		err = fmt.Errorf("not found")
-	}
-	return v, err
-}
-
-func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
-	self.lock.RLock()
-	defer self.lock.RUnlock()
-	_, ok = self.blacklist[string(pubkey)]
-	return
-}
-
-func (self *BlacklistMap) Put(pubkey []byte) error {
-	self.lock.RLock()
-	defer self.lock.RUnlock()
-	self.blacklist[string(pubkey)] = true
-	return nil
-}
-
-func (self *BlacklistMap) Delete(pubkey []byte) error {
-	self.lock.RLock()
-	defer self.lock.RUnlock()
-	delete(self.blacklist, string(pubkey))
-	return nil
-}
+var srvlog = logger.NewLogger("P2P Server")
 
+// Server manages all peer connections.
+//
+// The fields of Server are used as configuration parameters.
+// You should set them before starting the Server. Fields may not be
+// modified while the server is running.
 type Server struct {
-	network   Network
-	listening bool //needed?
-	dialing   bool //needed?
-	closed    bool
-	identity  ClientIdentity
-	addr      net.Addr
-	port      uint16
-	protocols []string
-
-	quit      chan chan bool
-	peersLock sync.RWMutex
-
-	maxPeers           int
-	peers              []*Peer
-	peerSlots          chan int
-	peersTable         map[string]int
-	peerCount          int
-	cachedEncodedPeers []byte
-
-	peerConnect    chan net.Addr
-	peerDisconnect chan DisconnectRequest
-	blacklist      Blacklist
-	handlers       Handlers
-}
-
-var logger = logpkg.NewLogger("P2P")
-
-func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server {
-	// get alphabetical list of protocol names from handlers map
-	protocols := []string{}
-	for protocol := range handlers {
-		protocols = append(protocols, protocol)
-	}
-	sort.Strings(protocols)
-
-	_, port, _ := net.SplitHostPort(addr.String())
-	intport, _ := strconv.Atoi(port)
-
-	self := &Server{
-		// NewSimpleClientIdentity(clientIdentifier, version, customIdentifier)
-		network:   network,
-		identity:  identity,
-		addr:      addr,
-		port:      uint16(intport),
-		protocols: protocols,
-
-		quit: make(chan chan bool),
-
-		maxPeers:   maxPeers,
-		peers:      make([]*Peer, maxPeers),
-		peerSlots:  make(chan int, maxPeers),
-		peersTable: make(map[string]int),
+	// This field must be set to a valid client identity.
+	Identity ClientIdentity
+
+	// MaxPeers is the maximum number of peers that can be
+	// connected. It must be greater than zero.
+	MaxPeers int
+
+	// Protocols should contain the protocols supported
+	// by the server. Matching protocols are launched for
+	// each peer.
+	Protocols []Protocol
+
+	// If Blacklist is set to a non-nil value, the given Blacklist
+	// is used to verify peer connections.
+	Blacklist Blacklist
+
+	// If ListenAddr is set to a non-nil address, the server
+	// will listen for incoming connections.
+	//
+	// If the port is zero, the operating system will pick a port. The
+	// ListenAddr field will be updated with the actual address when
+	// the server is started.
+	ListenAddr string
+
+	// If set to a non-nil value, the given NAT port mapper
+	// is used to make the listening port available to the
+	// Internet.
+	NAT NAT
+
+	// If Dialer is set to a non-nil value, the given Dialer
+	// is used to dial outbound peer connections.
+	Dialer *net.Dialer
+
+	// If NoDial is true, the server will not dial any peers.
+	NoDial bool
+
+	// Hook for testing. This is useful because we can inhibit
+	// the whole protocol stack.
+	newPeerFunc peerFunc
 
-		peerConnect:    make(chan net.Addr, outboundAddressPoolSize),
-		peerDisconnect: make(chan DisconnectRequest),
-		blacklist:      blacklist,
-
-		handlers: handlers,
-	}
-	for i := 0; i < maxPeers; i++ {
-		self.peerSlots <- i // fill up with indexes
-	}
-	return self
+	lock      sync.RWMutex
+	running   bool
+	listener  net.Listener
+	laddr     *net.TCPAddr // real listen addr
+	peers     []*Peer
+	peerSlots chan int
+	peerCount int
+
+	quit           chan struct{}
+	wg             sync.WaitGroup
+	peerConnect    chan *peerAddr
+	peerDisconnect chan *Peer
 }
 
-func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) {
-	addr, err = self.network.NewAddr(host, port)
-	return
-}
+// NAT is implemented by NAT traversal methods.
+type NAT interface {
+	GetExternalAddress() (net.IP, error)
+	AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
+	DeletePortMapping(protocol string, extport, intport int) error
 
-func (self *Server) ParseAddr(address string) (addr net.Addr, err error) {
-	addr, err = self.network.ParseAddr(address)
-	return
+	// Should return name of the method.
+	String() string
 }
 
-func (self *Server) ClientIdentity() ClientIdentity {
-	return self.identity
-}
+type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
 
-func (self *Server) Peers() (peers []*Peer) {
-	self.peersLock.RLock()
-	defer self.peersLock.RUnlock()
-	for _, peer := range self.peers {
+// Peers returns all connected peers.
+func (srv *Server) Peers() (peers []*Peer) {
+	srv.lock.RLock()
+	defer srv.lock.RUnlock()
+	for _, peer := range srv.peers {
 		if peer != nil {
 			peers = append(peers, peer)
 		}
@@ -158,331 +104,364 @@ func (self *Server) Peers() (peers []*Peer) {
 	return
 }
 
-func (self *Server) PeerCount() int {
-	self.peersLock.RLock()
-	defer self.peersLock.RUnlock()
-	return self.peerCount
+// PeerCount returns the number of connected peers.
+func (srv *Server) PeerCount() int {
+	srv.lock.RLock()
+	defer srv.lock.RUnlock()
+	return srv.peerCount
 }
 
-func (self *Server) PeerConnect(addr net.Addr) {
-	// TODO: should buffer, filter and uniq
-	// send GetPeersMsg if not blocking
+// SuggestPeer injects an address into the outbound address pool.
+func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
 	select {
-	case self.peerConnect <- addr: // not enough peers
-		self.Broadcast("", getPeersMsg)
-	default: // we dont care
+	case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}:
+	default: // don't block
 	}
 }
 
-func (self *Server) PeerDisconnect() chan DisconnectRequest {
-	return self.peerDisconnect
-}
-
-func (self *Server) Blacklist() Blacklist {
-	return self.blacklist
-}
-
-func (self *Server) Handlers() Handlers {
-	return self.handlers
-}
-
-func (self *Server) Broadcast(protocol string, code MsgCode, data ...interface{}) {
+// Broadcast sends an RLP-encoded message to all connected peers.
+// This method is deprecated and will be removed later.
+func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
 	var payload []byte
 	if data != nil {
 		payload = encodePayload(data...)
 	}
-	self.peersLock.RLock()
-	defer self.peersLock.RUnlock()
-	for _, peer := range self.peers {
+	srv.lock.RLock()
+	defer srv.lock.RUnlock()
+	for _, peer := range srv.peers {
 		if peer != nil {
 			var msg = Msg{Code: code}
 			if data != nil {
 				msg.Payload = bytes.NewReader(payload)
 				msg.Size = uint32(len(payload))
 			}
-			peer.messenger.writeProtoMsg(protocol, msg)
+			peer.writeProtoMsg(protocol, msg)
 		}
 	}
 }
 
-// Start the server
-func (self *Server) Start(listen bool, dial bool) {
-	self.network.Start()
-	if listen {
-		listener, err := self.network.Listener(self.addr)
-		if err != nil {
-			logger.Warnf("Error initializing listener: %v", err)
-			logger.Warnf("Connection listening disabled")
-			self.listening = false
-		} else {
-			self.listening = true
-			logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr())
-			go self.inboundPeerHandler(listener)
-		}
+// Start starts running the server.
+// Servers can be re-used and started again after stopping.
+func (srv *Server) Start() (err error) {
+	srv.lock.Lock()
+	defer srv.lock.Unlock()
+	if srv.running {
+		return errors.New("server already running")
+	}
+	srvlog.Infoln("Starting Server")
+
+	// initialize fields
+	if srv.Identity == nil {
+		return fmt.Errorf("Server.Identity must be set to a non-nil identity")
 	}
-	if dial {
-		dialer, err := self.network.Dialer(self.addr)
-		if err != nil {
-			logger.Warnf("Error initializing dialer: %v", err)
-			logger.Warnf("Connection dialout disabled")
-			self.dialing = false
-		} else {
-			self.dialing = true
-			logger.Infoln("Dial peers watching outbound address pool")
-			go self.outboundPeerHandler(dialer)
+	if srv.MaxPeers <= 0 {
+		return fmt.Errorf("Server.MaxPeers must be > 0")
+	}
+	srv.quit = make(chan struct{})
+	srv.peers = make([]*Peer, srv.MaxPeers)
+	srv.peerSlots = make(chan int, srv.MaxPeers)
+	srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
+	srv.peerDisconnect = make(chan *Peer)
+	if srv.newPeerFunc == nil {
+		srv.newPeerFunc = newServerPeer
+	}
+	if srv.Blacklist == nil {
+		srv.Blacklist = NewBlacklist()
+	}
+	if srv.Dialer == nil {
+		srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
+	}
+
+	if srv.ListenAddr != "" {
+		if err := srv.startListening(); err != nil {
+			return err
 		}
 	}
-	logger.Infoln("server started")
+	if !srv.NoDial {
+		srv.wg.Add(1)
+		go srv.dialLoop()
+	}
+	if srv.NoDial && srv.ListenAddr == "" {
+		srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
+	}
+
+	// make all slots available
+	for i := range srv.peers {
+		srv.peerSlots <- i
+	}
+	// note: discLoop is not part of WaitGroup
+	go srv.discLoop()
+	srv.running = true
+	return nil
 }
 
-func (self *Server) Stop() {
-	logger.Infoln("server stopping...")
-	// // quit one loop if dialing
-	if self.dialing {
-		logger.Infoln("stop dialout...")
-		dialq := make(chan bool)
-		self.quit <- dialq
-		<-dialq
-		fmt.Println("quit another")
-	}
-	// quit the other loop if listening
-	if self.listening {
-		logger.Infoln("stop listening...")
-		listenq := make(chan bool)
-		self.quit <- listenq
-		<-listenq
-		fmt.Println("quit one")
-	}
-
-	fmt.Println("quit waited")
-
-	logger.Infoln("stopping peers...")
-	peers := []net.Addr{}
-	self.peersLock.RLock()
-	self.closed = true
-	for _, peer := range self.peers {
-		if peer != nil {
-			peers = append(peers, peer.Address)
-		}
+func (srv *Server) startListening() error {
+	listener, err := net.Listen("tcp", srv.ListenAddr)
+	if err != nil {
+		return err
+	}
+	srv.ListenAddr = listener.Addr().String()
+	srv.laddr = listener.Addr().(*net.TCPAddr)
+	srv.listener = listener
+	srv.wg.Add(1)
+	go srv.listenLoop()
+	if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
+		srv.wg.Add(1)
+		go srv.natLoop(srv.laddr.Port)
+	}
+	return nil
+}
+
+// Stop terminates the server and all active peer connections.
+// It blocks until all active connections have been closed.
+func (srv *Server) Stop() {
+	srv.lock.Lock()
+	if !srv.running {
+		srv.lock.Unlock()
+		return
 	}
-	self.peersLock.RUnlock()
-	for _, address := range peers {
-		go self.removePeer(DisconnectRequest{
-			addr:   address,
-			reason: DiscQuitting,
-		})
+	srv.running = false
+	srv.lock.Unlock()
+
+	srvlog.Infoln("Stopping server")
+	if srv.listener != nil {
+		// this unblocks listener Accept
+		srv.listener.Close()
+	}
+	close(srv.quit)
+	for _, peer := range srv.Peers() {
+		peer.Disconnect(DiscQuitting)
 	}
+	srv.wg.Wait()
+
 	// wait till they actually disconnect
-	// this is checked by draining the peerSlots (slots are released back if a peer is removed)
-	i := 0
-	fmt.Println("draining peers")
+	// this is checked by claiming all peerSlots.
+	// slots become available as the peers disconnect.
+	for i := 0; i < cap(srv.peerSlots); i++ {
+		<-srv.peerSlots
+	}
+	// terminate discLoop
+	close(srv.peerDisconnect)
+}
+
+func (srv *Server) discLoop() {
+	for peer := range srv.peerDisconnect {
+		// peer has just disconnected. free up its slot.
+		srvlog.Infof("%v is gone", peer)
+		srv.peerSlots <- peer.slot
+		srv.lock.Lock()
+		srv.peers[peer.slot] = nil
+		srv.lock.Unlock()
+	}
+}
 
-FOR:
+// main loop for adding connections via listening
+func (srv *Server) listenLoop() {
+	defer srv.wg.Done()
+
+	srvlog.Infoln("Listening on", srv.listener.Addr())
 	for {
 		select {
-		case slot := <-self.peerSlots:
-			i++
-			fmt.Printf("%v: found slot %v\n", i, slot)
-			if i == self.maxPeers {
-				break FOR
+		case slot := <-srv.peerSlots:
+			conn, err := srv.listener.Accept()
+			if err != nil {
+				srv.peerSlots <- slot
+				return
 			}
+			srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
+			srv.addPeer(conn, nil, slot)
+		case <-srv.quit:
+			return
 		}
 	}
-	logger.Infoln("server stopped")
 }
 
-// main loop for adding connections via listening
-func (self *Server) inboundPeerHandler(listener net.Listener) {
+func (srv *Server) natLoop(port int) {
+	defer srv.wg.Done()
 	for {
+		srv.updatePortMapping(port)
 		select {
-		case slot := <-self.peerSlots:
-			go self.connectInboundPeer(listener, slot)
-		case errc := <-self.quit:
-			listener.Close()
-			fmt.Println("quit listenloop")
-			errc <- true
+		case <-time.After(portMappingUpdateInterval):
+			// one more round
+		case <-srv.quit:
+			srv.removePortMapping(port)
 			return
 		}
 	}
 }
 
-// main loop for adding outbound peers based on peerConnect address pool
-// this same loop handles peer disconnect requests as well
-func (self *Server) outboundPeerHandler(dialer Dialer) {
-	// addressChan initially set to nil (only watches peerConnect if we need more peers)
-	var addressChan chan net.Addr
-	slots := self.peerSlots
-	var slot *int
+func (srv *Server) updatePortMapping(port int) {
+	srvlog.Infoln("Attempting to map port", port, "with", srv.NAT)
+	err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout)
+	if err != nil {
+		srvlog.Errorln("Port mapping error:", err)
+		return
+	}
+	extip, err := srv.NAT.GetExternalAddress()
+	if err != nil {
+		srvlog.Errorln("Error getting external IP:", err)
+		return
+	}
+	srv.lock.Lock()
+	extaddr := *(srv.listener.Addr().(*net.TCPAddr))
+	extaddr.IP = extip
+	srvlog.Infoln("Mapped port, external addr is", &extaddr)
+	srv.laddr = &extaddr
+	srv.lock.Unlock()
+}
+
+func (srv *Server) removePortMapping(port int) {
+	srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT)
+	srv.NAT.DeletePortMapping("tcp", port, port)
+}
+
+func (srv *Server) dialLoop() {
+	defer srv.wg.Done()
+	var (
+		suggest chan *peerAddr
+		slot    *int
+		slots   = srv.peerSlots
+	)
 	for {
 		select {
 		case i := <-slots:
 			// we need a peer in slot i, slot reserved
 			slot = &i
 			// now we can watch for candidate peers in the next loop
-			addressChan = self.peerConnect
+			suggest = srv.peerConnect
 			// do not consume more until candidate peer is found
 			slots = nil
-		case address := <-addressChan:
+
+		case desc := <-suggest:
 			// candidate peer found, will dial out asyncronously
 			// if connection fails slot will be released
-			go self.connectOutboundPeer(dialer, address, *slot)
+			go srv.dialPeer(desc, *slot)
 			// we can watch if more peers needed in the next loop
-			slots = self.peerSlots
+			slots = srv.peerSlots
 			// until then we dont care about candidate peers
-			addressChan = nil
-		case request := <-self.peerDisconnect:
-			go self.removePeer(request)
-		case errc := <-self.quit:
-			if addressChan != nil && slot != nil {
-				self.peerSlots <- *slot
+			suggest = nil
+
+		case <-srv.quit:
+			// give back the currently reserved slot
+			if slot != nil {
+				srv.peerSlots <- *slot
 			}
-			fmt.Println("quit dialloop")
-			errc <- true
 			return
 		}
 	}
 }
 
-// check if peer address already connected
-func (self *Server) isConnected(address net.Addr) bool {
-	self.peersLock.RLock()
-	defer self.peersLock.RUnlock()
-	_, found := self.peersTable[address.String()]
-	return found
-}
-
-// connect to peer via listener.Accept()
-func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
-	var address net.Addr
-	conn, err := listener.Accept()
-	if err != nil {
-		logger.Debugln(err)
-		self.peerSlots <- slot
-		return
-	}
-	address = conn.RemoteAddr()
-	// XXX: this won't work because the remote socket
-	// address does not identify the peer. we should
-	// probably get rid of this check and rely on public
-	// key detection in the base protocol.
-	if self.isConnected(address) {
-		conn.Close()
-		self.peerSlots <- slot
-		return
-	}
-	fmt.Printf("adding %v\n", address)
-	go self.addPeer(conn, address, true, slot)
-}
-
 // connect to peer via dial out
-func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) {
-	if self.isConnected(address) {
-		return
-	}
-	conn, err := dialer.Dial(address.Network(), address.String())
+func (srv *Server) dialPeer(desc *peerAddr, slot int) {
+	srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
+	conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
 	if err != nil {
-		self.peerSlots <- slot
+		srvlog.Errorf("Dial error: %v", err)
+		srv.peerSlots <- slot
 		return
 	}
-	go self.addPeer(conn, address, false, slot)
+	go srv.addPeer(conn, desc, slot)
 }
 
 // creates the new peer object and inserts it into its slot
-func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) *Peer {
-	self.peersLock.Lock()
-	defer self.peersLock.Unlock()
-	if self.closed {
-		fmt.Println("oopsy, not no longer need peer")
-		conn.Close()           //oopsy our bad
-		self.peerSlots <- slot // release slot
+func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
+	srv.lock.Lock()
+	defer srv.lock.Unlock()
+	if !srv.running {
+		conn.Close()
+		srv.peerSlots <- slot // release slot
 		return nil
 	}
-	logger.Infoln("adding new peer", address)
-	peer := NewPeer(conn, address, inbound, self)
-	self.peers[slot] = peer
-	self.peersTable[address.String()] = slot
-	self.peerCount++
-	self.cachedEncodedPeers = nil
-	fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
-	peer.Start()
+	peer := srv.newPeerFunc(srv, conn, desc)
+	peer.slot = slot
+	srv.peers[slot] = peer
+	srv.peerCount++
+	go func() { peer.loop(); srv.peerDisconnect <- peer }()
 	return peer
 }
 
 // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
-func (self *Server) removePeer(request DisconnectRequest) {
-	self.peersLock.Lock()
-
-	address := request.addr
-	slot := self.peersTable[address.String()]
-	peer := self.peers[slot]
-	fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot)
-	if peer == nil {
-		logger.Debugf("already removed peer on %v", address)
-		self.peersLock.Unlock()
+func (srv *Server) removePeer(peer *Peer) {
+	srv.lock.Lock()
+	defer srv.lock.Unlock()
+	srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot)
+	if srv.peers[peer.slot] != peer {
+		srvlog.Warnln("Invalid peer to remove:", peer)
 		return
 	}
 	// remove from list and index
-	self.peerCount--
-	self.peers[slot] = nil
-	delete(self.peersTable, address.String())
-	self.cachedEncodedPeers = nil
-	fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
-	self.peersLock.Unlock()
-
-	// sending disconnect message
-	disconnectMsg := NewMsg(discMsg, request.reason)
-	peer.Write("", disconnectMsg)
-	// be nice and wait
-	time.Sleep(disconnectGracePeriod * time.Second)
-	// switch off peer and close connections etc.
-	fmt.Println("stopping peer")
-	peer.Stop()
-	fmt.Println("stopped peer")
+	srv.peerCount--
+	srv.peers[peer.slot] = nil
 	// release slot to signal need for a new peer, last!
-	self.peerSlots <- slot
+	srv.peerSlots <- peer.slot
 }
 
-// encodedPeerList returns an RLP-encoded list of peers.
-// the returned slice will be nil if there are no peers.
-func (self *Server) encodedPeerList() []byte {
-	// TODO: memoize and reset when peers change
-	self.peersLock.RLock()
-	defer self.peersLock.RUnlock()
-	if self.cachedEncodedPeers == nil && self.peerCount > 0 {
-		var peerData []interface{}
-		for _, i := range self.peersTable {
-			peer := self.peers[i]
-			peerData = append(peerData, peer.Encode())
+func (srv *Server) verifyPeer(addr *peerAddr) error {
+	if srv.Blacklist.Exists(addr.Pubkey) {
+		return errors.New("blacklisted")
+	}
+	if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
+		return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
+	}
+	srv.lock.RLock()
+	defer srv.lock.RUnlock()
+	for _, peer := range srv.peers {
+		if peer != nil {
+			id := peer.Identity()
+			if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
+				return errors.New("already connected")
+			}
 		}
-		self.cachedEncodedPeers = encodePayload(peerData)
 	}
-	return self.cachedEncodedPeers
+	return nil
 }
 
-// fix handshake message to push to peers
-func (self *Server) handshakeMsg() Msg {
-	return NewMsg(handshakeMsg,
-		p2pVersion,
-		[]byte(self.identity.String()),
-		[]interface{}{self.protocols},
-		self.port,
-		self.identity.Pubkey()[1:],
-	)
+type Blacklist interface {
+	Get([]byte) (bool, error)
+	Put([]byte) error
+	Delete([]byte) error
+	Exists(pubkey []byte) (ok bool)
+}
+
+type BlacklistMap struct {
+	blacklist map[string]bool
+	lock      sync.RWMutex
 }
 
-func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error {
-	// Check for blacklisting
-	if self.blacklist.Exists(pubkey) {
-		return fmt.Errorf("blacklisted")
+func NewBlacklist() *BlacklistMap {
+	return &BlacklistMap{
+		blacklist: make(map[string]bool),
 	}
+}
 
-	self.peersLock.RLock()
-	defer self.peersLock.RUnlock()
-	for _, peer := range self.peers {
-		if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 {
-			return fmt.Errorf("already connected")
-		}
+func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
+	self.lock.RLock()
+	defer self.lock.RUnlock()
+	v, ok := self.blacklist[string(pubkey)]
+	var err error
+	if !ok {
+		err = fmt.Errorf("not found")
 	}
-	candidate.Pubkey = pubkey
+	return v, err
+}
+
+func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
+	self.lock.RLock()
+	defer self.lock.RUnlock()
+	_, ok = self.blacklist[string(pubkey)]
+	return
+}
+
+func (self *BlacklistMap) Put(pubkey []byte) error {
+	self.lock.RLock()
+	defer self.lock.RUnlock()
+	self.blacklist[string(pubkey)] = true
+	return nil
+}
+
+func (self *BlacklistMap) Delete(pubkey []byte) error {
+	self.lock.RLock()
+	defer self.lock.RUnlock()
+	delete(self.blacklist, string(pubkey))
 	return nil
 }

+ 130 - 258
p2p/server_test.go

@@ -1,289 +1,161 @@
 package p2p
 
 import (
-	"fmt"
+	"bytes"
 	"io"
 	"net"
+	"sync"
 	"testing"
 	"time"
 )
 
-type TestNetwork struct {
-	connections map[string]*TestNetworkConnection
-	dialer      Dialer
-	maxinbound  int
-}
-
-func NewTestNetwork(maxinbound int) *TestNetwork {
-	connections := make(map[string]*TestNetworkConnection)
-	return &TestNetwork{
-		connections: connections,
-		dialer:      &TestDialer{connections},
-		maxinbound:  maxinbound,
+func startTestServer(t *testing.T, pf peerFunc) *Server {
+	server := &Server{
+		Identity:    NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"),
+		MaxPeers:    10,
+		ListenAddr:  "127.0.0.1:0",
+		newPeerFunc: pf,
 	}
-}
-
-func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) {
-	return self.dialer, nil
-}
-
-func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
-	return &TestListener{
-		connections: self.connections,
-		addr:        addr,
-		max:         self.maxinbound,
-		close:       make(chan struct{}),
-	}, nil
-}
-
-func (self *TestNetwork) Start() error {
-	return nil
-}
-
-func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) {
-	return
-}
-
-func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) {
-	return
-}
-
-type TestAddr struct {
-	name string
-}
-
-func (self *TestAddr) String() string {
-	return self.name
-}
-
-func (*TestAddr) Network() string {
-	return "test"
-}
-
-type TestDialer struct {
-	connections map[string]*TestNetworkConnection
-}
-
-func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) {
-	address := &TestAddr{addr}
-	tconn := NewTestNetworkConnection(address)
-	self.connections[addr] = tconn
-	conn = net.Conn(tconn)
-	return
-}
-
-type TestListener struct {
-	connections map[string]*TestNetworkConnection
-	addr        net.Addr
-	max         int
-	i           int
-	close       chan struct{}
-}
-
-func (self *TestListener) Accept() (net.Conn, error) {
-	self.i++
-	if self.i > self.max {
-		<-self.close
-		return nil, io.EOF
+	if err := server.Start(); err != nil {
+		t.Fatalf("Could not start server: %v", err)
 	}
-	addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
-	tconn := NewTestNetworkConnection(addr)
-	key := tconn.RemoteAddr().String()
-	self.connections[key] = tconn
-	fmt.Printf("accepted connection from: %v \n", addr)
-	return tconn, nil
-}
-
-func (self *TestListener) Close() error {
-	close(self.close)
-	return nil
-}
-
-func (self *TestListener) Addr() net.Addr {
-	return self.addr
+	return server
 }
 
-type TestNetworkConnection struct {
-	in      chan []byte
-	close   chan struct{}
-	current []byte
-	Out     [][]byte
-	addr    net.Addr
-}
+func TestServerListen(t *testing.T) {
+	defer testlog(t).detach()
 
-func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
-	return &TestNetworkConnection{
-		in:      make(chan []byte),
-		close:   make(chan struct{}),
-		current: []byte{},
-		Out:     [][]byte{},
-		addr:    addr,
+	// start the test server
+	connected := make(chan *Peer)
+	srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
+		if conn == nil {
+			t.Error("peer func called with nil conn")
+		}
+		if dialAddr != nil {
+			t.Error("peer func called with non-nil dialAddr")
+		}
+		peer := newPeer(conn, nil, dialAddr)
+		connected <- peer
+		return peer
+	})
+	defer close(connected)
+	defer srv.Stop()
+
+	// dial the test server
+	conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
+	if err != nil {
+		t.Fatalf("could not dial: %v", err)
 	}
-}
+	defer conn.Close()
 
-func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
-	time.Sleep(latency)
-	for _, s := range packets {
-		self.in <- s
+	select {
+	case peer := <-connected:
+		if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
+			t.Errorf("peer started with wrong conn: got %v, want %v",
+				peer.conn.LocalAddr(), conn.RemoteAddr())
+		}
+	case <-time.After(1 * time.Second):
+		t.Error("server did not accept within one second")
 	}
 }
 
-func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
-	if len(self.current) == 0 {
-		var ok bool
+func TestServerDial(t *testing.T) {
+	defer testlog(t).detach()
+
+	// run a fake TCP server to handle the connection.
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("could not setup listener: %v")
+	}
+	defer listener.Close()
+	accepted := make(chan net.Conn)
+	go func() {
+		conn, err := listener.Accept()
+		if err != nil {
+			t.Error("acccept error:", err)
+		}
+		conn.Close()
+		accepted <- conn
+	}()
+
+	// start the test server
+	connected := make(chan *Peer)
+	srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
+		if conn == nil {
+			t.Error("peer func called with nil conn")
+		}
+		peer := newPeer(conn, nil, dialAddr)
+		connected <- peer
+		return peer
+	})
+	defer close(connected)
+	defer srv.Stop()
+
+	// tell the server to connect.
+	connAddr := newPeerAddr(listener.Addr(), nil)
+	srv.peerConnect <- connAddr
+
+	select {
+	case conn := <-accepted:
 		select {
-		case self.current, ok = <-self.in:
-			if !ok {
-				return 0, io.EOF
+		case peer := <-connected:
+			if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
+				t.Errorf("peer started with wrong conn: got %v, want %v",
+					peer.conn.RemoteAddr(), conn.LocalAddr())
+			}
+			if peer.dialAddr != connAddr {
+				t.Errorf("peer started with wrong dialAddr: got %v, want %v",
+					peer.dialAddr, connAddr)
 			}
-		case <-self.close:
-			return 0, io.EOF
+		case <-time.After(1 * time.Second):
+			t.Error("server did not launch peer within one second")
 		}
-	}
-	length := len(self.current)
-	if length > len(buff) {
-		copy(buff[:], self.current[:len(buff)])
-		self.current = self.current[len(buff):]
-		return len(buff), nil
-	} else {
-		copy(buff[:length], self.current[:])
-		self.current = []byte{}
-		return length, io.EOF
-	}
-}
-
-func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
-	self.Out = append(self.Out, buff)
-	fmt.Printf("net write(%d): %x\n", len(self.Out), buff)
-	return len(buff), nil
-}
-
-func (self *TestNetworkConnection) Close() error {
-	close(self.close)
-	return nil
-}
-
-func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
-	return
-}
 
-func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
-	return self.addr
-}
-
-func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
-	return
-}
-
-func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
-	return
-}
-
-func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
-	return
-}
-
-func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
-	network = NewTestNetwork(1)
-	addr := &TestAddr{"test:30303"}
-	identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey")
-	maxPeers := 2
-	if handlers == nil {
-		handlers = make(Handlers)
+	case <-time.After(1 * time.Second):
+		t.Error("server did not connect within one second")
 	}
-	blackist := NewBlacklist()
-	server = New(network, addr, identity, handlers, maxPeers, blackist)
-	fmt.Println(server.identity.Pubkey())
-	return
 }
 
-func TestServerListener(t *testing.T) {
-	t.SkipNow()
-
-	network, server := SetupTestServer(nil)
-	server.Start(true, false)
-	time.Sleep(10 * time.Millisecond)
-	server.Stop()
-	peer1, ok := network.connections["inboundpeer-1"]
-	if !ok {
-		t.Error("not found inbound peer 1")
-	} else {
-		if len(peer1.Out) != 2 {
-			t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
+func TestServerBroadcast(t *testing.T) {
+	defer testlog(t).detach()
+	var connected sync.WaitGroup
+	srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
+		peer := newPeer(c, []Protocol{discard}, dialAddr)
+		peer.startSubprotocols([]Cap{discard.cap()})
+		connected.Done()
+		return peer
+	})
+	defer srv.Stop()
+
+	// dial a bunch of conns
+	var conns = make([]net.Conn, 8)
+	connected.Add(len(conns))
+	deadline := time.Now().Add(3 * time.Second)
+	dialer := &net.Dialer{Deadline: deadline}
+	for i := range conns {
+		conn, err := dialer.Dial("tcp", srv.ListenAddr)
+		if err != nil {
+			t.Fatalf("conn %d: dial error: %v", i, err)
 		}
+		defer conn.Close()
+		conn.SetDeadline(deadline)
+		conns[i] = conn
 	}
-}
-
-func TestServerDialer(t *testing.T) {
-	network, server := SetupTestServer(nil)
-	server.Start(false, true)
-	server.peerConnect <- &TestAddr{"outboundpeer-1"}
-	time.Sleep(10 * time.Millisecond)
-	server.Stop()
-	peer1, ok := network.connections["outboundpeer-1"]
-	if !ok {
-		t.Error("not found outbound peer 1")
-	} else {
-		if len(peer1.Out) != 2 {
-			t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
+	connected.Wait()
+
+	// broadcast one message
+	srv.Broadcast("discard", 0, "foo")
+	goldbuf := new(bytes.Buffer)
+	writeMsg(goldbuf, NewMsg(16, "foo"))
+	golden := goldbuf.Bytes()
+
+	// check that the message has been written everywhere
+	for i, conn := range conns {
+		buf := make([]byte, len(golden))
+		if _, err := io.ReadFull(conn, buf); err != nil {
+			t.Errorf("conn %d: read error: %v", i, err)
+		} else if !bytes.Equal(buf, golden) {
+			t.Errorf("conn %d: msg mismatch\ngot:  %x\nwant: %x", i, buf, golden)
 		}
 	}
 }
-
-// func TestServerBroadcast(t *testing.T) {
-// 	handlers := make(Handlers)
-// 	testProtocol := &TestProtocol{Msgs: []*Msg{}}
-// 	handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
-// 	network, server := SetupTestServer(handlers)
-// 	server.Start(true, true)
-// 	server.peerConnect <- &TestAddr{"outboundpeer-1"}
-// 	time.Sleep(10 * time.Millisecond)
-// 	msg := NewMsg(0)
-// 	server.Broadcast("", msg)
-// 	packet := Packet(0, 0)
-// 	time.Sleep(10 * time.Millisecond)
-// 	server.Stop()
-// 	peer1, ok := network.connections["outboundpeer-1"]
-// 	if !ok {
-// 		t.Error("not found outbound peer 1")
-// 	} else {
-// 		fmt.Printf("out: %v\n", peer1.Out)
-// 		if len(peer1.Out) != 3 {
-// 			t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
-// 		} else {
-// 			if bytes.Compare(peer1.Out[1], packet) != 0 {
-// 				t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
-// 			}
-// 		}
-// 	}
-// 	peer2, ok := network.connections["inboundpeer-1"]
-// 	if !ok {
-// 		t.Error("not found inbound peer 2")
-// 	} else {
-// 		fmt.Printf("out: %v\n", peer2.Out)
-// 		if len(peer1.Out) != 3 {
-// 			t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
-// 		} else {
-// 			if bytes.Compare(peer2.Out[1], packet) != 0 {
-// 				t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
-// 			}
-// 		}
-// 	}
-// }
-
-func TestServerPeersMessage(t *testing.T) {
-	t.SkipNow()
-	_, server := SetupTestServer(nil)
-	server.Start(true, true)
-	defer server.Stop()
-	server.peerConnect <- &TestAddr{"outboundpeer-1"}
-	time.Sleep(2000 * time.Millisecond)
-
-	pl := server.encodedPeerList()
-	if pl == nil {
-		t.Errorf("expect non-nil peer list")
-	}
-	if c := server.PeerCount(); c != 2 {
-		t.Errorf("expect 2 peers, got %v", c)
-	}
-}

+ 28 - 0
p2p/testlog_test.go

@@ -0,0 +1,28 @@
+package p2p
+
+import (
+	"testing"
+
+	"github.com/ethereum/go-ethereum/logger"
+)
+
+type testLogger struct{ t *testing.T }
+
+func testlog(t *testing.T) testLogger {
+	logger.Reset()
+	l := testLogger{t}
+	logger.AddLogSystem(l)
+	return l
+}
+
+func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
+func (testLogger) SetLogLevel(logger.LogLevel)  {}
+
+func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
+	l.t.Logf("%s", msg)
+}
+
+func (testLogger) detach() {
+	logger.Flush()
+	logger.Reset()
+}

+ 40 - 0
p2p/testpoc7.go

@@ -0,0 +1,40 @@
+// +build none
+
+package main
+
+import (
+	"fmt"
+	"log"
+	"net"
+	"os"
+
+	"github.com/ethereum/go-ethereum/logger"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/obscuren/secp256k1-go"
+)
+
+func main() {
+	logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
+
+	pub, _ := secp256k1.GenerateKeyPair()
+	srv := p2p.Server{
+		MaxPeers:   10,
+		Identity:   p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
+		ListenAddr: ":30303",
+		NAT:        p2p.PMP(net.ParseIP("10.0.0.1")),
+	}
+	if err := srv.Start(); err != nil {
+		fmt.Println("could not start server:", err)
+		os.Exit(1)
+	}
+
+	// add seed peers
+	seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
+	if err != nil {
+		fmt.Println("couldn't resolve:", err)
+		os.Exit(1)
+	}
+	srv.SuggestPeer(seed.IP, seed.Port, nil)
+
+	select {}
+}