prox_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. package pss
  2. import (
  3. "context"
  4. "crypto/ecdsa"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "testing"
  12. "time"
  13. "github.com/ethereum/go-ethereum/common"
  14. "github.com/ethereum/go-ethereum/common/hexutil"
  15. "github.com/ethereum/go-ethereum/log"
  16. "github.com/ethereum/go-ethereum/node"
  17. "github.com/ethereum/go-ethereum/p2p"
  18. "github.com/ethereum/go-ethereum/p2p/enode"
  19. "github.com/ethereum/go-ethereum/p2p/simulations/adapters"
  20. "github.com/ethereum/go-ethereum/rpc"
  21. "github.com/ethereum/go-ethereum/swarm/network"
  22. "github.com/ethereum/go-ethereum/swarm/network/simulation"
  23. "github.com/ethereum/go-ethereum/swarm/pot"
  24. "github.com/ethereum/go-ethereum/swarm/state"
  25. )
  26. // needed to make the enode id of the receiving node available to the handler for triggers
  27. type handlerContextFunc func(*testData, *adapters.NodeConfig) *handler
  28. // struct to notify reception of messages to simulation driver
  29. // TODO To make code cleaner:
  30. // - consider a separate pss unwrap to message event in sim framework (this will make eventual message propagation analysis with pss easier/possible in the future)
  31. // - consider also test api calls to inspect handling results of messages
  32. type handlerNotification struct {
  33. id enode.ID
  34. serial uint64
  35. }
  36. type testData struct {
  37. mu sync.Mutex
  38. sim *simulation.Simulation
  39. handlerDone bool // set to true on termination of the simulation run
  40. requiredMessages int
  41. allowedMessages int
  42. messageCount int
  43. kademlias map[enode.ID]*network.Kademlia
  44. nodeAddrs map[enode.ID][]byte // make predictable overlay addresses from the generated random enode ids
  45. recipients map[int][]enode.ID // for logging output only
  46. allowed map[int][]enode.ID // allowed recipients
  47. expectedMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive
  48. allowedMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive
  49. senders map[int]enode.ID // originating nodes of the messages (intention is to choose as far as possible from the receiving neighborhood)
  50. handlerC chan handlerNotification // passes message from pss message handler to simulation driver
  51. doneC chan struct{} // terminates the handler channel listener
  52. errC chan error // error to pass to main sim thread
  53. msgC chan handlerNotification // message receipt notification to main sim thread
  54. msgs [][]byte // recipient addresses of messages
  55. }
  56. var (
  57. pof = pot.DefaultPof(256) // generate messages and index them
  58. topic = BytesToTopic([]byte{0xf3, 0x9e, 0x06, 0x82})
  59. )
  60. func (d *testData) getMsgCount() int {
  61. d.mu.Lock()
  62. defer d.mu.Unlock()
  63. return d.messageCount
  64. }
  65. func (d *testData) incrementMsgCount() int {
  66. d.mu.Lock()
  67. defer d.mu.Unlock()
  68. d.messageCount++
  69. return d.messageCount
  70. }
  71. func (d *testData) isDone() bool {
  72. d.mu.Lock()
  73. defer d.mu.Unlock()
  74. return d.handlerDone
  75. }
  76. func (d *testData) setDone() {
  77. d.mu.Lock()
  78. defer d.mu.Unlock()
  79. d.handlerDone = true
  80. }
  81. func getCmdParams(t *testing.T) (int, int, time.Duration) {
  82. args := strings.Split(t.Name(), "/")
  83. msgCount, err := strconv.ParseInt(args[2], 10, 16)
  84. if err != nil {
  85. t.Fatal(err)
  86. }
  87. nodeCount, err := strconv.ParseInt(args[1], 10, 16)
  88. if err != nil {
  89. t.Fatal(err)
  90. }
  91. timeoutStr := fmt.Sprintf("%ss", args[3])
  92. timeoutDur, err := time.ParseDuration(timeoutStr)
  93. if err != nil {
  94. t.Fatal(err)
  95. }
  96. return int(msgCount), int(nodeCount), timeoutDur
  97. }
  98. func newTestData() *testData {
  99. return &testData{
  100. kademlias: make(map[enode.ID]*network.Kademlia),
  101. nodeAddrs: make(map[enode.ID][]byte),
  102. recipients: make(map[int][]enode.ID),
  103. allowed: make(map[int][]enode.ID),
  104. expectedMsgs: make(map[enode.ID][]uint64),
  105. allowedMsgs: make(map[enode.ID][]uint64),
  106. senders: make(map[int]enode.ID),
  107. handlerC: make(chan handlerNotification),
  108. doneC: make(chan struct{}),
  109. errC: make(chan error),
  110. msgC: make(chan handlerNotification),
  111. }
  112. }
  113. func (d *testData) getKademlia(nodeId *enode.ID) (*network.Kademlia, error) {
  114. kadif, ok := d.sim.NodeItem(*nodeId, simulation.BucketKeyKademlia)
  115. if !ok {
  116. return nil, fmt.Errorf("no kademlia entry for %v", nodeId)
  117. }
  118. kad, ok := kadif.(*network.Kademlia)
  119. if !ok {
  120. return nil, fmt.Errorf("invalid kademlia entry for %v", nodeId)
  121. }
  122. return kad, nil
  123. }
  124. func (d *testData) init(msgCount int) error {
  125. log.Debug("TestProxNetwork start")
  126. for _, nodeId := range d.sim.NodeIDs() {
  127. kad, err := d.getKademlia(&nodeId)
  128. if err != nil {
  129. return err
  130. }
  131. d.nodeAddrs[nodeId] = kad.BaseAddr()
  132. }
  133. for i := 0; i < int(msgCount); i++ {
  134. msgAddr := pot.RandomAddress() // we choose message addresses randomly
  135. d.msgs = append(d.msgs, msgAddr.Bytes())
  136. smallestPo := 256
  137. var targets []enode.ID
  138. var closestPO int
  139. // loop through all nodes and find the required and allowed recipients of each message
  140. // (for more information, please see the comment to the main test function)
  141. for _, nod := range d.sim.Net.GetNodes() {
  142. po, _ := pof(d.msgs[i], d.nodeAddrs[nod.ID()], 0)
  143. depth := d.kademlias[nod.ID()].NeighbourhoodDepth()
  144. // only nodes with closest IDs (wrt the msg address) will be required recipients
  145. if po > closestPO {
  146. closestPO = po
  147. targets = nil
  148. targets = append(targets, nod.ID())
  149. } else if po == closestPO {
  150. targets = append(targets, nod.ID())
  151. }
  152. if po >= depth {
  153. d.allowedMessages++
  154. d.allowed[i] = append(d.allowed[i], nod.ID())
  155. d.allowedMsgs[nod.ID()] = append(d.allowedMsgs[nod.ID()], uint64(i))
  156. }
  157. // a node with the smallest PO (wrt msg) will be the sender,
  158. // in order to increase the distance the msg must travel
  159. if po < smallestPo {
  160. smallestPo = po
  161. d.senders[i] = nod.ID()
  162. }
  163. }
  164. d.requiredMessages += len(targets)
  165. for _, id := range targets {
  166. d.recipients[i] = append(d.recipients[i], id)
  167. d.expectedMsgs[id] = append(d.expectedMsgs[id], uint64(i))
  168. }
  169. log.Debug("nn for msg", "targets", len(d.recipients[i]), "msgidx", i, "msg", common.Bytes2Hex(msgAddr[:8]), "sender", d.senders[i], "senderpo", smallestPo)
  170. }
  171. log.Debug("msgs to receive", "count", d.requiredMessages)
  172. return nil
  173. }
  174. // Here we test specific functionality of the pss, setting the prox property of
  175. // the handler. The tests generate a number of messages with random addresses.
  176. // Then, for each message it calculates which nodes have the msg address
  177. // within its nearest neighborhood depth, and stores those nodes as possible
  178. // recipients. Those nodes that are the closest to the message address (nodes
  179. // belonging to the deepest PO wrt the msg address) are stored as required
  180. // recipients. The difference between allowed and required recipients results
  181. // from the fact that the nearest neighbours are not necessarily reciprocal.
  182. // Upon sending the messages, the test verifies that the respective message is
  183. // passed to the message handlers of these required recipients. The test fails
  184. // if a message is handled by recipient which is not listed among the allowed
  185. // recipients of this particular message. It also fails after timeout, if not
  186. // all the required recipients have received their respective messages.
  187. //
  188. // For example, if proximity order of certain msg address is 4, and node X
  189. // has PO=5 wrt the message address, and nodes Y and Z have PO=6, then:
  190. // nodes Y and Z will be considered required recipients of the msg,
  191. // whereas nodes X, Y and Z will be allowed recipients.
  192. func TestProxNetwork(t *testing.T) {
  193. t.Run("16/16/15", testProxNetwork)
  194. }
  195. // params in run name: nodes/msgs
  196. func TestProxNetworkLong(t *testing.T) {
  197. if !*longrunning {
  198. t.Skip("run with --longrunning flag to run extensive network tests")
  199. }
  200. t.Run("8/100/30", testProxNetwork)
  201. t.Run("16/100/30", testProxNetwork)
  202. t.Run("32/100/60", testProxNetwork)
  203. t.Run("64/100/60", testProxNetwork)
  204. t.Run("128/100/120", testProxNetwork)
  205. }
  206. func testProxNetwork(t *testing.T) {
  207. tstdata := newTestData()
  208. msgCount, nodeCount, timeout := getCmdParams(t)
  209. handlerContextFuncs := make(map[Topic]handlerContextFunc)
  210. handlerContextFuncs[topic] = nodeMsgHandler
  211. services := newProxServices(tstdata, true, handlerContextFuncs, tstdata.kademlias)
  212. tstdata.sim = simulation.New(services)
  213. defer tstdata.sim.Close()
  214. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  215. defer cancel()
  216. filename := fmt.Sprintf("testdata/snapshot_%d.json", nodeCount)
  217. err := tstdata.sim.UploadSnapshot(ctx, filename)
  218. if err != nil {
  219. t.Fatal(err)
  220. }
  221. err = tstdata.init(msgCount) // initialize the test data
  222. if err != nil {
  223. t.Fatal(err)
  224. }
  225. wrapper := func(c context.Context, _ *simulation.Simulation) error {
  226. return testRoutine(tstdata, c)
  227. }
  228. result := tstdata.sim.Run(ctx, wrapper) // call the main test function
  229. if result.Error != nil {
  230. // context deadline exceeded
  231. // however, it might just mean that not all possible messages are received
  232. // now we must check if all required messages are received
  233. cnt := tstdata.getMsgCount()
  234. log.Debug("TestProxNetwork finished", "rcv", cnt)
  235. if cnt < tstdata.requiredMessages {
  236. t.Fatal(result.Error)
  237. }
  238. }
  239. t.Logf("completed %d", result.Duration)
  240. }
  241. func (tstdata *testData) sendAllMsgs() {
  242. for i, msg := range tstdata.msgs {
  243. log.Debug("sending msg", "idx", i, "from", tstdata.senders[i])
  244. nodeClient, err := tstdata.sim.Net.GetNode(tstdata.senders[i]).Client()
  245. if err != nil {
  246. tstdata.errC <- err
  247. }
  248. var uvarByte [8]byte
  249. binary.PutUvarint(uvarByte[:], uint64(i))
  250. nodeClient.Call(nil, "pss_sendRaw", hexutil.Encode(msg), hexutil.Encode(topic[:]), hexutil.Encode(uvarByte[:]))
  251. }
  252. log.Debug("all messages sent")
  253. }
  254. // testRoutine is the main test function, called by Simulation.Run()
  255. func testRoutine(tstdata *testData, ctx context.Context) error {
  256. go handlerChannelListener(tstdata, ctx)
  257. go tstdata.sendAllMsgs()
  258. received := 0
  259. // collect incoming messages and terminate with corresponding status when message handler listener ends
  260. for {
  261. select {
  262. case err := <-tstdata.errC:
  263. return err
  264. case hn := <-tstdata.msgC:
  265. received++
  266. log.Debug("msg received", "msgs_received", received, "total_expected", tstdata.requiredMessages, "id", hn.id, "serial", hn.serial)
  267. if received == tstdata.allowedMessages {
  268. close(tstdata.doneC)
  269. return nil
  270. }
  271. }
  272. }
  273. return nil
  274. }
  275. func handlerChannelListener(tstdata *testData, ctx context.Context) {
  276. for {
  277. select {
  278. case <-tstdata.doneC: // graceful exit
  279. tstdata.setDone()
  280. tstdata.errC <- nil
  281. return
  282. case <-ctx.Done(): // timeout or cancel
  283. tstdata.setDone()
  284. tstdata.errC <- ctx.Err()
  285. return
  286. // incoming message from pss message handler
  287. case handlerNotification := <-tstdata.handlerC:
  288. // check if recipient has already received all its messages and notify to fail the test if so
  289. aMsgs := tstdata.allowedMsgs[handlerNotification.id]
  290. if len(aMsgs) == 0 {
  291. tstdata.setDone()
  292. tstdata.errC <- fmt.Errorf("too many messages received by recipient %x", handlerNotification.id)
  293. return
  294. }
  295. // check if message serial is in expected messages for this recipient and notify to fail the test if not
  296. idx := -1
  297. for i, msg := range aMsgs {
  298. if handlerNotification.serial == msg {
  299. idx = i
  300. break
  301. }
  302. }
  303. if idx == -1 {
  304. tstdata.setDone()
  305. tstdata.errC <- fmt.Errorf("message %d received by wrong recipient %v", handlerNotification.serial, handlerNotification.id)
  306. return
  307. }
  308. // message is ok, so remove that message serial from the recipient expectation array and notify the main sim thread
  309. aMsgs[idx] = aMsgs[len(aMsgs)-1]
  310. aMsgs = aMsgs[:len(aMsgs)-1]
  311. tstdata.msgC <- handlerNotification
  312. }
  313. }
  314. }
  315. func nodeMsgHandler(tstdata *testData, config *adapters.NodeConfig) *handler {
  316. return &handler{
  317. f: func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error {
  318. cnt := tstdata.incrementMsgCount()
  319. log.Debug("nodeMsgHandler rcv", "cnt", cnt)
  320. // using simple serial in message body, makes it easy to keep track of who's getting what
  321. serial, c := binary.Uvarint(msg)
  322. if c <= 0 {
  323. log.Crit(fmt.Sprintf("corrupt message received by %x (uvarint parse returned %d)", config.ID, c))
  324. }
  325. if tstdata.isDone() {
  326. return errors.New("handlers aborted") // terminate if simulation is over
  327. }
  328. // pass message context to the listener in the simulation
  329. tstdata.handlerC <- handlerNotification{
  330. id: config.ID,
  331. serial: serial,
  332. }
  333. return nil
  334. },
  335. caps: &handlerCaps{
  336. raw: true, // we use raw messages for simplicity
  337. prox: true,
  338. },
  339. }
  340. }
  341. // an adaptation of the same services setup as in pss_test.go
  342. // replaces pss_test.go when those tests are rewritten to the new swarm/network/simulation package
  343. func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[Topic]handlerContextFunc, kademlias map[enode.ID]*network.Kademlia) map[string]simulation.ServiceFunc {
  344. stateStore := state.NewInmemoryStore()
  345. kademlia := func(id enode.ID, bzzkey []byte) *network.Kademlia {
  346. if k, ok := kademlias[id]; ok {
  347. return k
  348. }
  349. params := network.NewKadParams()
  350. params.MaxBinSize = 3
  351. params.MinBinSize = 1
  352. params.MaxRetries = 1000
  353. params.RetryExponent = 2
  354. params.RetryInterval = 1000000
  355. kademlias[id] = network.NewKademlia(bzzkey, params)
  356. return kademlias[id]
  357. }
  358. return map[string]simulation.ServiceFunc{
  359. "bzz": func(ctx *adapters.ServiceContext, b *sync.Map) (node.Service, func(), error) {
  360. var err error
  361. var bzzPrivateKey *ecdsa.PrivateKey
  362. // normally translation of enode id to swarm address is concealed by the network package
  363. // however, we need to keep track of it in the test driver as well.
  364. // if the translation in the network package changes, that can cause these tests to unpredictably fail
  365. // therefore we keep a local copy of the translation here
  366. addr := network.NewAddr(ctx.Config.Node())
  367. bzzPrivateKey, err = simulation.BzzPrivateKeyFromConfig(ctx.Config)
  368. if err != nil {
  369. return nil, nil, err
  370. }
  371. addr.OAddr = network.PrivateKeyToBzzKey(bzzPrivateKey)
  372. b.Store(simulation.BucketKeyBzzPrivateKey, bzzPrivateKey)
  373. hp := network.NewHiveParams()
  374. hp.Discovery = false
  375. config := &network.BzzConfig{
  376. OverlayAddr: addr.Over(),
  377. UnderlayAddr: addr.Under(),
  378. HiveParams: hp,
  379. }
  380. return network.NewBzz(config, kademlia(ctx.Config.ID, addr.OAddr), stateStore, nil, nil), nil, nil
  381. },
  382. "pss": func(ctx *adapters.ServiceContext, b *sync.Map) (node.Service, func(), error) {
  383. // execadapter does not exec init()
  384. initTest()
  385. // create keys in whisper and set up the pss object
  386. ctxlocal, cancel := context.WithTimeout(context.Background(), time.Second*3)
  387. defer cancel()
  388. keys, err := wapi.NewKeyPair(ctxlocal)
  389. privkey, err := w.GetPrivateKey(keys)
  390. pssp := NewPssParams().WithPrivateKey(privkey)
  391. pssp.AllowRaw = allowRaw
  392. bzzPrivateKey, err := simulation.BzzPrivateKeyFromConfig(ctx.Config)
  393. if err != nil {
  394. return nil, nil, err
  395. }
  396. bzzKey := network.PrivateKeyToBzzKey(bzzPrivateKey)
  397. pskad := kademlia(ctx.Config.ID, bzzKey)
  398. ps, err := NewPss(pskad, pssp)
  399. if err != nil {
  400. return nil, nil, err
  401. }
  402. // register the handlers we've been passed
  403. var deregisters []func()
  404. for tpc, hndlrFunc := range handlerContextFuncs {
  405. deregisters = append(deregisters, ps.Register(&tpc, hndlrFunc(tstdata, ctx.Config)))
  406. }
  407. // if handshake mode is set, add the controller
  408. // TODO: This should be hooked to the handshake test file
  409. if useHandshake {
  410. SetHandshakeController(ps, NewHandshakeParams())
  411. }
  412. // we expose some api calls for cheating
  413. ps.addAPI(rpc.API{
  414. Namespace: "psstest",
  415. Version: "0.3",
  416. Service: NewAPITest(ps),
  417. Public: false,
  418. })
  419. b.Store(simulation.BucketKeyKademlia, pskad)
  420. // return Pss and cleanups
  421. return ps, func() {
  422. // run the handler deregister functions in reverse order
  423. for i := len(deregisters); i > 0; i-- {
  424. deregisters[i-1]()
  425. }
  426. }, nil
  427. },
  428. }
  429. }