|
@@ -1,6 +1,7 @@
|
|
|
package p2p
|
|
package p2p
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
"io"
|
|
"io"
|
|
|
"io/ioutil"
|
|
"io/ioutil"
|
|
@@ -71,7 +72,8 @@ type Peer struct {
|
|
|
runlock sync.RWMutex // protects running
|
|
runlock sync.RWMutex // protects running
|
|
|
running map[string]*proto
|
|
running map[string]*proto
|
|
|
|
|
|
|
|
- protocolHandshakeEnabled bool
|
|
|
|
|
|
|
+ // disables protocol handshake, for testing
|
|
|
|
|
+ noHandshake bool
|
|
|
|
|
|
|
|
protoWG sync.WaitGroup
|
|
protoWG sync.WaitGroup
|
|
|
protoErr chan error
|
|
protoErr chan error
|
|
@@ -134,11 +136,11 @@ func (p *Peer) Disconnect(reason DiscReason) {
|
|
|
|
|
|
|
|
// String implements fmt.Stringer.
|
|
// String implements fmt.Stringer.
|
|
|
func (p *Peer) String() string {
|
|
func (p *Peer) String() string {
|
|
|
- return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr())
|
|
|
|
|
|
|
+ return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
|
|
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
|
|
|
- logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr())
|
|
|
|
|
|
|
+ logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
|
|
|
return &Peer{
|
|
return &Peer{
|
|
|
Logger: logger.NewLogger(logtag),
|
|
Logger: logger.NewLogger(logtag),
|
|
|
rw: newFrameRW(conn, msgWriteTimeout),
|
|
rw: newFrameRW(conn, msgWriteTimeout),
|
|
@@ -164,33 +166,35 @@ func (p *Peer) run() DiscReason {
|
|
|
var readErr = make(chan error, 1)
|
|
var readErr = make(chan error, 1)
|
|
|
defer p.closeProtocols()
|
|
defer p.closeProtocols()
|
|
|
defer close(p.closed)
|
|
defer close(p.closed)
|
|
|
- defer p.rw.Close()
|
|
|
|
|
|
|
|
|
|
- // start the read loop
|
|
|
|
|
go func() { readErr <- p.readLoop() }()
|
|
go func() { readErr <- p.readLoop() }()
|
|
|
|
|
|
|
|
- if p.protocolHandshakeEnabled {
|
|
|
|
|
|
|
+ if !p.noHandshake {
|
|
|
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
|
|
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
|
|
|
p.DebugDetailf("Protocol handshake error: %v\n", err)
|
|
p.DebugDetailf("Protocol handshake error: %v\n", err)
|
|
|
|
|
+ p.rw.Close()
|
|
|
return DiscProtocolError
|
|
return DiscProtocolError
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // wait for an error or disconnect
|
|
|
|
|
|
|
+ // Wait for an error or disconnect.
|
|
|
var reason DiscReason
|
|
var reason DiscReason
|
|
|
select {
|
|
select {
|
|
|
case err := <-readErr:
|
|
case err := <-readErr:
|
|
|
// We rely on protocols to abort if there is a write error. It
|
|
// We rely on protocols to abort if there is a write error. It
|
|
|
// might be more robust to handle them here as well.
|
|
// might be more robust to handle them here as well.
|
|
|
p.DebugDetailf("Read error: %v\n", err)
|
|
p.DebugDetailf("Read error: %v\n", err)
|
|
|
- reason = DiscNetworkError
|
|
|
|
|
|
|
+ p.rw.Close()
|
|
|
|
|
+ return DiscNetworkError
|
|
|
|
|
+
|
|
|
case err := <-p.protoErr:
|
|
case err := <-p.protoErr:
|
|
|
reason = discReasonForError(err)
|
|
reason = discReasonForError(err)
|
|
|
case reason = <-p.disc:
|
|
case reason = <-p.disc:
|
|
|
}
|
|
}
|
|
|
- if reason != DiscNetworkError {
|
|
|
|
|
- p.politeDisconnect(reason)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ p.politeDisconnect(reason)
|
|
|
|
|
+
|
|
|
|
|
+ // Wait for readLoop. It will end because conn is now closed.
|
|
|
|
|
+ <-readErr
|
|
|
p.Debugf("Disconnected: %v\n", reason)
|
|
p.Debugf("Disconnected: %v\n", reason)
|
|
|
return reason
|
|
return reason
|
|
|
}
|
|
}
|
|
@@ -198,9 +202,9 @@ func (p *Peer) run() DiscReason {
|
|
|
func (p *Peer) politeDisconnect(reason DiscReason) {
|
|
func (p *Peer) politeDisconnect(reason DiscReason) {
|
|
|
done := make(chan struct{})
|
|
done := make(chan struct{})
|
|
|
go func() {
|
|
go func() {
|
|
|
- // send reason
|
|
|
|
|
EncodeMsg(p.rw, discMsg, uint(reason))
|
|
EncodeMsg(p.rw, discMsg, uint(reason))
|
|
|
- // discard any data that might arrive
|
|
|
|
|
|
|
+ // Wait for the other side to close the connection.
|
|
|
|
|
+ // Discard any data that they send until then.
|
|
|
io.Copy(ioutil.Discard, p.rw)
|
|
io.Copy(ioutil.Discard, p.rw)
|
|
|
close(done)
|
|
close(done)
|
|
|
}()
|
|
}()
|
|
@@ -208,10 +212,11 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
|
|
|
case <-done:
|
|
case <-done:
|
|
|
case <-time.After(disconnectGracePeriod):
|
|
case <-time.After(disconnectGracePeriod):
|
|
|
}
|
|
}
|
|
|
|
|
+ p.rw.Close()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (p *Peer) readLoop() error {
|
|
func (p *Peer) readLoop() error {
|
|
|
- if p.protocolHandshakeEnabled {
|
|
|
|
|
|
|
+ if !p.noHandshake {
|
|
|
if err := readProtocolHandshake(p, p.rw); err != nil {
|
|
if err := readProtocolHandshake(p, p.rw); err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
@@ -264,7 +269,7 @@ func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
|
|
|
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
|
|
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
|
|
|
}
|
|
}
|
|
|
if msg.Size > baseProtocolMaxMsgSize {
|
|
if msg.Size > baseProtocolMaxMsgSize {
|
|
|
- return newPeerError(errMisc, "message too big")
|
|
|
|
|
|
|
+ return newPeerError(errInvalidMsg, "message too big")
|
|
|
}
|
|
}
|
|
|
var hs handshake
|
|
var hs handshake
|
|
|
if err := msg.Decode(&hs); err != nil {
|
|
if err := msg.Decode(&hs); err != nil {
|
|
@@ -326,7 +331,7 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
|
|
err := impl.Run(p, rw)
|
|
err := impl.Run(p, rw)
|
|
|
if err == nil {
|
|
if err == nil {
|
|
|
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
|
|
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
|
|
|
- err = newPeerError(errMisc, "protocol returned")
|
|
|
|
|
|
|
+ err = errors.New("protocol returned")
|
|
|
} else {
|
|
} else {
|
|
|
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
|
|
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
|
|
|
}
|
|
}
|