|
|
@@ -68,50 +68,61 @@ type protoHandshake struct {
|
|
|
// setupConn starts a protocol session on the given connection.
|
|
|
// It runs the encryption handshake and the protocol handshake.
|
|
|
// If dial is non-nil, the connection the local node is the initiator.
|
|
|
-func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
|
|
+// If atcap is true, the connection will be disconnected with DiscTooManyPeers
|
|
|
+// after the key exchange.
|
|
|
+func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
|
|
|
if dial == nil {
|
|
|
- return setupInboundConn(fd, prv, our)
|
|
|
+ return setupInboundConn(fd, prv, our, atcap)
|
|
|
} else {
|
|
|
- return setupOutboundConn(fd, prv, our, dial)
|
|
|
+ return setupOutboundConn(fd, prv, our, dial, atcap)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) {
|
|
|
+func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, atcap bool) (*conn, error) {
|
|
|
secrets, err := receiverEncHandshake(fd, prv, nil)
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("encryption handshake failed: %v", err)
|
|
|
}
|
|
|
-
|
|
|
- // Run the protocol handshake using authenticated messages.
|
|
|
rw := newRlpxFrameRW(fd, secrets)
|
|
|
- rhs, err := readProtocolHandshake(rw, our)
|
|
|
+ if atcap {
|
|
|
+ SendItems(rw, discMsg, DiscTooManyPeers)
|
|
|
+ return nil, errors.New("we have too many peers")
|
|
|
+ }
|
|
|
+ // Run the protocol handshake using authenticated messages.
|
|
|
+ rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- if rhs.ID != secrets.RemoteID {
|
|
|
- return nil, errors.New("node ID in protocol handshake does not match encryption handshake")
|
|
|
- }
|
|
|
- // TODO: validate that handshake node ID matches
|
|
|
if err := Send(rw, handshakeMsg, our); err != nil {
|
|
|
- return nil, fmt.Errorf("protocol write error: %v", err)
|
|
|
+ return nil, fmt.Errorf("protocol handshake write error: %v", err)
|
|
|
}
|
|
|
return &conn{rw, rhs}, nil
|
|
|
}
|
|
|
|
|
|
-func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
|
|
+func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
|
|
|
secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("encryption handshake failed: %v", err)
|
|
|
}
|
|
|
-
|
|
|
- // Run the protocol handshake using authenticated messages.
|
|
|
rw := newRlpxFrameRW(fd, secrets)
|
|
|
- if err := Send(rw, handshakeMsg, our); err != nil {
|
|
|
- return nil, fmt.Errorf("protocol write error: %v", err)
|
|
|
+ if atcap {
|
|
|
+ SendItems(rw, discMsg, DiscTooManyPeers)
|
|
|
+ return nil, errors.New("we have too many peers")
|
|
|
}
|
|
|
- rhs, err := readProtocolHandshake(rw, our)
|
|
|
+ // Run the protocol handshake using authenticated messages.
|
|
|
+ //
|
|
|
+ // Note that even though writing the handshake is first, we prefer
|
|
|
+ // returning the handshake read error. If the remote side
|
|
|
+ // disconnects us early with a valid reason, we should return it
|
|
|
+ // as the error so it can be tracked elsewhere.
|
|
|
+ werr := make(chan error)
|
|
|
+ go func() { werr <- Send(rw, handshakeMsg, our) }()
|
|
|
+ rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("protocol handshake read error: %v", err)
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if err := <-werr; err != nil {
|
|
|
+ return nil, fmt.Errorf("protocol handshake write error: %v", err)
|
|
|
}
|
|
|
if rhs.ID != dial.ID {
|
|
|
return nil, errors.New("dialed node id mismatch")
|
|
|
@@ -398,18 +409,17 @@ func xor(one, other []byte) (xor []byte) {
|
|
|
return xor
|
|
|
}
|
|
|
|
|
|
-func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
|
|
|
- // read and handle remote handshake
|
|
|
- msg, err := r.ReadMsg()
|
|
|
+func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) {
|
|
|
+ msg, err := rw.ReadMsg()
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
if msg.Code == discMsg {
|
|
|
// disconnect before protocol handshake is valid according to the
|
|
|
// spec and we send it ourself if Server.addPeer fails.
|
|
|
- var reason DiscReason
|
|
|
+ var reason [1]DiscReason
|
|
|
rlp.Decode(msg.Payload, &reason)
|
|
|
- return nil, reason
|
|
|
+ return nil, reason[0]
|
|
|
}
|
|
|
if msg.Code != handshakeMsg {
|
|
|
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
|
|
@@ -423,10 +433,16 @@ func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, e
|
|
|
}
|
|
|
// validate handshake info
|
|
|
if hs.Version != our.Version {
|
|
|
- return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version)
|
|
|
+ SendItems(rw, discMsg, DiscIncompatibleVersion)
|
|
|
+ return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version)
|
|
|
}
|
|
|
if (hs.ID == discover.NodeID{}) {
|
|
|
- return nil, newPeerError(errPubkeyInvalid, "missing")
|
|
|
+ SendItems(rw, discMsg, DiscInvalidIdentity)
|
|
|
+ return nil, errors.New("invalid public key in handshake")
|
|
|
+ }
|
|
|
+ if hs.ID != wantID {
|
|
|
+ SendItems(rw, discMsg, DiscUnexpectedIdentity)
|
|
|
+ return nil, errors.New("handshake node ID does not match encryption handshake")
|
|
|
}
|
|
|
return &hs, nil
|
|
|
}
|